ADR-0026: DPPolicy intra-device only + ShardSpec structural coords

DPPolicy no longer carries a cross-SIP axis. SIP-level placement is
solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy
itself describes only the cube × PE layout within one SIP. ShardSpec
switches to structural (sip, cube, pe) coordinates; the flat pe_index
field/property is fully removed — silent drift between global-flat and
SIP-local interpretations was a foot-gun flagged by ADR-0024 D11.

Breaking API (explicit TypeError / AttributeError):
- DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError
- ShardSpec.pe_index -> AttributeError
- ShardSpec(pe_index=...) -> TypeError
- resolve_dp_policy now takes target_sip= (required), no num_sips.

Downstream migration:
- PE allocator dict keyed by (sip, cube, pe) tuples, in both
  _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup.
- _create_tensor passes target_sip=current_sip; post-hoc pe_index
  shifting removed entirely.
- launch._compute_local_shape drops the dp.sip branch.
- Internal resolvers (column_wise / row_wise / replicate / tiled_*)
  return _LocalPeShard (cube-local identifier) instead of ShardSpec —
  resolve_dp_policy lifts them to full structural coords.

Tests:
- New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the
  contract end-to-end.
- test_sip_parallel.py rewritten: SIP composition now modeled as two
  resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style).
- Call-site migration: test_tensor, test_va_integration, test_va_offset,
  test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches
  gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy
  branch) all use intra-device DPPolicy and structural ShardSpec.

Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged
ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 13:02:19 -07:00
parent 787409ced1
commit 357cab525b
20 changed files with 549 additions and 328 deletions
+91 -128
View File
@@ -1,157 +1,120 @@
"""Tests for SIP-level tensor parallelism.
"""Tests for SIP-level tensor parallelism — ADR-0026 structural model.
Validates:
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
SP3. sip="row_wise": tensor M-axis split across SIPs
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
SP5. sip="replicate": all SIPs get full copy (existing behavior)
SP6. PE_CPU sets num_programs from shard count per cube
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is
now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)``
emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling
the resolver once per SIP (typically driven by the ADR-0024 launcher, one
worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``).
Covered here:
SP1. ``target_sip`` stamps every shard.
SP2. Two-SIP placement: union of two resolver calls covers the whole
tensor K-axis when the combined bench treats them as column-split.
SP3. Same for row-wise.
SP4. Cube + PE sharding within a SIP remains correct across SIPs.
SP5. PE_CPU num_programs contract (unchanged by ADR-0026).
"""
import pytest
from pathlib import Path
from __future__ import annotations
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
# ── SP1. target_sip stamps shards ────────────────────────────────────
def test_dp_policy_sip_default_replicate():
"""DPPolicy without sip= defaults to 'replicate'."""
def test_target_sip_stamps_all_shards():
dp = DPPolicy(cube="replicate", pe="column_wise")
assert dp.sip == "replicate"
def test_dp_policy_sip_column_wise():
"""DPPolicy accepts sip='column_wise'."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
assert dp.sip == "column_wise"
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
def test_sip_column_wise_splits_across_sips():
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
num_pe=8, num_cubes=1, target_sip=3,
)
# 2 SIPs × 1 cube × 8 PEs = 16 shards
assert len(shards) == 16
# SIP0 shards: first half of K (0 to K//2)
# SIP1 shards: second half of K (K//2 to K)
total_bytes = 128 * 256 * 2 # 64KB
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0 offsets start at 0
assert sip0_shards[0].offset_bytes == 0
# SIP1 offsets start at half
assert sip1_shards[0].offset_bytes == total_bytes // 2
# Total coverage
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
assert all(s.sip == 3 for s in shards)
assert all(0 <= s.pe < 8 for s in shards)
assert all(s.cube == 0 for s in shards)
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
# ── SP2. column-wise placement composed across two SIPs ─────────────
def test_sip_row_wise_splits_across_sips():
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
def test_compose_two_sips_column_wise_covers_tensor():
"""Bench splits K-axis across 2 SIPs by calling resolve twice and
giving each SIP half of the tensor (half-shape + offset). Shards
from both SIPs together cover the whole K axis."""
full_shape = (128, 256)
itemsize = 2
# Per-SIP half-shape (K split across SIPs).
half_shape = (128, 128)
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
sip0_bytes = sum(s.nbytes for s in shards_sip0)
sip1_bytes = sum(s.nbytes for s in shards_sip1)
assert sip0_bytes + sip1_bytes == total_bytes
assert all(s.sip == 0 for s in shards_sip0)
assert all(s.sip == 1 for s in shards_sip1)
# ── SP3. row-wise placement composed across two SIPs ────────────────
def test_compose_two_sips_row_wise_covers_tensor():
full_shape = (128, 256)
itemsize = 2
half_shape = (64, 256) # per-SIP half of M
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes
# ── SP4. cube × PE sharding is independent per SIP ──────────────────
def test_cube_pe_sharding_independent_per_sip():
"""Intra-SIP cube + PE layout matches across SIPs; only sip field differs."""
dp = DPPolicy(cube="column_wise", pe="column_wise")
s0 = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
num_pe=4, num_cubes=2, target_sip=0,
)
assert len(shards) == 16
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0: rows 0..63, SIP1: rows 64..127
total_bytes = 128 * 256 * 2
assert sip0_shards[0].offset_bytes == 0
assert sip1_shards[0].offset_bytes == total_bytes // 2
# ── SP4. 3-level resolve ─────────────────────────────────────────────
def test_3level_resolve_flat_index():
"""3-level: sip × cube × pe produces correct flat indices."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
s1 = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
num_pe=4, num_cubes=2, target_sip=1,
)
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
assert len(shards) == 32
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
indices = [s.pe_index for s in shards]
# SIP0: 0..15, SIP1: 16..31
assert min(indices) == 0
assert max(indices) == 31
assert len(set(indices)) == 32 # all unique
assert len(s0) == len(s1) == 2 * 4
for a, b in zip(s0, s1):
assert a.sip == 0 and b.sip == 1
assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == (
b.cube, b.pe, b.offset_bytes, b.nbytes
)
def test_3level_offsets_cover_full_tensor():
"""3-level sharding covers the entire tensor with no gaps."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 4 PEs = 8 shards
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
total = 128 * 256 * 2
# For non-replicate, total shard bytes == tensor bytes
# (replicate within cube means cube shards overlap, but sip shards don't)
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
assert sip0_bytes + sip1_bytes == total
# ── SP5. sip="replicate" backward compat ─────────────────────────────
def test_sip_replicate_backward_compat():
"""sip='replicate' produces same result as before (2-level)."""
dp_old = DPPolicy(cube="replicate", pe="column_wise")
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
shards_old = resolve_dp_policy(
dp_old, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
shards_new = resolve_dp_policy(
dp_new, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
assert len(shards_old) == len(shards_new)
for a, b in zip(shards_old, shards_new):
assert a.pe_index == b.pe_index
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
# ── SP5. PE_CPU num_programs (contract unchanged) ───────────────────
def test_pe_cpu_sets_num_programs():
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
# This test validates the interface contract.
# After implementation, PE_CPU should derive num_programs from the
# number of PE shards in the kernel launch's target cube.
"""TLContext reports num_programs from its initializer — used by PE_CPU
when it launches a kernel on behalf of its shards."""
from kernbench.triton_emu.tl_context import TLContext
# With 8 PEs per cube, num_programs should be 8
tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8