357cab525b
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy itself describes only the cube × PE layout within one SIP. ShardSpec switches to structural (sip, cube, pe) coordinates; the flat pe_index field/property is fully removed — silent drift between global-flat and SIP-local interpretations was a foot-gun flagged by ADR-0024 D11. Breaking API (explicit TypeError / AttributeError): - DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError - ShardSpec.pe_index -> AttributeError - ShardSpec(pe_index=...) -> TypeError - resolve_dp_policy now takes target_sip= (required), no num_sips. Downstream migration: - PE allocator dict keyed by (sip, cube, pe) tuples, in both _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup. - _create_tensor passes target_sip=current_sip; post-hoc pe_index shifting removed entirely. - launch._compute_local_shape drops the dp.sip branch. - Internal resolvers (column_wise / row_wise / replicate / tiled_*) return _LocalPeShard (cube-local identifier) instead of ShardSpec — resolve_dp_policy lifts them to full structural coords. Tests: - New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the contract end-to-end. - test_sip_parallel.py rewritten: SIP composition now modeled as two resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style). - Call-site migration: test_tensor, test_va_integration, test_va_offset, test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy branch) all use intra-device DPPolicy and structural ShardSpec. Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
289 lines
9.1 KiB
Python
289 lines
9.1 KiB
Python
import pytest
|
||
|
||
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
|
||
from kernbench.policy.placement.dp import (
|
||
DPPolicy,
|
||
ShardSpec,
|
||
column_wise,
|
||
replicate,
|
||
resolve_dp_policy,
|
||
row_wise,
|
||
tiled_column_major,
|
||
tiled_row_major,
|
||
)
|
||
from kernbench.runtime_api.kernel import (
|
||
KernelLaunchMsg,
|
||
KernelRef,
|
||
MemoryReadMsg,
|
||
MemoryWriteMsg,
|
||
ScalarArg,
|
||
TensorArg,
|
||
TensorArgShard,
|
||
)
|
||
from kernbench.runtime_api.tensor import (
|
||
TensorHandle,
|
||
TensorShard,
|
||
deploy_tensor,
|
||
dtype_itemsize,
|
||
)
|
||
|
||
_MB = 1 << 20
|
||
_GB = 1 << 30
|
||
|
||
_CFG = AddressConfig(
|
||
sip_count=2,
|
||
cubes_per_sip=16,
|
||
pes_per_cube=8,
|
||
hbm_bytes_per_cube=48 * _GB,
|
||
hbm_slices_per_cube=8,
|
||
tcm_bytes_per_pe=16 * _MB,
|
||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||
sram_bytes_per_cube=32 * _MB,
|
||
)
|
||
|
||
|
||
def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
|
||
return {
|
||
(0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||
for i in range(num_pe)
|
||
}
|
||
|
||
|
||
# ── Tensor types ─────────────────────────────────────────────────────
|
||
|
||
|
||
def test_tensor_shard_immutable():
|
||
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||
with pytest.raises(AttributeError):
|
||
ts.pa = 0x2000 # type: ignore[misc]
|
||
# hashable
|
||
{ts}
|
||
|
||
|
||
def test_tensor_handle_nbytes():
|
||
th = TensorHandle(
|
||
name="A",
|
||
shape=(1024, 512),
|
||
dtype="fp16",
|
||
itemsize=2,
|
||
shards=(),
|
||
)
|
||
assert th.nbytes == 1024 * 512 * 2 # 1 MB
|
||
|
||
|
||
# ── Message types (ADR-0012) ─────────────────────────────────────────
|
||
|
||
|
||
def test_memory_write_msg_fields():
|
||
msg = MemoryWriteMsg(
|
||
correlation_id="c0",
|
||
request_id="r0",
|
||
dst_sip=0,
|
||
dst_cube=3,
|
||
dst_pe=5,
|
||
dst_pa=0xDEAD,
|
||
nbytes=4096,
|
||
pattern="zero",
|
||
)
|
||
assert msg.msg_type == "memory_write"
|
||
assert msg.src_kind == "pattern"
|
||
assert msg.dst_pa == 0xDEAD
|
||
assert msg.pattern == "zero"
|
||
with pytest.raises(AttributeError):
|
||
msg.nbytes = 0 # type: ignore[misc]
|
||
|
||
|
||
def test_memory_read_msg_fields():
|
||
msg = MemoryReadMsg(
|
||
correlation_id="c0",
|
||
request_id="r1",
|
||
src_sip=1,
|
||
src_cube=2,
|
||
src_pe=7,
|
||
src_pa=0xBEEF,
|
||
nbytes=2048,
|
||
)
|
||
assert msg.msg_type == "memory_read"
|
||
assert msg.src_pa == 0xBEEF
|
||
assert msg.nbytes == 2048
|
||
|
||
|
||
def test_kernel_launch_msg_fields():
|
||
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0)
|
||
targ = TensorArg(shards=(shard,))
|
||
sarg = ScalarArg(dtype="fp32", value=1.0)
|
||
kref = KernelRef(name="gemm", kind="builtin")
|
||
msg = KernelLaunchMsg(
|
||
correlation_id="c0",
|
||
request_id="r2",
|
||
kernel_ref=kref,
|
||
args=(targ, sarg),
|
||
)
|
||
assert msg.msg_type == "kernel_launch"
|
||
assert msg.kernel_ref.name == "gemm"
|
||
assert len(msg.args) == 2
|
||
assert msg.args[0].arg_kind == "tensor"
|
||
assert msg.args[1].arg_kind == "scalar"
|
||
|
||
|
||
# ── Placement: column_wise ───────────────────────────────────────────
|
||
|
||
|
||
def test_column_wise_placement():
|
||
"""(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB"""
|
||
shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||
assert len(shards) == 8
|
||
expected_nbytes = 1024 * 64 * 2 # 128 KB
|
||
for i, s in enumerate(shards):
|
||
assert s.local_pe == i
|
||
assert s.nbytes == expected_nbytes
|
||
# offsets are contiguous
|
||
assert shards[0].offset_bytes == 0
|
||
assert shards[1].offset_bytes == expected_nbytes
|
||
# total coverage
|
||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||
|
||
|
||
# ── Placement: row_wise ──────────────────────────────────────────────
|
||
|
||
|
||
def test_row_wise_placement():
|
||
"""(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB"""
|
||
shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||
assert len(shards) == 8
|
||
expected_nbytes = 128 * 512 * 2 # 128 KB
|
||
for i, s in enumerate(shards):
|
||
assert s.local_pe == i
|
||
assert s.nbytes == expected_nbytes
|
||
assert shards[0].offset_bytes == 0
|
||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||
|
||
|
||
# ── Placement: replicate ─────────────────────────────────────────────
|
||
|
||
|
||
def test_replicate_placement():
|
||
"""(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB"""
|
||
shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8)
|
||
assert len(shards) == 8
|
||
full_nbytes = 1024 * 512 * 2 # 1 MB
|
||
for i, s in enumerate(shards):
|
||
assert s.local_pe == i
|
||
assert s.nbytes == full_nbytes
|
||
assert s.offset_bytes == 0 # each is a full copy
|
||
|
||
|
||
# ── Placement: tiled_column_major ─────────────────────────────────────
|
||
|
||
|
||
def test_tiled_column_major():
|
||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs"""
|
||
shards = tiled_column_major(
|
||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||
)
|
||
# 4 tiles along M, 4 tiles along K → 16 tiles total
|
||
assert len(shards) == 16
|
||
tile_bytes = 256 * 128 * 2 # 64 KB per tile
|
||
for s in shards:
|
||
assert s.nbytes == tile_bytes
|
||
# column-major: iterate K first, then M
|
||
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
|
||
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
|
||
# tile (m=2,k=0) → PE0, ...
|
||
assert shards[0].local_pe == 0
|
||
assert shards[1].local_pe == 1
|
||
assert shards[7].local_pe == 7
|
||
assert shards[8].local_pe == 0 # wraps around
|
||
# total coverage
|
||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||
|
||
|
||
# ── Placement: tiled_row_major ────────────────────────────────────────
|
||
|
||
|
||
def test_tiled_row_major():
|
||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs"""
|
||
shards = tiled_row_major(
|
||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||
)
|
||
assert len(shards) == 16
|
||
tile_bytes = 256 * 128 * 2
|
||
for s in shards:
|
||
assert s.nbytes == tile_bytes
|
||
# row-major: iterate M first, then K
|
||
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
|
||
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
|
||
# tile (m=0,k=2) → PE0, ...
|
||
assert shards[0].local_pe == 0
|
||
assert shards[1].local_pe == 1
|
||
assert shards[7].local_pe == 7
|
||
assert shards[8].local_pe == 0 # wraps around
|
||
# total coverage
|
||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||
|
||
|
||
# ── deploy_tensor ────────────────────────────────────────────────────
|
||
|
||
|
||
def test_deploy_tensor_hbm():
|
||
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
|
||
allocs = _make_allocators()
|
||
placement = resolve_dp_policy(
|
||
DPPolicy(cube="replicate", pe="column_wise"),
|
||
shape=(1024, 512), itemsize=2,
|
||
num_pe=8, num_cubes=1, target_sip=0,
|
||
)
|
||
th = deploy_tensor(
|
||
name="W",
|
||
shape=(1024, 512),
|
||
dtype="fp16",
|
||
placement=placement,
|
||
allocators=allocs,
|
||
mem_kind="hbm",
|
||
)
|
||
assert th.name == "W"
|
||
assert th.shape == (1024, 512)
|
||
assert th.dtype == "fp16"
|
||
assert th.itemsize == 2
|
||
assert len(th.shards) == 8
|
||
# each shard has a distinct PA
|
||
pas = [s.pa for s in th.shards]
|
||
assert len(set(pas)) == 8
|
||
# each shard placed on correct PE
|
||
for i, s in enumerate(th.shards):
|
||
assert s.pe == i
|
||
assert s.sip == 0
|
||
assert s.cube == 0
|
||
|
||
|
||
def test_deploy_tensor_tcm():
|
||
"""Deploy with TCM → uses pe_tcm_addr allocation."""
|
||
allocs = _make_allocators()
|
||
placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=256)]
|
||
th = deploy_tensor(
|
||
name="small",
|
||
shape=(128,),
|
||
dtype="fp16",
|
||
placement=placement,
|
||
allocators=allocs,
|
||
mem_kind="tcm",
|
||
)
|
||
assert len(th.shards) == 1
|
||
assert th.shards[0].pe == 0
|
||
assert th.shards[0].nbytes == 256
|
||
|
||
|
||
def test_deploy_tensor_overflow():
|
||
"""Allocation exceeding PE HBM capacity raises AllocationError."""
|
||
allocs = _make_allocators()
|
||
# 6 GB per PE slice, try to allocate 7 GB
|
||
big_shard = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=7 * _GB)
|
||
with pytest.raises(AllocationError):
|
||
deploy_tensor(
|
||
name="toobig",
|
||
shape=(1,),
|
||
dtype="int8",
|
||
placement=[big_shard],
|
||
allocators=allocs,
|
||
)
|