commit - release 1
This commit is contained in:
@@ -0,0 +1,282 @@
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
|
||||
from kernbench.policy.placement.dp import (
|
||||
ShardSpec,
|
||||
column_wise,
|
||||
tiled_column_major,
|
||||
replicate,
|
||||
row_wise,
|
||||
tiled_row_major,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
MemoryReadMsg,
|
||||
MemoryWriteMsg,
|
||||
ScalarArg,
|
||||
TensorArg,
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import (
|
||||
TensorHandle,
|
||||
TensorShard,
|
||||
deploy_tensor,
|
||||
dtype_itemsize,
|
||||
)
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
|
||||
# ── Tensor types ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_shard_immutable():
|
||||
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||||
with pytest.raises(AttributeError):
|
||||
ts.pa = 0x2000 # type: ignore[misc]
|
||||
# hashable
|
||||
{ts}
|
||||
|
||||
|
||||
def test_tensor_handle_nbytes():
|
||||
th = TensorHandle(
|
||||
name="A",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
itemsize=2,
|
||||
shards=(),
|
||||
)
|
||||
assert th.nbytes == 1024 * 512 * 2 # 1 MB
|
||||
|
||||
|
||||
# ── Message types (ADR-0012) ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_memory_write_msg_fields():
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
dst_sip=0,
|
||||
dst_cube=3,
|
||||
dst_pe=5,
|
||||
dst_pa=0xDEAD,
|
||||
nbytes=4096,
|
||||
pattern="zero",
|
||||
)
|
||||
assert msg.msg_type == "memory_write"
|
||||
assert msg.src_kind == "pattern"
|
||||
assert msg.dst_pa == 0xDEAD
|
||||
assert msg.pattern == "zero"
|
||||
with pytest.raises(AttributeError):
|
||||
msg.nbytes = 0 # type: ignore[misc]
|
||||
|
||||
|
||||
def test_memory_read_msg_fields():
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r1",
|
||||
src_sip=1,
|
||||
src_cube=2,
|
||||
src_pe=7,
|
||||
src_pa=0xBEEF,
|
||||
nbytes=2048,
|
||||
)
|
||||
assert msg.msg_type == "memory_read"
|
||||
assert msg.src_pa == 0xBEEF
|
||||
assert msg.nbytes == 2048
|
||||
|
||||
|
||||
def test_kernel_launch_msg_fields():
|
||||
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0)
|
||||
targ = TensorArg(shards=(shard,))
|
||||
sarg = ScalarArg(dtype="fp32", value=1.0)
|
||||
kref = KernelRef(name="gemm", kind="builtin")
|
||||
msg = KernelLaunchMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r2",
|
||||
kernel_ref=kref,
|
||||
args=(targ, sarg),
|
||||
)
|
||||
assert msg.msg_type == "kernel_launch"
|
||||
assert msg.kernel_ref.name == "gemm"
|
||||
assert len(msg.args) == 2
|
||||
assert msg.args[0].arg_kind == "tensor"
|
||||
assert msg.args[1].arg_kind == "scalar"
|
||||
|
||||
|
||||
# ── Placement: column_wise ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_column_wise_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB"""
|
||||
shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 1024 * 64 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
# offsets are contiguous
|
||||
assert shards[0].offset_bytes == 0
|
||||
assert shards[1].offset_bytes == expected_nbytes
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: row_wise ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_row_wise_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB"""
|
||||
shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 128 * 512 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
assert shards[0].offset_bytes == 0
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: replicate ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_replicate_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB"""
|
||||
shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
full_nbytes = 1024 * 512 * 2 # 1 MB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == full_nbytes
|
||||
assert s.offset_bytes == 0 # each is a full copy
|
||||
|
||||
|
||||
# ── Placement: tiled_column_major ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_tiled_column_major():
|
||||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs"""
|
||||
shards = tiled_column_major(
|
||||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||||
)
|
||||
# 4 tiles along M, 4 tiles along K → 16 tiles total
|
||||
assert len(shards) == 16
|
||||
tile_bytes = 256 * 128 * 2 # 64 KB per tile
|
||||
for s in shards:
|
||||
assert s.nbytes == tile_bytes
|
||||
# column-major: iterate K first, then M
|
||||
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
|
||||
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
|
||||
# tile (m=2,k=0) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: tiled_row_major ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tiled_row_major():
|
||||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs"""
|
||||
shards = tiled_row_major(
|
||||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||||
)
|
||||
assert len(shards) == 16
|
||||
tile_bytes = 256 * 128 * 2
|
||||
for s in shards:
|
||||
assert s.nbytes == tile_bytes
|
||||
# row-major: iterate M first, then K
|
||||
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
|
||||
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
|
||||
# tile (m=0,k=2) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── deploy_tensor ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_deploy_tensor_hbm():
|
||||
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
|
||||
allocs = _make_allocators()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
mem_kind="hbm",
|
||||
)
|
||||
assert th.name == "W"
|
||||
assert th.shape == (1024, 512)
|
||||
assert th.dtype == "fp16"
|
||||
assert th.itemsize == 2
|
||||
assert len(th.shards) == 8
|
||||
# each shard has a distinct PA
|
||||
pas = [s.pa for s in th.shards]
|
||||
assert len(set(pas)) == 8
|
||||
# each shard placed on correct PE
|
||||
for i, s in enumerate(th.shards):
|
||||
assert s.pe == i
|
||||
assert s.sip == 0
|
||||
assert s.cube == 0
|
||||
|
||||
|
||||
def test_deploy_tensor_tcm():
|
||||
"""Deploy with TCM → uses pe_tcm_addr allocation."""
|
||||
allocs = _make_allocators()
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)]
|
||||
th = deploy_tensor(
|
||||
name="small",
|
||||
shape=(128,),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
mem_kind="tcm",
|
||||
)
|
||||
assert len(th.shards) == 1
|
||||
assert th.shards[0].pe == 0
|
||||
assert th.shards[0].nbytes == 256
|
||||
|
||||
|
||||
def test_deploy_tensor_overflow():
|
||||
"""Allocation exceeding PE HBM capacity raises AllocationError."""
|
||||
allocs = _make_allocators()
|
||||
# 6 GB per PE slice, try to allocate 7 GB
|
||||
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB)
|
||||
with pytest.raises(AllocationError):
|
||||
deploy_tensor(
|
||||
name="toobig",
|
||||
shape=(1,),
|
||||
dtype="int8",
|
||||
placement=[big_shard],
|
||||
allocators=allocs,
|
||||
)
|
||||
Reference in New Issue
Block a user