import pytest from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator from kernbench.policy.placement.dp import ( DPPolicy, ShardSpec, column_wise, replicate, resolve_dp_policy, row_wise, tiled_column_major, tiled_row_major, ) from kernbench.runtime_api.kernel import ( KernelLaunchMsg, KernelRef, MemoryReadMsg, MemoryWriteMsg, ScalarArg, TensorArg, TensorArgShard, ) from kernbench.runtime_api.tensor import ( TensorHandle, TensorShard, deploy_tensor, dtype_itemsize, ) _MB = 1 << 20 _GB = 1 << 30 _CFG = AddressConfig( sip_count=2, cubes_per_sip=16, pes_per_cube=8, hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=8, tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, sram_bytes_per_cube=32 * _MB, ) def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]: return { (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) for i in range(num_pe) } # ── Tensor types ───────────────────────────────────────────────────── def test_tensor_shard_immutable(): ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0) with pytest.raises(AttributeError): ts.pa = 0x2000 # type: ignore[misc] # hashable {ts} def test_tensor_handle_nbytes(): th = TensorHandle( name="A", shape=(1024, 512), dtype="fp16", itemsize=2, shards=(), ) assert th.nbytes == 1024 * 512 * 2 # 1 MB # ── Message types (ADR-0012) ───────────────────────────────────────── def test_memory_write_msg_fields(): msg = MemoryWriteMsg( correlation_id="c0", request_id="r0", dst_sip=0, dst_cube=3, dst_pe=5, dst_pa=0xDEAD, nbytes=4096, pattern="zero", ) assert msg.msg_type == "memory_write" assert msg.src_kind == "pattern" assert msg.dst_pa == 0xDEAD assert msg.pattern == "zero" with pytest.raises(AttributeError): msg.nbytes = 0 # type: ignore[misc] def test_memory_read_msg_fields(): msg = MemoryReadMsg( correlation_id="c0", request_id="r1", src_sip=1, src_cube=2, src_pe=7, src_pa=0xBEEF, nbytes=2048, ) assert msg.msg_type == "memory_read" assert msg.src_pa == 0xBEEF assert msg.nbytes == 2048 def test_kernel_launch_msg_fields(): shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0) targ = TensorArg(shards=(shard,)) sarg = ScalarArg(dtype="fp32", value=1.0) kref = KernelRef(name="gemm", kind="builtin") msg = KernelLaunchMsg( correlation_id="c0", request_id="r2", kernel_ref=kref, args=(targ, sarg), ) assert msg.msg_type == "kernel_launch" assert msg.kernel_ref.name == "gemm" assert len(msg.args) == 2 assert msg.args[0].arg_kind == "tensor" assert msg.args[1].arg_kind == "scalar" # ── Placement: column_wise ─────────────────────────────────────────── def test_column_wise_placement(): """(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB""" shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) assert len(shards) == 8 expected_nbytes = 1024 * 64 * 2 # 128 KB for i, s in enumerate(shards): assert s.local_pe == i assert s.nbytes == expected_nbytes # offsets are contiguous assert shards[0].offset_bytes == 0 assert shards[1].offset_bytes == expected_nbytes # total coverage assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 # ── Placement: row_wise ────────────────────────────────────────────── def test_row_wise_placement(): """(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB""" shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8) assert len(shards) == 8 expected_nbytes = 128 * 512 * 2 # 128 KB for i, s in enumerate(shards): assert s.local_pe == i assert s.nbytes == expected_nbytes assert shards[0].offset_bytes == 0 assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 # ── Placement: replicate ───────────────────────────────────────────── def test_replicate_placement(): """(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB""" shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8) assert len(shards) == 8 full_nbytes = 1024 * 512 * 2 # 1 MB for i, s in enumerate(shards): assert s.local_pe == i assert s.nbytes == full_nbytes assert s.offset_bytes == 0 # each is a full copy # ── Placement: tiled_column_major ───────────────────────────────────── def test_tiled_column_major(): """(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs""" shards = tiled_column_major( shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128, ) # 4 tiles along M, 4 tiles along K → 16 tiles total assert len(shards) == 16 tile_bytes = 256 * 128 * 2 # 64 KB per tile for s in shards: assert s.nbytes == tile_bytes # column-major: iterate K first, then M # tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3 # tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7 # tile (m=2,k=0) → PE0, ... assert shards[0].local_pe == 0 assert shards[1].local_pe == 1 assert shards[7].local_pe == 7 assert shards[8].local_pe == 0 # wraps around # total coverage assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 # ── Placement: tiled_row_major ──────────────────────────────────────── def test_tiled_row_major(): """(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs""" shards = tiled_row_major( shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128, ) assert len(shards) == 16 tile_bytes = 256 * 128 * 2 for s in shards: assert s.nbytes == tile_bytes # row-major: iterate M first, then K # tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3 # tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7 # tile (m=0,k=2) → PE0, ... assert shards[0].local_pe == 0 assert shards[1].local_pe == 1 assert shards[7].local_pe == 7 assert shards[8].local_pe == 0 # wraps around # total coverage assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 # ── deploy_tensor ──────────────────────────────────────────────────── def test_deploy_tensor_hbm(): """Deploy with column_wise placement → TensorHandle with valid PA shards.""" allocs = _make_allocators() placement = resolve_dp_policy( DPPolicy(cube="replicate", pe="column_wise"), shape=(1024, 512), itemsize=2, num_pe=8, num_cubes=1, target_sip=0, ) th = deploy_tensor( name="W", shape=(1024, 512), dtype="fp16", placement=placement, allocators=allocs, mem_kind="hbm", ) assert th.name == "W" assert th.shape == (1024, 512) assert th.dtype == "fp16" assert th.itemsize == 2 assert len(th.shards) == 8 # each shard has a distinct PA pas = [s.pa for s in th.shards] assert len(set(pas)) == 8 # each shard placed on correct PE for i, s in enumerate(th.shards): assert s.pe == i assert s.sip == 0 assert s.cube == 0 def test_deploy_tensor_tcm(): """Deploy with TCM → uses pe_tcm_addr allocation.""" allocs = _make_allocators() placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=256)] th = deploy_tensor( name="small", shape=(128,), dtype="fp16", placement=placement, allocators=allocs, mem_kind="tcm", ) assert len(th.shards) == 1 assert th.shards[0].pe == 0 assert th.shards[0].nbytes == 256 def test_deploy_tensor_overflow(): """Allocation exceeding PE HBM capacity raises AllocationError.""" allocs = _make_allocators() # 6 GB per PE slice, try to allocate 7 GB big_shard = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=7 * _GB) with pytest.raises(AllocationError): deploy_tensor( name="toobig", shape=(1,), dtype="int8", placement=[big_shard], allocators=allocs, )