ADR-0026: DPPolicy intra-device only + ShardSpec structural coords
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy itself describes only the cube × PE layout within one SIP. ShardSpec switches to structural (sip, cube, pe) coordinates; the flat pe_index field/property is fully removed — silent drift between global-flat and SIP-local interpretations was a foot-gun flagged by ADR-0024 D11. Breaking API (explicit TypeError / AttributeError): - DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError - ShardSpec.pe_index -> AttributeError - ShardSpec(pe_index=...) -> TypeError - resolve_dp_policy now takes target_sip= (required), no num_sips. Downstream migration: - PE allocator dict keyed by (sip, cube, pe) tuples, in both _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup. - _create_tensor passes target_sip=current_sip; post-hoc pe_index shifting removed entirely. - launch._compute_local_shape drops the dp.sip branch. - Internal resolvers (column_wise / row_wise / replicate / tiled_*) return _LocalPeShard (cube-local identifier) instead of ShardSpec — resolve_dp_policy lifts them to full structural coords. Tests: - New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the contract end-to-end. - test_sip_parallel.py rewritten: SIP composition now modeled as two resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style). - Call-site migration: test_tensor, test_va_integration, test_va_offset, test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy branch) all use intra-device DPPolicy and structural ShardSpec. Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+91
-128
@@ -1,157 +1,120 @@
|
||||
"""Tests for SIP-level tensor parallelism.
|
||||
"""Tests for SIP-level tensor parallelism — ADR-0026 structural model.
|
||||
|
||||
Validates:
|
||||
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
|
||||
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
|
||||
SP3. sip="row_wise": tensor M-axis split across SIPs
|
||||
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
|
||||
SP5. sip="replicate": all SIPs get full copy (existing behavior)
|
||||
SP6. PE_CPU sets num_programs from shard count per cube
|
||||
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
|
||||
DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is
|
||||
now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)``
|
||||
emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling
|
||||
the resolver once per SIP (typically driven by the ADR-0024 launcher, one
|
||||
worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``).
|
||||
|
||||
Covered here:
|
||||
SP1. ``target_sip`` stamps every shard.
|
||||
SP2. Two-SIP placement: union of two resolver calls covers the whole
|
||||
tensor K-axis when the combined bench treats them as column-split.
|
||||
SP3. Same for row-wise.
|
||||
SP4. Cube + PE sharding within a SIP remains correct across SIPs.
|
||||
SP5. PE_CPU num_programs contract (unchanged by ADR-0026).
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||||
|
||||
|
||||
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
|
||||
# ── SP1. target_sip stamps shards ────────────────────────────────────
|
||||
|
||||
|
||||
def test_dp_policy_sip_default_replicate():
|
||||
"""DPPolicy without sip= defaults to 'replicate'."""
|
||||
def test_target_sip_stamps_all_shards():
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "replicate"
|
||||
|
||||
|
||||
def test_dp_policy_sip_column_wise():
|
||||
"""DPPolicy accepts sip='column_wise'."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "column_wise"
|
||||
|
||||
|
||||
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sip_column_wise_splits_across_sips():
|
||||
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
num_pe=8, num_cubes=1, target_sip=3,
|
||||
)
|
||||
# 2 SIPs × 1 cube × 8 PEs = 16 shards
|
||||
assert len(shards) == 16
|
||||
|
||||
# SIP0 shards: first half of K (0 to K//2)
|
||||
# SIP1 shards: second half of K (K//2 to K)
|
||||
total_bytes = 128 * 256 * 2 # 64KB
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0 offsets start at 0
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
# SIP1 offsets start at half
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
# Total coverage
|
||||
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
|
||||
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
|
||||
assert all(s.sip == 3 for s in shards)
|
||||
assert all(0 <= s.pe < 8 for s in shards)
|
||||
assert all(s.cube == 0 for s in shards)
|
||||
|
||||
|
||||
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
|
||||
# ── SP2. column-wise placement composed across two SIPs ─────────────
|
||||
|
||||
|
||||
def test_sip_row_wise_splits_across_sips():
|
||||
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
|
||||
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
def test_compose_two_sips_column_wise_covers_tensor():
|
||||
"""Bench splits K-axis across 2 SIPs by calling resolve twice and
|
||||
giving each SIP half of the tensor (half-shape + offset). Shards
|
||||
from both SIPs together cover the whole K axis."""
|
||||
full_shape = (128, 256)
|
||||
itemsize = 2
|
||||
# Per-SIP half-shape (K split across SIPs).
|
||||
half_shape = (128, 128)
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
|
||||
shards_sip0 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
shards_sip1 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=1,
|
||||
)
|
||||
|
||||
total_bytes = full_shape[0] * full_shape[1] * itemsize
|
||||
sip0_bytes = sum(s.nbytes for s in shards_sip0)
|
||||
sip1_bytes = sum(s.nbytes for s in shards_sip1)
|
||||
assert sip0_bytes + sip1_bytes == total_bytes
|
||||
assert all(s.sip == 0 for s in shards_sip0)
|
||||
assert all(s.sip == 1 for s in shards_sip1)
|
||||
|
||||
|
||||
# ── SP3. row-wise placement composed across two SIPs ────────────────
|
||||
|
||||
|
||||
def test_compose_two_sips_row_wise_covers_tensor():
|
||||
full_shape = (128, 256)
|
||||
itemsize = 2
|
||||
half_shape = (64, 256) # per-SIP half of M
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
|
||||
shards_sip0 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
shards_sip1 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=1,
|
||||
)
|
||||
|
||||
total_bytes = full_shape[0] * full_shape[1] * itemsize
|
||||
assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes
|
||||
|
||||
|
||||
# ── SP4. cube × PE sharding is independent per SIP ──────────────────
|
||||
|
||||
|
||||
def test_cube_pe_sharding_independent_per_sip():
|
||||
"""Intra-SIP cube + PE layout matches across SIPs; only sip field differs."""
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise")
|
||||
s0 = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
num_pe=4, num_cubes=2, target_sip=0,
|
||||
)
|
||||
assert len(shards) == 16
|
||||
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0: rows 0..63, SIP1: rows 64..127
|
||||
total_bytes = 128 * 256 * 2
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
|
||||
# ── SP4. 3-level resolve ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_3level_resolve_flat_index():
|
||||
"""3-level: sip × cube × pe produces correct flat indices."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
s1 = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
num_pe=4, num_cubes=2, target_sip=1,
|
||||
)
|
||||
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
|
||||
assert len(shards) == 32
|
||||
|
||||
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
|
||||
indices = [s.pe_index for s in shards]
|
||||
# SIP0: 0..15, SIP1: 16..31
|
||||
assert min(indices) == 0
|
||||
assert max(indices) == 31
|
||||
assert len(set(indices)) == 32 # all unique
|
||||
assert len(s0) == len(s1) == 2 * 4
|
||||
for a, b in zip(s0, s1):
|
||||
assert a.sip == 0 and b.sip == 1
|
||||
assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == (
|
||||
b.cube, b.pe, b.offset_bytes, b.nbytes
|
||||
)
|
||||
|
||||
|
||||
def test_3level_offsets_cover_full_tensor():
|
||||
"""3-level sharding covers the entire tensor with no gaps."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=4, num_cubes=1, num_sips=2,
|
||||
)
|
||||
# 2 SIPs × 1 cube × 4 PEs = 8 shards
|
||||
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
|
||||
total = 128 * 256 * 2
|
||||
# For non-replicate, total shard bytes == tensor bytes
|
||||
# (replicate within cube means cube shards overlap, but sip shards don't)
|
||||
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
|
||||
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
|
||||
assert sip0_bytes + sip1_bytes == total
|
||||
|
||||
|
||||
# ── SP5. sip="replicate" backward compat ─────────────────────────────
|
||||
|
||||
|
||||
def test_sip_replicate_backward_compat():
|
||||
"""sip='replicate' produces same result as before (2-level)."""
|
||||
dp_old = DPPolicy(cube="replicate", pe="column_wise")
|
||||
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
|
||||
|
||||
shards_old = resolve_dp_policy(
|
||||
dp_old, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
shards_new = resolve_dp_policy(
|
||||
dp_new, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
assert len(shards_old) == len(shards_new)
|
||||
for a, b in zip(shards_old, shards_new):
|
||||
assert a.pe_index == b.pe_index
|
||||
assert a.offset_bytes == b.offset_bytes
|
||||
assert a.nbytes == b.nbytes
|
||||
|
||||
|
||||
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
|
||||
# ── SP5. PE_CPU num_programs (contract unchanged) ───────────────────
|
||||
|
||||
|
||||
def test_pe_cpu_sets_num_programs():
|
||||
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
|
||||
# This test validates the interface contract.
|
||||
# After implementation, PE_CPU should derive num_programs from the
|
||||
# number of PE shards in the kernel launch's target cube.
|
||||
"""TLContext reports num_programs from its initializer — used by PE_CPU
|
||||
when it launches a kernel on behalf of its shards."""
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
# With 8 PEs per cube, num_programs should be 8
|
||||
tl = TLContext(pe_id=3, num_programs=8)
|
||||
assert tl.program_id(0) == 3
|
||||
assert tl.num_programs(0) == 8
|
||||
|
||||
Reference in New Issue
Block a user