Files
kernbench2/tests/test_tensor.py
T
2026-03-18 11:47:48 -07:00

283 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pytest
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
from kernbench.policy.placement.dp import (
ShardSpec,
column_wise,
tiled_column_major,
replicate,
row_wise,
tiled_row_major,
)
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
MemoryReadMsg,
MemoryWriteMsg,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
dtype_itemsize,
)
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
# ── Tensor types ─────────────────────────────────────────────────────
def test_tensor_shard_immutable():
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
with pytest.raises(AttributeError):
ts.pa = 0x2000 # type: ignore[misc]
# hashable
{ts}
def test_tensor_handle_nbytes():
th = TensorHandle(
name="A",
shape=(1024, 512),
dtype="fp16",
itemsize=2,
shards=(),
)
assert th.nbytes == 1024 * 512 * 2 # 1 MB
# ── Message types (ADR-0012) ─────────────────────────────────────────
def test_memory_write_msg_fields():
msg = MemoryWriteMsg(
correlation_id="c0",
request_id="r0",
dst_sip=0,
dst_cube=3,
dst_pe=5,
dst_pa=0xDEAD,
nbytes=4096,
pattern="zero",
)
assert msg.msg_type == "memory_write"
assert msg.src_kind == "pattern"
assert msg.dst_pa == 0xDEAD
assert msg.pattern == "zero"
with pytest.raises(AttributeError):
msg.nbytes = 0 # type: ignore[misc]
def test_memory_read_msg_fields():
msg = MemoryReadMsg(
correlation_id="c0",
request_id="r1",
src_sip=1,
src_cube=2,
src_pe=7,
src_pa=0xBEEF,
nbytes=2048,
)
assert msg.msg_type == "memory_read"
assert msg.src_pa == 0xBEEF
assert msg.nbytes == 2048
def test_kernel_launch_msg_fields():
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0)
targ = TensorArg(shards=(shard,))
sarg = ScalarArg(dtype="fp32", value=1.0)
kref = KernelRef(name="gemm", kind="builtin")
msg = KernelLaunchMsg(
correlation_id="c0",
request_id="r2",
kernel_ref=kref,
args=(targ, sarg),
)
assert msg.msg_type == "kernel_launch"
assert msg.kernel_ref.name == "gemm"
assert len(msg.args) == 2
assert msg.args[0].arg_kind == "tensor"
assert msg.args[1].arg_kind == "scalar"
# ── Placement: column_wise ───────────────────────────────────────────
def test_column_wise_placement():
"""(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB"""
shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
expected_nbytes = 1024 * 64 * 2 # 128 KB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == expected_nbytes
# offsets are contiguous
assert shards[0].offset_bytes == 0
assert shards[1].offset_bytes == expected_nbytes
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: row_wise ──────────────────────────────────────────────
def test_row_wise_placement():
"""(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB"""
shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
expected_nbytes = 128 * 512 * 2 # 128 KB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == expected_nbytes
assert shards[0].offset_bytes == 0
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: replicate ─────────────────────────────────────────────
def test_replicate_placement():
"""(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB"""
shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
full_nbytes = 1024 * 512 * 2 # 1 MB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == full_nbytes
assert s.offset_bytes == 0 # each is a full copy
# ── Placement: tiled_column_major ─────────────────────────────────────
def test_tiled_column_major():
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs"""
shards = tiled_column_major(
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
)
# 4 tiles along M, 4 tiles along K → 16 tiles total
assert len(shards) == 16
tile_bytes = 256 * 128 * 2 # 64 KB per tile
for s in shards:
assert s.nbytes == tile_bytes
# column-major: iterate K first, then M
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
# tile (m=2,k=0) → PE0, ...
assert shards[0].pe_index == 0
assert shards[1].pe_index == 1
assert shards[7].pe_index == 7
assert shards[8].pe_index == 0 # wraps around
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: tiled_row_major ────────────────────────────────────────
def test_tiled_row_major():
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs"""
shards = tiled_row_major(
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
)
assert len(shards) == 16
tile_bytes = 256 * 128 * 2
for s in shards:
assert s.nbytes == tile_bytes
# row-major: iterate M first, then K
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
# tile (m=0,k=2) → PE0, ...
assert shards[0].pe_index == 0
assert shards[1].pe_index == 1
assert shards[7].pe_index == 7
assert shards[8].pe_index == 0 # wraps around
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── deploy_tensor ────────────────────────────────────────────────────
def test_deploy_tensor_hbm():
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
allocs = _make_allocators()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
mem_kind="hbm",
)
assert th.name == "W"
assert th.shape == (1024, 512)
assert th.dtype == "fp16"
assert th.itemsize == 2
assert len(th.shards) == 8
# each shard has a distinct PA
pas = [s.pa for s in th.shards]
assert len(set(pas)) == 8
# each shard placed on correct PE
for i, s in enumerate(th.shards):
assert s.pe == i
assert s.sip == 0
assert s.cube == 0
def test_deploy_tensor_tcm():
"""Deploy with TCM → uses pe_tcm_addr allocation."""
allocs = _make_allocators()
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)]
th = deploy_tensor(
name="small",
shape=(128,),
dtype="fp16",
placement=placement,
allocators=allocs,
mem_kind="tcm",
)
assert len(th.shards) == 1
assert th.shards[0].pe == 0
assert th.shards[0].nbytes == 256
def test_deploy_tensor_overflow():
"""Allocation exceeding PE HBM capacity raises AllocationError."""
allocs = _make_allocators()
# 6 GB per PE slice, try to allocate 7 GB
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB)
with pytest.raises(AllocationError):
deploy_tensor(
name="toobig",
shape=(1,),
dtype="int8",
placement=[big_shard],
allocators=allocs,
)