63669f82cb
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
158 lines
6.1 KiB
Python
158 lines
6.1 KiB
Python
"""Tests for SIP-level tensor parallelism.
|
||
|
||
Validates:
|
||
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
|
||
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
|
||
SP3. sip="row_wise": tensor M-axis split across SIPs
|
||
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
|
||
SP5. sip="replicate": all SIPs get full copy (existing behavior)
|
||
SP6. PE_CPU sets num_programs from shard count per cube
|
||
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
|
||
"""
|
||
import pytest
|
||
from pathlib import Path
|
||
|
||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||
|
||
|
||
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
|
||
|
||
|
||
def test_dp_policy_sip_default_replicate():
|
||
"""DPPolicy without sip= defaults to 'replicate'."""
|
||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||
assert dp.sip == "replicate"
|
||
|
||
|
||
def test_dp_policy_sip_column_wise():
|
||
"""DPPolicy accepts sip='column_wise'."""
|
||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||
assert dp.sip == "column_wise"
|
||
|
||
|
||
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
|
||
|
||
|
||
def test_sip_column_wise_splits_across_sips():
|
||
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
|
||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||
shards = resolve_dp_policy(
|
||
dp, shape=(128, 256), itemsize=2,
|
||
num_pe=8, num_cubes=1, num_sips=2,
|
||
)
|
||
# 2 SIPs × 1 cube × 8 PEs = 16 shards
|
||
assert len(shards) == 16
|
||
|
||
# SIP0 shards: first half of K (0 to K//2)
|
||
# SIP1 shards: second half of K (K//2 to K)
|
||
total_bytes = 128 * 256 * 2 # 64KB
|
||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||
|
||
# SIP0 offsets start at 0
|
||
assert sip0_shards[0].offset_bytes == 0
|
||
# SIP1 offsets start at half
|
||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||
|
||
# Total coverage
|
||
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
|
||
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
|
||
|
||
|
||
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
|
||
|
||
|
||
def test_sip_row_wise_splits_across_sips():
|
||
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
|
||
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
|
||
shards = resolve_dp_policy(
|
||
dp, shape=(128, 256), itemsize=2,
|
||
num_pe=8, num_cubes=1, num_sips=2,
|
||
)
|
||
assert len(shards) == 16
|
||
|
||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||
|
||
# SIP0: rows 0..63, SIP1: rows 64..127
|
||
total_bytes = 128 * 256 * 2
|
||
assert sip0_shards[0].offset_bytes == 0
|
||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||
|
||
|
||
# ── SP4. 3-level resolve ─────────────────────────────────────────────
|
||
|
||
|
||
def test_3level_resolve_flat_index():
|
||
"""3-level: sip × cube × pe produces correct flat indices."""
|
||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||
shards = resolve_dp_policy(
|
||
dp, shape=(128, 256), itemsize=2,
|
||
num_pe=8, num_cubes=2, num_sips=2,
|
||
)
|
||
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
|
||
assert len(shards) == 32
|
||
|
||
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
|
||
indices = [s.pe_index for s in shards]
|
||
# SIP0: 0..15, SIP1: 16..31
|
||
assert min(indices) == 0
|
||
assert max(indices) == 31
|
||
assert len(set(indices)) == 32 # all unique
|
||
|
||
|
||
def test_3level_offsets_cover_full_tensor():
|
||
"""3-level sharding covers the entire tensor with no gaps."""
|
||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||
shards = resolve_dp_policy(
|
||
dp, shape=(128, 256), itemsize=2,
|
||
num_pe=4, num_cubes=1, num_sips=2,
|
||
)
|
||
# 2 SIPs × 1 cube × 4 PEs = 8 shards
|
||
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
|
||
total = 128 * 256 * 2
|
||
# For non-replicate, total shard bytes == tensor bytes
|
||
# (replicate within cube means cube shards overlap, but sip shards don't)
|
||
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
|
||
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
|
||
assert sip0_bytes + sip1_bytes == total
|
||
|
||
|
||
# ── SP5. sip="replicate" backward compat ─────────────────────────────
|
||
|
||
|
||
def test_sip_replicate_backward_compat():
|
||
"""sip='replicate' produces same result as before (2-level)."""
|
||
dp_old = DPPolicy(cube="replicate", pe="column_wise")
|
||
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
|
||
|
||
shards_old = resolve_dp_policy(
|
||
dp_old, shape=(128, 256), itemsize=2,
|
||
num_pe=8, num_cubes=2, num_sips=2,
|
||
)
|
||
shards_new = resolve_dp_policy(
|
||
dp_new, shape=(128, 256), itemsize=2,
|
||
num_pe=8, num_cubes=2, num_sips=2,
|
||
)
|
||
assert len(shards_old) == len(shards_new)
|
||
for a, b in zip(shards_old, shards_new):
|
||
assert a.pe_index == b.pe_index
|
||
assert a.offset_bytes == b.offset_bytes
|
||
assert a.nbytes == b.nbytes
|
||
|
||
|
||
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
|
||
|
||
|
||
def test_pe_cpu_sets_num_programs():
|
||
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
|
||
# This test validates the interface contract.
|
||
# After implementation, PE_CPU should derive num_programs from the
|
||
# number of PE shards in the kernel launch's target cube.
|
||
from kernbench.triton_emu.tl_context import TLContext
|
||
|
||
# With 8 PEs per cube, num_programs should be 8
|
||
tl = TLContext(pe_id=3, num_programs=8)
|
||
assert tl.program_id(0) == 3
|
||
assert tl.num_programs(0) == 8
|