From 357cab525bb0e9113e6a267af968cde3543b29ab Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Tue, 14 Apr 2026 13:02:19 -0700 Subject: [PATCH] ADR-0026: DPPolicy intra-device only + ShardSpec structural coords MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DPPolicy no longer carries a cross-SIP axis. SIP-level placement is solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy itself describes only the cube × PE layout within one SIP. ShardSpec switches to structural (sip, cube, pe) coordinates; the flat pe_index field/property is fully removed — silent drift between global-flat and SIP-local interpretations was a foot-gun flagged by ADR-0024 D11. Breaking API (explicit TypeError / AttributeError): - DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError - ShardSpec.pe_index -> AttributeError - ShardSpec(pe_index=...) -> TypeError - resolve_dp_policy now takes target_sip= (required), no num_sips. Downstream migration: - PE allocator dict keyed by (sip, cube, pe) tuples, in both _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup. - _create_tensor passes target_sip=current_sip; post-hoc pe_index shifting removed entirely. - launch._compute_local_shape drops the dp.sip branch. - Internal resolvers (column_wise / row_wise / replicate / tiled_*) return _LocalPeShard (cube-local identifier) instead of ShardSpec — resolve_dp_policy lifts them to full structural coords. Tests: - New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the contract end-to-end. - test_sip_parallel.py rewritten: SIP composition now modeled as two resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style). - Call-site migration: test_tensor, test_va_integration, test_va_offset, test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy branch) all use intra-device DPPolicy and structural ShardSpec. Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027). Co-Authored-By: Claude Opus 4.6 (1M context) --- benches/ccl_allreduce.py | 23 +- benches/gemm_single_pe.py | 4 +- benches/gpt3_qkv.py | 12 +- benches/va_offset_verify.py | 4 +- docs/adr/ADR-0026-dppolicy-intra-device.md | 2 +- docs/ccl-author-guide.en.md | 4 +- docs/ccl-author-guide.md | 4 +- src/kernbench/policy/placement/dp.py | 160 +++++++------ src/kernbench/runtime_api/context.py | 59 ++--- src/kernbench/runtime_api/tensor.py | 13 +- tests/test_adr0026_dppolicy_intra_device.py | 239 ++++++++++++++++++++ tests/test_ccl_deadlock_detection.py | 4 +- tests/test_ccl_hello_world_guide.py | 4 +- tests/test_recv_copy_to_dst.py | 4 +- tests/test_runtime_api_tensor.py | 32 +-- tests/test_sip_parallel.py | 219 ++++++++---------- tests/test_tensor.py | 40 ++-- tests/test_tl_recv_async.py | 4 +- tests/test_va_integration.py | 26 ++- tests/test_va_offset.py | 20 +- 20 files changed, 549 insertions(+), 328 deletions(-) create mode 100644 tests/test_adr0026_dppolicy_intra_device.py diff --git a/benches/ccl_allreduce.py b/benches/ccl_allreduce.py index 00ea17d..8df6b81 100644 --- a/benches/ccl_allreduce.py +++ b/benches/ccl_allreduce.py @@ -32,29 +32,26 @@ def _derive_dp(spec: dict, world_size: int) -> DPPolicy: """Legacy DPPolicy for world_size > SIP count (rank = flat PE index). Used only in the ccl.yaml-override path so the existing matrix tests - with explicit world_size (8, 16, 7 etc.) keep working. The new - ADR-0024 TP path (rank = SIP) uses a per-rank DPPolicy inside the - worker instead. + with explicit world_size (8, 16, 7 etc.) keep working. ADR-0026: + DPPolicy is intra-device only, so this legacy path now always stays + within a single SIP and distributes the override world_size across + that SIP's cubes and PEs. """ - sips = int(spec["system"]["sips"]["count"]) - cm = spec["sip"]["cube_mesh"] pl = spec["cube"]["pe_layout"] pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"]) + cm = spec["sip"]["cube_mesh"] cubes_per_sip = int(cm["w"]) * int(cm["h"]) - total = sips * cubes_per_sip * pes_per_cube - if world_size == total: - return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") if world_size <= pes_per_cube: return DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, num_pes=world_size, + cube="replicate", pe="column_wise", + num_cubes=1, num_pes=world_size, ) if world_size <= cubes_per_sip * pes_per_cube: return DPPolicy( - sip="replicate", cube="column_wise", pe="column_wise", - num_sips=1, num_cubes=world_size // pes_per_cube, + cube="column_wise", pe="column_wise", + num_cubes=world_size // pes_per_cube, ) - return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") + return DPPolicy(cube="column_wise", pe="column_wise") def worker(rank: int, world_size: int, torch) -> None: diff --git a/benches/gemm_single_pe.py b/benches/gemm_single_pe.py index dda336f..6142ddd 100644 --- a/benches/gemm_single_pe.py +++ b/benches/gemm_single_pe.py @@ -3,7 +3,7 @@ Full host-to-PE pipeline: Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM -Single PE: num_sips=1, num_cubes=1, num_pes=1 via DPPolicy override. +Single PE: num_cubes=1, num_pes=1 via DPPolicy override. Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams per-tile DMA internally. @@ -30,7 +30,7 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): def run(torch): """Run the single-PE GEMM benchmark.""" dp = DPPolicy(cube="replicate", pe="replicate", - num_sips=1, num_cubes=1, num_pes=1) + num_cubes=1, num_pes=1) a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a") b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b") diff --git a/benches/gpt3_qkv.py b/benches/gpt3_qkv.py index 5ff8fd4..56e80f1 100644 --- a/benches/gpt3_qkv.py +++ b/benches/gpt3_qkv.py @@ -72,12 +72,16 @@ def run(torch): K = GPT3_D_MODEL N = COLS_PER_PE - # X: replicated across all PEs + # ADR-0026: DPPolicy is intra-device only. For multi-SIP execution the + # ADR-0024 launcher calls this bench once per SIP (each worker via + # torch.ahbm.set_device(rank)); here the policy describes only the + # cube × PE layout within a single SIP. + # X: replicated across all PEs within the SIP dp_replicate = DPPolicy(cube="replicate", pe="replicate", - num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) - # W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs + num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) + # W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs within the SIP dp_sharded = DPPolicy(cube="column_wise", pe="column_wise", - num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) + num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x") wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq") diff --git a/benches/va_offset_verify.py b/benches/va_offset_verify.py index 578ae30..7c2bdd1 100644 --- a/benches/va_offset_verify.py +++ b/benches/va_offset_verify.py @@ -1,7 +1,7 @@ """VA offset verification benchmark. Verifies that Triton-style base_ptr + pid * stride addressing works correctly -with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own +with intra-SIP TP sharding (cube/pe column_wise). Each PE loads its own block from a sharded tensor and stores it back. The kernel uses standard Triton patterns: @@ -28,7 +28,7 @@ def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"): def run(torch): """Run the VA offset verification benchmark with full TP sharding.""" - dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") + dp = DPPolicy(cube="column_wise", pe="column_wise") src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src") dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst") diff --git a/docs/adr/ADR-0026-dppolicy-intra-device.md b/docs/adr/ADR-0026-dppolicy-intra-device.md index 3d87aad..0b93535 100644 --- a/docs/adr/ADR-0026-dppolicy-intra-device.md +++ b/docs/adr/ADR-0026-dppolicy-intra-device.md @@ -2,7 +2,7 @@ ## Status -Proposed (Revision 4 — 문서 일관성 + grep audit 구체화) +Accepted (Revision 5 — Phase 2 landed 2026-04-14, 523 passed + 1 strict xfail) ## Context diff --git a/docs/ccl-author-guide.en.md b/docs/ccl-author-guide.en.md index e2e62f9..7fd38e1 100644 --- a/docs/ccl-author-guide.en.md +++ b/docs/ccl-author-guide.en.md @@ -129,8 +129,8 @@ N_ELEM = 8 def worker(rank: int, world_size: int, torch) -> None: """Per-rank business logic — mirrors a real PyTorch DDP worker.""" dp = DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, num_pes=world_size, + cube="replicate", pe="column_wise", + num_cubes=1, num_pes=world_size, ) tensor = torch.zeros( (1, world_size * N_ELEM), dtype="f16", dp=dp, name="hello_in", diff --git a/docs/ccl-author-guide.md b/docs/ccl-author-guide.md index 4fa7cb4..d785f24 100644 --- a/docs/ccl-author-guide.md +++ b/docs/ccl-author-guide.md @@ -114,8 +114,8 @@ def run(torch): a = torch.zeros( (1, WORLD_SIZE * N_ELEM), dtype="f16", dp=DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, + cube="replicate", pe="column_wise", + num_cubes=1, ), name="hello_in", ) diff --git a/src/kernbench/policy/placement/dp.py b/src/kernbench/policy/placement/dp.py index 5b0e01a..7d92d6b 100644 --- a/src/kernbench/policy/placement/dp.py +++ b/src/kernbench/policy/placement/dp.py @@ -1,3 +1,14 @@ +"""Data-parallel placement policy (ADR-0026: intra-device only). + +``DPPolicy`` describes how a tensor is sharded *within a single SIP* across +that SIP's cubes and PEs. Crossing the SIP boundary is not a DPPolicy +concern: ADR-0024's ``torch.ahbm.set_device(rank)`` picks the SIP, and +Megatron-style TP (ADR-0027) expresses multi-SIP tensors when needed. + +``ShardSpec`` is expressed in structural ``(sip, cube, pe)`` coordinates. +The former flat ``pe_index`` field/property is fully removed — callers +needing a flat integer key compute it explicitly at the call site. +""" from __future__ import annotations from dataclasses import dataclass @@ -7,25 +18,58 @@ from typing import Literal @dataclass(frozen=True) class DPPolicy: - """Three-level data-parallel policy: sip-level + cube-level + pe-level. + """Intra-device (cube × PE) data-parallel policy. - Policies: + SIP-level placement is controlled by ``torch.ahbm.set_device(rank)`` + (ADR-0024). For tensors that must cross SIP boundaries, use + Megatron-style parallel layers (ADR-0027). DPPolicy itself never + crosses a SIP boundary. + + Policies (per axis): - "replicate": full copy at each unit - "column_wise": split K (column) axis across units - "row_wise": split M (row) axis across units - Optional overrides (default None = use topology dimensions): - - num_pes: override PEs per cube (e.g., 1 for single-PE test) - - num_cubes: override cubes per SIP (e.g., 1 for single-cube test) - - num_sips: override SIP count + Optional overrides (``None`` = use topology dimensions): + - num_pes: override PEs per cube + - num_cubes: override cubes per SIP """ - sip: Literal["replicate", "column_wise", "row_wise"] = "replicate" cube: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" num_pes: int | None = None num_cubes: int | None = None - num_sips: int | None = None + + +@dataclass(frozen=True) +class ShardSpec: + """Structural shard placement — ``(sip, cube, pe)`` coord (ADR-0026). + + Global-flat ``pe_index`` was removed: callers must use structural + coords directly. If a flat integer key is needed in a local context + (e.g. internal dict lookup), compute it explicitly at the call site + and do not expose it in any public API. + """ + + sip: int + cube: int + pe: int + offset_bytes: int + nbytes: int + + +@dataclass(frozen=True) +class _LocalPeShard: + """Internal — PE resolver's return type (ADR-0026 D3). + + Holds a cube-local PE identifier (``local_pe``) plus the shard's + byte payload. Lifted into ``ShardSpec`` with full ``(sip, cube, pe)`` + coordinates inside ``resolve_dp_policy``. + """ + + local_pe: int + offset_bytes: int + nbytes: int def _split_shape( @@ -52,14 +96,13 @@ def resolve_dp_policy( itemsize: int, num_pe: int, num_cubes: int = 1, - num_sips: int = 1, + target_sip: int, ) -> list[ShardSpec]: - """Resolve a DPPolicy into a list[ShardSpec] with three-level resolution. + """Resolve a DPPolicy into a list[ShardSpec] on a single SIP. - SIP-level → cube-level → pe-level. - num_cubes is cubes per SIP (not total). - ShardSpec.pe_index uses flat indexing: - sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id + Two-level resolution (cube × PE) within ``target_sip``. Each returned + ``ShardSpec`` carries ``sip=target_sip`` and cube/pe local to the SIP. + No SIP-level split — DPPolicy is intra-device only (ADR-0026). """ _PE_RESOLVERS = { "replicate": replicate, @@ -70,84 +113,61 @@ def resolve_dp_policy( if resolver is None: raise ValueError(f"Unknown pe-level policy: {policy.pe}") - cubes_per_sip = num_cubes all_shards: list[ShardSpec] = [] - # Level 1: SIP - sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize) + # Level 1: cube within SIP + cube_splits = _split_shape(policy.cube, shape, num_cubes, itemsize) - for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits): - # Level 2: Cube within SIP - cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize) + for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits): + # Level 2: PE within cube — resolver returns _LocalPeShard + local_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) - for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits): - # Level 3: PE within cube - pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) - - for ps in pe_shards: - flat_idx = ( - sip_id * cubes_per_sip * num_pe - + cube_id * num_pe - + ps.pe_index - ) - all_shards.append(ShardSpec( - pe_index=flat_idx, - offset_bytes=sip_offset + cube_offset + ps.offset_bytes, - nbytes=ps.nbytes, - )) + for ls in local_shards: + all_shards.append(ShardSpec( + sip=target_sip, + cube=cube_id, + pe=ls.local_pe, + offset_bytes=cube_offset + ls.offset_bytes, + nbytes=ls.nbytes, + )) return all_shards -@dataclass(frozen=True) -class ShardSpec: - pe_index: int - offset_bytes: int - nbytes: int - - def column_wise( *, shape: tuple[int, int], itemsize: int, num_pe: int, -) -> list[ShardSpec]: +) -> list[_LocalPeShard]: """Split K axis into num_pe equal parts. Each PE gets (M, K/P).""" M, K = shape chunk_k = K // num_pe chunk_bytes = M * chunk_k * itemsize - shards = [] - for i in range(num_pe): - shards.append(ShardSpec( - pe_index=i, - offset_bytes=i * chunk_bytes, - nbytes=chunk_bytes, - )) - return shards + return [ + _LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes) + for i in range(num_pe) + ] def row_wise( *, shape: tuple[int, int], itemsize: int, num_pe: int, -) -> list[ShardSpec]: +) -> list[_LocalPeShard]: """Split M axis into num_pe equal parts. Each PE gets (M/P, K).""" M, K = shape chunk_m = M // num_pe chunk_bytes = chunk_m * K * itemsize - shards = [] - for i in range(num_pe): - shards.append(ShardSpec( - pe_index=i, - offset_bytes=i * chunk_bytes, - nbytes=chunk_bytes, - )) - return shards + return [ + _LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes) + for i in range(num_pe) + ] def replicate( *, shape: tuple[int, int], itemsize: int, num_pe: int, -) -> list[ShardSpec]: +) -> list[_LocalPeShard]: """Full copy per PE. Each PE gets (M, K).""" M, K = shape full_bytes = M * K * itemsize return [ - ShardSpec(pe_index=i, offset_bytes=0, nbytes=full_bytes) + _LocalPeShard(local_pe=i, offset_bytes=0, nbytes=full_bytes) for i in range(num_pe) ] @@ -155,20 +175,20 @@ def replicate( def tiled_column_major( *, shape: tuple[int, int], itemsize: int, num_pe: int, tile_m: int, tile_k: int, -) -> list[ShardSpec]: +) -> list[_LocalPeShard]: """2D tiling, column-major order (K axis first), round-robin across PEs.""" M, K = shape tiles_m = ceil(M / tile_m) tiles_k = ceil(K / tile_k) tile_bytes = tile_m * tile_k * itemsize row_bytes = K * itemsize - shards = [] + shards: list[_LocalPeShard] = [] idx = 0 for mi in range(tiles_m): for ki in range(tiles_k): offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize) - shards.append(ShardSpec( - pe_index=idx % num_pe, + shards.append(_LocalPeShard( + local_pe=idx % num_pe, offset_bytes=offset, nbytes=tile_bytes, )) @@ -179,20 +199,20 @@ def tiled_column_major( def tiled_row_major( *, shape: tuple[int, int], itemsize: int, num_pe: int, tile_m: int, tile_k: int, -) -> list[ShardSpec]: +) -> list[_LocalPeShard]: """2D tiling, row-major order (M axis first), round-robin across PEs.""" M, K = shape tiles_m = ceil(M / tile_m) tiles_k = ceil(K / tile_k) tile_bytes = tile_m * tile_k * itemsize row_bytes = K * itemsize - shards = [] + shards: list[_LocalPeShard] = [] idx = 0 for ki in range(tiles_k): for mi in range(tiles_m): offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize) - shards.append(ShardSpec( - pe_index=idx % num_pe, + shards.append(_LocalPeShard( + local_pe=idx % num_pe, offset_bytes=offset, nbytes=tile_bytes, )) diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index 9e391c8..786a114 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -89,7 +89,7 @@ class RuntimeContext: _handles: list[RequestHandle] = field(default_factory=list, init=False) _completed: set[RequestHandle] = field(default_factory=set, init=False) - _allocators: dict[int, Any] = field(default_factory=dict, init=False) + _allocators: dict[tuple[int, int, int], Any] = field(default_factory=dict, init=False) _va_allocator: Any = field(default=None, init=False) _tensor_counter: int = field(default=0, init=False) _traces: list[dict] = field(default_factory=list, init=False) @@ -270,12 +270,7 @@ class RuntimeContext: # Return PA space if self._allocators: for shard in handle.shards: - flat_idx = ( - shard.sip * self._num_cubes * self._pes_per_cube - + shard.cube * self._pes_per_cube - + shard.pe - ) - alloc = self._allocators.get(flat_idx) + alloc = self._allocators.get((shard.sip, shard.cube, shard.pe)) if alloc is not None: from kernbench.policy.address.phyaddr import PhysAddr alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes) @@ -339,17 +334,15 @@ class RuntimeContext: tcm_scheduler_reserved_bytes=4 * (1 << 20), sram_bytes_per_cube=32 * (1 << 20), ) - # Create allocators scoped to target SIP(s) only - # Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id + # Create allocators scoped to target SIP(s) only. + # ADR-0026 D5: dict key is the structural (sip, cube, pe) tuple. self._pes_per_cube = pes_per_cube self._num_cubes = cubes_per_sip self._num_sips = sip_count - cubes_x_pes = cubes_per_sip * pes_per_cube for sip_id in sip_range: for cube_id in range(cubes_per_sip): for pe_id in range(pes_per_cube): - flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id - self._allocators[flat_idx] = PEMemAllocator( + self._allocators[(sip_id, cube_id, pe_id)] = PEMemAllocator( rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg, ) @@ -436,44 +429,23 @@ class RuntimeContext: # DPPolicy overrides take precedence over topology dimensions eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes - # ADR-0024 D11: if torch.ahbm.set_device(r) is active AND DPPolicy - # leaves the SIP dimension at its default (replicate + no num_sips - # override), scope the tensor to SIP r only. - # NOTE: this path uses post-hoc pe_index shifting as a temporary - # measure; ADR-0026 replaces it with structural (sip, cube, pe) - # coords in ShardSpec. + # ADR-0026 D4: resolve structural coords directly at resolve time. + # ``torch.ahbm.set_device(rank)`` (ADR-0024 D10) selects the target + # SIP; if unset, fall back to SIP 0 for single-driver compatibility. current_sip = ( self.ahbm.current_device() if hasattr(self, "ahbm") else None ) - scope_to_current_sip = ( - current_sip is not None - and dp.sip == "replicate" - and dp.num_sips is None - ) - if scope_to_current_sip: - eff_num_sips = 1 - else: - eff_num_sips = ( - dp.num_sips if dp.num_sips is not None else self._num_sips - ) + if current_sip is None: + current_sip = 0 placement = resolve_dp_policy( dp, shape=shape_2d, itemsize=itemsize, num_pe=eff_num_pe, num_cubes=eff_num_cubes, - num_sips=eff_num_sips, + target_sip=int(current_sip), ) - if scope_to_current_sip: - from kernbench.policy.placement.dp import ShardSpec as _SS - sip_stride = self._num_cubes * self._pes_per_cube - offset = int(current_sip) * sip_stride - placement = [ - _SS(pe_index=s.pe_index + offset, - offset_bytes=s.offset_bytes, nbytes=s.nbytes) - for s in placement - ] # Infer target_pe from placement using local (within-cube) PE IDs. # This ensures M_CPU only fans out to PEs that own shards, not all PEs. - local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement}) + local_pe_ids = sorted({s.pe for s in placement}) if len(local_pe_ids) == 1: target_pe: int | tuple[int, ...] | str = local_pe_ids[0] elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube: @@ -669,11 +641,8 @@ class RuntimeContext: dp = t._dp_metadata.dp_policy if t._dp_metadata else None if dp is None: return t.shape - if dp.sip != "replicate": - if dp.sip == "column_wise": - K = K // self._num_sips - elif dp.sip == "row_wise": - M = M // self._num_sips + # ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE + # are the only axes that shrink the local shape. if dp.cube != "replicate": if dp.cube == "column_wise": K = K // self._num_cubes diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 8226f3c..f7fe9e4 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -72,7 +72,7 @@ def deploy_tensor( shape: tuple[int, ...], dtype: str, placement: list[ShardSpec], - allocators: dict[int, PEMemAllocator], + allocators: dict[tuple[int, int, int], PEMemAllocator], mem_kind: Literal["hbm", "tcm"] = "hbm", va_allocator=None, ) -> TensorHandle: @@ -86,15 +86,15 @@ def deploy_tensor( shards: list[TensorShard] = [] for spec in placement: - alloc = allocators[spec.pe_index] + alloc = allocators[(spec.sip, spec.cube, spec.pe)] if mem_kind == "hbm": pa = alloc.alloc_hbm(spec.nbytes) else: pa = alloc.alloc_tcm(spec.nbytes) shards.append(TensorShard( - sip=alloc._sip_id, - cube=alloc._cube_id, - pe=alloc._pe_id, + sip=spec.sip, + cube=spec.cube, + pe=spec.pe, pa=pa.encode(), nbytes=spec.nbytes, offset_bytes=spec.offset_bytes, @@ -394,7 +394,8 @@ class Tensor: ) -> Tensor: """Set DP placement metadata (like torch.Tensor.to()).""" if placement is None: - placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)] + placement = [ShardSpec(sip=0, cube=0, pe=0, + offset_bytes=0, nbytes=self.nbytes)] self._dp_metadata = DPMetadata( placement=placement, dp_policy=dp_policy, sip=sip, cube=cube, target_pe=target_pe, diff --git a/tests/test_adr0026_dppolicy_intra_device.py b/tests/test_adr0026_dppolicy_intra_device.py new file mode 100644 index 0000000..1fcf096 --- /dev/null +++ b/tests/test_adr0026_dppolicy_intra_device.py @@ -0,0 +1,239 @@ +"""ADR-0026 Phase 1 tests: DPPolicy intra-device only + ShardSpec structural. + +These tests encode the contract from ADR-0026: + +- DPPolicy no longer accepts ``sip`` or ``num_sips`` kwargs (TypeError). +- ShardSpec carries structural ``(sip, cube, pe)`` coordinates; the old flat + ``pe_index`` field/property is fully removed (AttributeError). +- ``resolve_dp_policy(..., target_sip=N)`` stamps every returned ShardSpec + with ``sip=N``; cube and pe fields are local. +- ``RuntimeContext._allocators`` is keyed by ``(sip, cube, pe)`` tuples. + +Phase 1: production code is unchanged → these tests SHOULD FAIL until the +Phase 2 diff lands. Phase 2 makes all of them pass. +""" +from __future__ import annotations + +import pytest + +from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator +from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy +from kernbench.runtime_api.tensor import deploy_tensor + + +# ── D1: DPPolicy no longer accepts sip / num_sips ───────────────────── + + +def test_dppolicy_rejects_sip_kwarg(): + """DPPolicy(sip=...) must raise TypeError after field removal.""" + with pytest.raises(TypeError): + DPPolicy(sip="column_wise", cube="replicate", pe="replicate") + + +def test_dppolicy_rejects_num_sips_kwarg(): + """DPPolicy(num_sips=...) must raise TypeError after field removal.""" + with pytest.raises(TypeError): + DPPolicy(cube="replicate", pe="replicate", num_sips=2) + + +def test_dppolicy_accepts_only_intra_device_fields(): + """Intra-device fields still work: cube, pe, num_cubes, num_pes.""" + dp = DPPolicy(cube="column_wise", pe="column_wise", + num_cubes=2, num_pes=4) + assert dp.cube == "column_wise" + assert dp.pe == "column_wise" + assert dp.num_cubes == 2 + assert dp.num_pes == 4 + # No sip / num_sips attributes — even reading them must fail. + assert not hasattr(dp, "sip"), "DPPolicy.sip must be removed" + assert not hasattr(dp, "num_sips"), "DPPolicy.num_sips must be removed" + + +# ── D2: ShardSpec structural coords, no pe_index ────────────────────── + + +def test_shardspec_has_structural_coords(): + """ShardSpec constructs from (sip, cube, pe, offset_bytes, nbytes).""" + s = ShardSpec(sip=1, cube=2, pe=3, offset_bytes=128, nbytes=64) + assert s.sip == 1 + assert s.cube == 2 + assert s.pe == 3 + assert s.offset_bytes == 128 + assert s.nbytes == 64 + + +def test_shardspec_has_no_pe_index_attr(): + """Flat pe_index must be fully removed — no field, no property.""" + s = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=8) + with pytest.raises(AttributeError): + _ = s.pe_index # noqa: F841 + + +def test_shardspec_rejects_pe_index_kwarg(): + """ShardSpec(pe_index=...) must raise TypeError.""" + with pytest.raises(TypeError): + ShardSpec(pe_index=0, offset_bytes=0, nbytes=8) # type: ignore[call-arg] + + +# ── D3: resolve_dp_policy(target_sip=...) structural semantics ──────── + + +def test_resolve_dp_policy_target_sip_stamps_shards(): + """All returned shards must carry sip == target_sip; cube/pe local.""" + dp = DPPolicy(cube="column_wise", pe="column_wise") + shards = resolve_dp_policy( + dp, shape=(4, 32), itemsize=2, + num_pe=4, num_cubes=2, target_sip=1, + ) + assert len(shards) == 2 * 4 + assert all(s.sip == 1 for s in shards) + assert all(0 <= s.cube < 2 for s in shards) + assert all(0 <= s.pe < 4 for s in shards) + + +def test_resolve_dp_policy_target_sip_differ_only_in_sip(): + """Same policy + dims on two SIPs → shards identical except .sip.""" + dp = DPPolicy(cube="replicate", pe="column_wise") + shards_0 = resolve_dp_policy( + dp, shape=(4, 32), itemsize=2, + num_pe=4, num_cubes=1, target_sip=0, + ) + shards_1 = resolve_dp_policy( + dp, shape=(4, 32), itemsize=2, + num_pe=4, num_cubes=1, target_sip=1, + ) + assert len(shards_0) == len(shards_1) + for a, b in zip(shards_0, shards_1): + assert a.sip == 0 and b.sip == 1 + assert a.cube == b.cube + assert a.pe == b.pe + assert a.offset_bytes == b.offset_bytes + assert a.nbytes == b.nbytes + + +def test_resolve_dp_policy_no_num_sips_param(): + """resolve_dp_policy must not accept num_sips anymore. + + Post-Phase-2 signature drops ``num_sips`` (DPPolicy no longer crosses + SIP boundaries) and adds required ``target_sip``. Calling with + ``num_sips=...`` must raise TypeError (unexpected keyword argument). + """ + dp = DPPolicy(cube="replicate", pe="replicate") + with pytest.raises(TypeError, match="num_sips"): + resolve_dp_policy( + dp, shape=(4, 8), itemsize=2, + num_pe=1, num_cubes=1, num_sips=2, # type: ignore[call-arg] + ) + + +# ── D5: Allocator dict keyed by (sip, cube, pe) tuples ──────────────── + + +_MB = 1 << 20 +_GB = 1 << 30 + +_CFG = AddressConfig( + sip_count=2, + cubes_per_sip=2, + pes_per_cube=4, + hbm_bytes_per_cube=_GB, + hbm_slices_per_cube=4, + tcm_bytes_per_pe=_MB, + tcm_scheduler_reserved_bytes=0, + sram_bytes_per_cube=_MB, +) + + +def _make_tuple_allocators( + num_sips: int = 1, num_cubes: int = 1, num_pe: int = 4, +) -> dict[tuple[int, int, int], PEMemAllocator]: + return { + (s, c, p): PEMemAllocator( + rack_id=0, sip_id=s, cube_id=c, pe_id=p, cfg=_CFG, + ) + for s in range(num_sips) + for c in range(num_cubes) + for p in range(num_pe) + } + + +def test_deploy_tensor_uses_tuple_lookup(): + """deploy_tensor(allocators={(sip,cube,pe): alloc, ...}) succeeds.""" + dp = DPPolicy(cube="replicate", pe="column_wise") + placement = resolve_dp_policy( + dp, shape=(4, 16), itemsize=2, + num_pe=4, num_cubes=1, target_sip=0, + ) + allocators = _make_tuple_allocators(num_sips=1, num_cubes=1, num_pe=4) + handle = deploy_tensor( + name="t", shape=(4, 16), dtype="f16", + placement=placement, allocators=allocators, + ) + assert len(handle.shards) == 4 + # Each shard's TensorShard carries structural coords; those coords + # must match the shard's ShardSpec (sip, cube, pe). + for spec, shard in zip(placement, handle.shards): + assert shard.sip == spec.sip + assert shard.cube == spec.cube + assert shard.pe == spec.pe + + +def test_runtime_context_allocator_keys_are_tuples(topology): + """After ctx tensor op, ctx._allocators keys are (sip, cube, pe) tuples. + + Ensures D5 migration landed (allocator population + lookup). + """ + from kernbench.runtime_api.context import RuntimeContext + from kernbench.runtime_api.types import DeviceSelector + from kernbench.sim_engine.engine import GraphEngine + + engine = GraphEngine(topology.topology_obj, enable_data=True) + ctx = RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:0"), + correlation_id="test_adr0026_tuple_keys", + spec=topology.topology_obj.spec, + ) + dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) + _ = ctx.zeros((1, 16), dtype="f16", dp=dp) + + assert ctx._allocators, "allocators dict should be populated" + keys = list(ctx._allocators.keys()) + assert all(isinstance(k, tuple) and len(k) == 3 for k in keys), ( + f"_allocators keys must be (sip, cube, pe) tuples; got {keys[:5]}" + ) + + +# ── D4 (via regression): no SIP-crossing tensor without set_device ──── + + +def test_create_tensor_on_target_sip_via_set_device(topology): + """torch.ahbm.set_device(1) + DPPolicy(cube=replicate, pe=replicate) + → all shards land on SIP 1 structurally (no post-hoc shifting needed).""" + from kernbench.runtime_api.context import RuntimeContext + from kernbench.runtime_api.types import DeviceSelector + from kernbench.sim_engine.engine import GraphEngine + + # Skip the test if topology has only 1 SIP (nothing to verify). + n_sips = int( + topology.topology_obj.spec.get("system", {}) + .get("sips", {}).get("count", 1) + ) + if n_sips < 2: + pytest.skip("topology has <2 SIPs; set_device(1) not meaningful") + + engine = GraphEngine(topology.topology_obj, enable_data=True) + ctx = RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:1"), + correlation_id="test_adr0026_set_device", + spec=topology.topology_obj.spec, + ) + ctx.ahbm.set_device(1) + dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) + t = ctx.zeros((1, 16), dtype="f16", dp=dp) + + assert t._handle is not None + assert all(s.sip == 1 for s in t._handle.shards), ( + f"expected all shards on SIP 1; got {[s.sip for s in t._handle.shards]}" + ) diff --git a/tests/test_ccl_deadlock_detection.py b/tests/test_ccl_deadlock_detection.py index 9dbb133..f5a5083 100644 --- a/tests/test_ccl_deadlock_detection.py +++ b/tests/test_ccl_deadlock_detection.py @@ -108,8 +108,8 @@ def test_deadlock_detection_recv_without_send(): (1, 8 * 8), dtype="f16", dp=DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, + cube="replicate", pe="column_wise", + num_cubes=1, ), name="dl_in", ) diff --git a/tests/test_ccl_hello_world_guide.py b/tests/test_ccl_hello_world_guide.py index 68bc017..9e76bf7 100644 --- a/tests/test_ccl_hello_world_guide.py +++ b/tests/test_ccl_hello_world_guide.py @@ -51,8 +51,8 @@ def test_hello_send_via_simpy_runner(): a = torch.zeros( (1, world_size * n_elem), dtype="f16", dp=DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, + cube="replicate", pe="column_wise", + num_cubes=1, ), name="hello_in", ) diff --git a/tests/test_recv_copy_to_dst.py b/tests/test_recv_copy_to_dst.py index c4388dc..ee035b7 100644 --- a/tests/test_recv_copy_to_dst.py +++ b/tests/test_recv_copy_to_dst.py @@ -48,8 +48,8 @@ def test_recv_copy_to_dst_via_simpy_runner(): (1, 8 * 8), dtype="f16", dp=DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, + cube="replicate", pe="column_wise", + num_cubes=1, ), name="copy_in", ) diff --git a/tests/test_runtime_api_tensor.py b/tests/test_runtime_api_tensor.py index 54a4698..5876bac 100644 --- a/tests/test_runtime_api_tensor.py +++ b/tests/test_runtime_api_tensor.py @@ -48,8 +48,8 @@ def test_from_numpy_creates_host_tensor(): assert h._handle is None # Submit a no-op so run_bench has at least one handle. torch.zeros((1, 8), dtype="f16", - dp=DPPolicy(sip="replicate", cube="replicate", pe="replicate", - num_sips=1, num_cubes=1, num_pes=1), + dp=DPPolicy(cube="replicate", pe="replicate", + num_cubes=1, num_pes=1), name="dummy") _run_with(body) @@ -63,8 +63,8 @@ def test_copy_and_numpy_single_pe(): a single-PE (no real sharding) tensor.""" def body(torch): - dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", - num_sips=1, num_cubes=1, num_pes=1) + dp = DPPolicy(cube="replicate", pe="replicate", + num_cubes=1, num_pes=1) t = torch.zeros((1, 16), dtype="f16", dp=dp, name="t") src = np.arange(16, dtype=np.float16).reshape(1, 16) t.copy_(torch.from_numpy(src)) @@ -83,8 +83,8 @@ def test_copy_and_numpy_multi_pe_column_wise(): def body(torch): n_pe = 8 - dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, num_pes=n_pe) + dp = DPPolicy(cube="replicate", pe="column_wise", + num_cubes=1, num_pes=n_pe) t = torch.zeros((1, n_pe * 4), dtype="f16", dp=dp, name="t") src = np.arange(n_pe * 4, dtype=np.float16).reshape(1, n_pe * 4) t.copy_(torch.from_numpy(src)) @@ -107,8 +107,8 @@ def test_copy_and_numpy_multi_cube(): n_pe_per_cube = 8 n_cubes = 2 total = n_cubes * n_pe_per_cube # 16 - dp = DPPolicy(sip="replicate", cube="column_wise", pe="column_wise", - num_sips=1, num_cubes=n_cubes) + dp = DPPolicy(cube="column_wise", pe="column_wise", + num_cubes=n_cubes) t = torch.zeros((1, total * 4), dtype="f16", dp=dp, name="t") src = np.arange(total * 4, dtype=np.float16).reshape(1, total * 4) t.copy_(torch.from_numpy(src)) @@ -126,8 +126,8 @@ def test_copy_shape_mismatch_raises(): """copy_ with mismatched shapes raises ValueError.""" def body(torch): - dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", - num_sips=1, num_cubes=1, num_pes=1) + dp = DPPolicy(cube="replicate", pe="replicate", + num_cubes=1, num_pes=1) t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t") src = np.zeros((1, 16), dtype=np.float16) with pytest.raises(ValueError, match="copy_ shape mismatch"): @@ -143,8 +143,8 @@ def test_setitem_getitem_single_pe(): """Scalar and slice assignment on a single-PE tensor round-trips.""" def body(torch): - dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", - num_sips=1, num_cubes=1, num_pes=1) + dp = DPPolicy(cube="replicate", pe="replicate", + num_cubes=1, num_pes=1) t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t") # Scalar broadcast @@ -169,8 +169,8 @@ def test_setitem_getitem_multi_pe_shard_aligned(): def body(torch): n_pe = 8 n_elem = 4 # per shard - dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, num_pes=n_pe) + dp = DPPolicy(cube="replicate", pe="column_wise", + num_cubes=1, num_pes=n_pe) t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") # Write each shard with its rank value @@ -197,8 +197,8 @@ def test_setitem_cross_shard_raises(): def body(torch): n_pe = 4 n_elem = 4 - dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, num_pes=n_pe) + dp = DPPolicy(cube="replicate", pe="column_wise", + num_cubes=1, num_pes=n_pe) t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") with pytest.raises(NotImplementedError, match="spans multiple shards"): t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8) diff --git a/tests/test_sip_parallel.py b/tests/test_sip_parallel.py index 33d15fa..5477cc1 100644 --- a/tests/test_sip_parallel.py +++ b/tests/test_sip_parallel.py @@ -1,157 +1,120 @@ -"""Tests for SIP-level tensor parallelism. +"""Tests for SIP-level tensor parallelism — ADR-0026 structural model. -Validates: - SP1. DPPolicy accepts sip field (default "replicate", backward compat) - SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips - SP3. sip="row_wise": tensor M-axis split across SIPs - SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets - SP5. sip="replicate": all SIPs get full copy (existing behavior) - SP6. PE_CPU sets num_programs from shard count per cube - SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology +DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is +now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)`` +emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling +the resolver once per SIP (typically driven by the ADR-0024 launcher, one +worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``). + +Covered here: + SP1. ``target_sip`` stamps every shard. + SP2. Two-SIP placement: union of two resolver calls covers the whole + tensor K-axis when the combined bench treats them as column-split. + SP3. Same for row-wise. + SP4. Cube + PE sharding within a SIP remains correct across SIPs. + SP5. PE_CPU num_programs contract (unchanged by ADR-0026). """ -import pytest -from pathlib import Path +from __future__ import annotations -from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy +from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy -# ── SP1. DPPolicy sip field ────────────────────────────────────────── +# ── SP1. target_sip stamps shards ──────────────────────────────────── -def test_dp_policy_sip_default_replicate(): - """DPPolicy without sip= defaults to 'replicate'.""" +def test_target_sip_stamps_all_shards(): dp = DPPolicy(cube="replicate", pe="column_wise") - assert dp.sip == "replicate" - - -def test_dp_policy_sip_column_wise(): - """DPPolicy accepts sip='column_wise'.""" - dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") - assert dp.sip == "column_wise" - - -# ── SP2. sip="column_wise" ────────────────────────────────────────────── - - -def test_sip_column_wise_splits_across_sips(): - """sip='column_wise' with 2 SIPs: each SIP gets K//2 columns.""" - dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") shards = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, - num_pe=8, num_cubes=1, num_sips=2, + num_pe=8, num_cubes=1, target_sip=3, ) - # 2 SIPs × 1 cube × 8 PEs = 16 shards - assert len(shards) == 16 - - # SIP0 shards: first half of K (0 to K//2) - # SIP1 shards: second half of K (K//2 to K) - total_bytes = 128 * 256 * 2 # 64KB - sip0_shards = [s for s in shards if s.pe_index < 8] - sip1_shards = [s for s in shards if s.pe_index >= 8] - - # SIP0 offsets start at 0 - assert sip0_shards[0].offset_bytes == 0 - # SIP1 offsets start at half - assert sip1_shards[0].offset_bytes == total_bytes // 2 - - # Total coverage - assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2 - assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2 + assert all(s.sip == 3 for s in shards) + assert all(0 <= s.pe < 8 for s in shards) + assert all(s.cube == 0 for s in shards) -# ── SP3. sip="row_wise" ────────────────────────────────────────────── +# ── SP2. column-wise placement composed across two SIPs ───────────── -def test_sip_row_wise_splits_across_sips(): - """sip='row_wise' with 2 SIPs: each SIP gets M//2 rows.""" - dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise") - shards = resolve_dp_policy( +def test_compose_two_sips_column_wise_covers_tensor(): + """Bench splits K-axis across 2 SIPs by calling resolve twice and + giving each SIP half of the tensor (half-shape + offset). Shards + from both SIPs together cover the whole K axis.""" + full_shape = (128, 256) + itemsize = 2 + # Per-SIP half-shape (K split across SIPs). + half_shape = (128, 128) + dp = DPPolicy(cube="replicate", pe="column_wise") + + shards_sip0 = resolve_dp_policy( + dp, shape=half_shape, itemsize=itemsize, + num_pe=8, num_cubes=1, target_sip=0, + ) + shards_sip1 = resolve_dp_policy( + dp, shape=half_shape, itemsize=itemsize, + num_pe=8, num_cubes=1, target_sip=1, + ) + + total_bytes = full_shape[0] * full_shape[1] * itemsize + sip0_bytes = sum(s.nbytes for s in shards_sip0) + sip1_bytes = sum(s.nbytes for s in shards_sip1) + assert sip0_bytes + sip1_bytes == total_bytes + assert all(s.sip == 0 for s in shards_sip0) + assert all(s.sip == 1 for s in shards_sip1) + + +# ── SP3. row-wise placement composed across two SIPs ──────────────── + + +def test_compose_two_sips_row_wise_covers_tensor(): + full_shape = (128, 256) + itemsize = 2 + half_shape = (64, 256) # per-SIP half of M + dp = DPPolicy(cube="replicate", pe="column_wise") + + shards_sip0 = resolve_dp_policy( + dp, shape=half_shape, itemsize=itemsize, + num_pe=8, num_cubes=1, target_sip=0, + ) + shards_sip1 = resolve_dp_policy( + dp, shape=half_shape, itemsize=itemsize, + num_pe=8, num_cubes=1, target_sip=1, + ) + + total_bytes = full_shape[0] * full_shape[1] * itemsize + assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes + + +# ── SP4. cube × PE sharding is independent per SIP ────────────────── + + +def test_cube_pe_sharding_independent_per_sip(): + """Intra-SIP cube + PE layout matches across SIPs; only sip field differs.""" + dp = DPPolicy(cube="column_wise", pe="column_wise") + s0 = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, - num_pe=8, num_cubes=1, num_sips=2, + num_pe=4, num_cubes=2, target_sip=0, ) - assert len(shards) == 16 - - sip0_shards = [s for s in shards if s.pe_index < 8] - sip1_shards = [s for s in shards if s.pe_index >= 8] - - # SIP0: rows 0..63, SIP1: rows 64..127 - total_bytes = 128 * 256 * 2 - assert sip0_shards[0].offset_bytes == 0 - assert sip1_shards[0].offset_bytes == total_bytes // 2 - - -# ── SP4. 3-level resolve ───────────────────────────────────────────── - - -def test_3level_resolve_flat_index(): - """3-level: sip × cube × pe produces correct flat indices.""" - dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") - shards = resolve_dp_policy( + s1 = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, - num_pe=8, num_cubes=2, num_sips=2, + num_pe=4, num_cubes=2, target_sip=1, ) - # 2 SIPs × 2 cubes × 8 PEs = 32 shards - assert len(shards) == 32 - - # Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id - indices = [s.pe_index for s in shards] - # SIP0: 0..15, SIP1: 16..31 - assert min(indices) == 0 - assert max(indices) == 31 - assert len(set(indices)) == 32 # all unique + assert len(s0) == len(s1) == 2 * 4 + for a, b in zip(s0, s1): + assert a.sip == 0 and b.sip == 1 + assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == ( + b.cube, b.pe, b.offset_bytes, b.nbytes + ) -def test_3level_offsets_cover_full_tensor(): - """3-level sharding covers the entire tensor with no gaps.""" - dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") - shards = resolve_dp_policy( - dp, shape=(128, 256), itemsize=2, - num_pe=4, num_cubes=1, num_sips=2, - ) - # 2 SIPs × 1 cube × 4 PEs = 8 shards - # sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE - total = 128 * 256 * 2 - # For non-replicate, total shard bytes == tensor bytes - # (replicate within cube means cube shards overlap, but sip shards don't) - sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4) - sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4) - assert sip0_bytes + sip1_bytes == total - - -# ── SP5. sip="replicate" backward compat ───────────────────────────── - - -def test_sip_replicate_backward_compat(): - """sip='replicate' produces same result as before (2-level).""" - dp_old = DPPolicy(cube="replicate", pe="column_wise") - dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise") - - shards_old = resolve_dp_policy( - dp_old, shape=(128, 256), itemsize=2, - num_pe=8, num_cubes=2, num_sips=2, - ) - shards_new = resolve_dp_policy( - dp_new, shape=(128, 256), itemsize=2, - num_pe=8, num_cubes=2, num_sips=2, - ) - assert len(shards_old) == len(shards_new) - for a, b in zip(shards_old, shards_new): - assert a.pe_index == b.pe_index - assert a.offset_bytes == b.offset_bytes - assert a.nbytes == b.nbytes - - -# ── SP6. PE_CPU num_programs ────────────────────────────────────────── +# ── SP5. PE_CPU num_programs (contract unchanged) ─────────────────── def test_pe_cpu_sets_num_programs(): - """PE_CPU should create TLContext with num_programs = PEs per cube.""" - # This test validates the interface contract. - # After implementation, PE_CPU should derive num_programs from the - # number of PE shards in the kernel launch's target cube. + """TLContext reports num_programs from its initializer — used by PE_CPU + when it launches a kernel on behalf of its shards.""" from kernbench.triton_emu.tl_context import TLContext - # With 8 PEs per cube, num_programs should be 8 tl = TLContext(pe_id=3, num_programs=8) assert tl.program_id(0) == 3 assert tl.num_programs(0) == 8 diff --git a/tests/test_tensor.py b/tests/test_tensor.py index a89109f..7a8b568 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -2,11 +2,13 @@ import pytest from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator from kernbench.policy.placement.dp import ( + DPPolicy, ShardSpec, column_wise, - tiled_column_major, replicate, + resolve_dp_policy, row_wise, + tiled_column_major, tiled_row_major, ) from kernbench.runtime_api.kernel import ( @@ -40,9 +42,9 @@ _CFG = AddressConfig( ) -def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]: +def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]: return { - i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) + (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) for i in range(num_pe) } @@ -133,7 +135,7 @@ def test_column_wise_placement(): assert len(shards) == 8 expected_nbytes = 1024 * 64 * 2 # 128 KB for i, s in enumerate(shards): - assert s.pe_index == i + assert s.local_pe == i assert s.nbytes == expected_nbytes # offsets are contiguous assert shards[0].offset_bytes == 0 @@ -151,7 +153,7 @@ def test_row_wise_placement(): assert len(shards) == 8 expected_nbytes = 128 * 512 * 2 # 128 KB for i, s in enumerate(shards): - assert s.pe_index == i + assert s.local_pe == i assert s.nbytes == expected_nbytes assert shards[0].offset_bytes == 0 assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 @@ -166,7 +168,7 @@ def test_replicate_placement(): assert len(shards) == 8 full_nbytes = 1024 * 512 * 2 # 1 MB for i, s in enumerate(shards): - assert s.pe_index == i + assert s.local_pe == i assert s.nbytes == full_nbytes assert s.offset_bytes == 0 # each is a full copy @@ -188,10 +190,10 @@ def test_tiled_column_major(): # tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3 # tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7 # tile (m=2,k=0) → PE0, ... - assert shards[0].pe_index == 0 - assert shards[1].pe_index == 1 - assert shards[7].pe_index == 7 - assert shards[8].pe_index == 0 # wraps around + assert shards[0].local_pe == 0 + assert shards[1].local_pe == 1 + assert shards[7].local_pe == 7 + assert shards[8].local_pe == 0 # wraps around # total coverage assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 @@ -212,10 +214,10 @@ def test_tiled_row_major(): # tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3 # tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7 # tile (m=0,k=2) → PE0, ... - assert shards[0].pe_index == 0 - assert shards[1].pe_index == 1 - assert shards[7].pe_index == 7 - assert shards[8].pe_index == 0 # wraps around + assert shards[0].local_pe == 0 + assert shards[1].local_pe == 1 + assert shards[7].local_pe == 7 + assert shards[8].local_pe == 0 # wraps around # total coverage assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 @@ -226,7 +228,11 @@ def test_tiled_row_major(): def test_deploy_tensor_hbm(): """Deploy with column_wise placement → TensorHandle with valid PA shards.""" allocs = _make_allocators() - placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(1024, 512), itemsize=2, + num_pe=8, num_cubes=1, target_sip=0, + ) th = deploy_tensor( name="W", shape=(1024, 512), @@ -253,7 +259,7 @@ def test_deploy_tensor_hbm(): def test_deploy_tensor_tcm(): """Deploy with TCM → uses pe_tcm_addr allocation.""" allocs = _make_allocators() - placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)] + placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=256)] th = deploy_tensor( name="small", shape=(128,), @@ -271,7 +277,7 @@ def test_deploy_tensor_overflow(): """Allocation exceeding PE HBM capacity raises AllocationError.""" allocs = _make_allocators() # 6 GB per PE slice, try to allocate 7 GB - big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB) + big_shard = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=7 * _GB) with pytest.raises(AllocationError): deploy_tensor( name="toobig", diff --git a/tests/test_tl_recv_async.py b/tests/test_tl_recv_async.py index 37aae56..881d203 100644 --- a/tests/test_tl_recv_async.py +++ b/tests/test_tl_recv_async.py @@ -75,8 +75,8 @@ def test_recv_async_simpy_runner(): (1, 8 * 8), dtype="f16", dp=DPPolicy( - sip="replicate", cube="replicate", pe="column_wise", - num_sips=1, num_cubes=1, + cube="replicate", pe="column_wise", + num_cubes=1, ), name="async_in", ) diff --git a/tests/test_va_integration.py b/tests/test_va_integration.py index 3ecbe6b..2998bde 100644 --- a/tests/test_va_integration.py +++ b/tests/test_va_integration.py @@ -12,7 +12,7 @@ import pytest from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.va_allocator import VirtualAllocator -from kernbench.policy.placement.dp import column_wise, ShardSpec +from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy from kernbench.runtime_api.tensor import ( TensorHandle, TensorShard, @@ -37,9 +37,9 @@ _CFG = AddressConfig( ) -def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]: +def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]: return { - i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) + (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) for i in range(num_pe) } @@ -88,7 +88,11 @@ def test_deploy_tensor_assigns_va_base(): """deploy_tensor with VA allocator assigns va_base to TensorHandle.""" allocs = _make_allocators() va_alloc = _make_va_allocator() - placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(1024, 512), itemsize=2, + num_pe=8, num_cubes=1, target_sip=0, + ) th = deploy_tensor( name="W", @@ -107,7 +111,11 @@ def test_deploy_tensor_va_covers_all_shards(): """VA allocation covers the entire tensor; each shard is at va_base + offset.""" allocs = _make_allocators() va_alloc = _make_va_allocator() - placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(1024, 512), itemsize=2, + num_pe=8, num_cubes=1, target_sip=0, + ) th = deploy_tensor( name="W", @@ -128,7 +136,11 @@ def test_deploy_tensor_does_not_install_mmu_mappings(): allocs = _make_allocators() va_alloc = _make_va_allocator() mmus = _make_mmus() - placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(1024, 512), itemsize=2, + num_pe=8, num_cubes=1, target_sip=0, + ) deploy_tensor( name="W", @@ -153,7 +165,7 @@ def test_tensor_va_property(): allocs = _make_allocators(1) va_alloc = _make_va_allocator() - placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)] + placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=4096)] t = Tensor(shape=(2048,), dtype="f16", name="test") t._handle = deploy_tensor( diff --git a/tests/test_va_offset.py b/tests/test_va_offset.py index 8537874..d7f71d7 100644 --- a/tests/test_va_offset.py +++ b/tests/test_va_offset.py @@ -20,7 +20,7 @@ from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.va_allocator import VirtualAllocator -from kernbench.policy.placement.dp import DPPolicy, column_wise +from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy from kernbench.runtime_api.tensor import deploy_tensor from kernbench.sim_engine.engine import GraphEngine from kernbench.runtime_api.context import RuntimeContext @@ -70,7 +70,7 @@ def _make_standalone(shape, num_pe=NUM_PE): sram_bytes_per_cube=32 * _MB, ) allocators = { - i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg) + (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg) for i in range(num_pe) } va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096) @@ -110,7 +110,11 @@ def test_2d_va_translates_to_local_hbm(): cols_per_pe = K // NUM_PE block_bytes = M * cols_per_pe * ELEM_BYTES - placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(M, K), itemsize=ELEM_BYTES, + num_pe=NUM_PE, num_cubes=1, target_sip=0, + ) handle = deploy_tensor( name="src", shape=(M, K), dtype="fp16", placement=placement, allocators=allocators, va_allocator=va_alloc, @@ -178,7 +182,11 @@ def test_1d_va_translates_to_local_hbm(): elems_per_pe = N_1D // NUM_PE block_bytes = elems_per_pe * ELEM_BYTES - placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE) + placement = resolve_dp_policy( + DPPolicy(cube="replicate", pe="column_wise"), + shape=(1, N_1D), itemsize=ELEM_BYTES, + num_pe=NUM_PE, num_cubes=1, target_sip=0, + ) handle = deploy_tensor( name="src_1d", shape=(N_1D,), dtype="fp16", placement=placement, allocators=allocators, va_allocator=va_alloc, @@ -207,7 +215,9 @@ def test_1d_e2e_completes(): correlation_id="vo6", spec=graph.spec, ) - dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") + # ADR-0026: DPPolicy is intra-device only; SIP scoping comes from the + # RuntimeContext's target_device. This 1D e2e runs on a single SIP. + dp = DPPolicy(cube="column_wise", pe="column_wise") src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d") dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")