Files
kernbench2/tests/test_sip_parallel.py
T
ywkang 357cab525b ADR-0026: DPPolicy intra-device only + ShardSpec structural coords
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is
solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy
itself describes only the cube × PE layout within one SIP. ShardSpec
switches to structural (sip, cube, pe) coordinates; the flat pe_index
field/property is fully removed — silent drift between global-flat and
SIP-local interpretations was a foot-gun flagged by ADR-0024 D11.

Breaking API (explicit TypeError / AttributeError):
- DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError
- ShardSpec.pe_index -> AttributeError
- ShardSpec(pe_index=...) -> TypeError
- resolve_dp_policy now takes target_sip= (required), no num_sips.

Downstream migration:
- PE allocator dict keyed by (sip, cube, pe) tuples, in both
  _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup.
- _create_tensor passes target_sip=current_sip; post-hoc pe_index
  shifting removed entirely.
- launch._compute_local_shape drops the dp.sip branch.
- Internal resolvers (column_wise / row_wise / replicate / tiled_*)
  return _LocalPeShard (cube-local identifier) instead of ShardSpec —
  resolve_dp_policy lifts them to full structural coords.

Tests:
- New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the
  contract end-to-end.
- test_sip_parallel.py rewritten: SIP composition now modeled as two
  resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style).
- Call-site migration: test_tensor, test_va_integration, test_va_offset,
  test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches
  gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy
  branch) all use intra-device DPPolicy and structural ShardSpec.

Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged
ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 13:02:19 -07:00

121 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for SIP-level tensor parallelism — ADR-0026 structural model.
DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is
now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)``
emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling
the resolver once per SIP (typically driven by the ADR-0024 launcher, one
worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``).
Covered here:
SP1. ``target_sip`` stamps every shard.
SP2. Two-SIP placement: union of two resolver calls covers the whole
tensor K-axis when the combined bench treats them as column-split.
SP3. Same for row-wise.
SP4. Cube + PE sharding within a SIP remains correct across SIPs.
SP5. PE_CPU num_programs contract (unchanged by ADR-0026).
"""
from __future__ import annotations
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
# ── SP1. target_sip stamps shards ────────────────────────────────────
def test_target_sip_stamps_all_shards():
dp = DPPolicy(cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, target_sip=3,
)
assert all(s.sip == 3 for s in shards)
assert all(0 <= s.pe < 8 for s in shards)
assert all(s.cube == 0 for s in shards)
# ── SP2. column-wise placement composed across two SIPs ─────────────
def test_compose_two_sips_column_wise_covers_tensor():
"""Bench splits K-axis across 2 SIPs by calling resolve twice and
giving each SIP half of the tensor (half-shape + offset). Shards
from both SIPs together cover the whole K axis."""
full_shape = (128, 256)
itemsize = 2
# Per-SIP half-shape (K split across SIPs).
half_shape = (128, 128)
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
sip0_bytes = sum(s.nbytes for s in shards_sip0)
sip1_bytes = sum(s.nbytes for s in shards_sip1)
assert sip0_bytes + sip1_bytes == total_bytes
assert all(s.sip == 0 for s in shards_sip0)
assert all(s.sip == 1 for s in shards_sip1)
# ── SP3. row-wise placement composed across two SIPs ────────────────
def test_compose_two_sips_row_wise_covers_tensor():
full_shape = (128, 256)
itemsize = 2
half_shape = (64, 256) # per-SIP half of M
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes
# ── SP4. cube × PE sharding is independent per SIP ──────────────────
def test_cube_pe_sharding_independent_per_sip():
"""Intra-SIP cube + PE layout matches across SIPs; only sip field differs."""
dp = DPPolicy(cube="column_wise", pe="column_wise")
s0 = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=2, target_sip=0,
)
s1 = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=2, target_sip=1,
)
assert len(s0) == len(s1) == 2 * 4
for a, b in zip(s0, s1):
assert a.sip == 0 and b.sip == 1
assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == (
b.cube, b.pe, b.offset_bytes, b.nbytes
)
# ── SP5. PE_CPU num_programs (contract unchanged) ───────────────────
def test_pe_cpu_sets_num_programs():
"""TLContext reports num_programs from its initializer — used by PE_CPU
when it launches a kernel on behalf of its shards."""
from kernbench.triton_emu.tl_context import TLContext
tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8