"""Tests for SIP-level tensor parallelism. Validates: SP1. DPPolicy accepts sip field (default "replicate", backward compat) SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips SP3. sip="row_wise": tensor M-axis split across SIPs SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets SP5. sip="replicate": all SIPs get full copy (existing behavior) SP6. PE_CPU sets num_programs from shard count per cube SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology """ import pytest from pathlib import Path from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy # ── SP1. DPPolicy sip field ────────────────────────────────────────── def test_dp_policy_sip_default_replicate(): """DPPolicy without sip= defaults to 'replicate'.""" dp = DPPolicy(cube="replicate", pe="column_wise") assert dp.sip == "replicate" def test_dp_policy_sip_column_wise(): """DPPolicy accepts sip='column_wise'.""" dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") assert dp.sip == "column_wise" # ── SP2. sip="column_wise" ────────────────────────────────────────────── def test_sip_column_wise_splits_across_sips(): """sip='column_wise' with 2 SIPs: each SIP gets K//2 columns.""" dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") shards = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, num_pe=8, num_cubes=1, num_sips=2, ) # 2 SIPs × 1 cube × 8 PEs = 16 shards assert len(shards) == 16 # SIP0 shards: first half of K (0 to K//2) # SIP1 shards: second half of K (K//2 to K) total_bytes = 128 * 256 * 2 # 64KB sip0_shards = [s for s in shards if s.pe_index < 8] sip1_shards = [s for s in shards if s.pe_index >= 8] # SIP0 offsets start at 0 assert sip0_shards[0].offset_bytes == 0 # SIP1 offsets start at half assert sip1_shards[0].offset_bytes == total_bytes // 2 # Total coverage assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2 assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2 # ── SP3. sip="row_wise" ────────────────────────────────────────────── def test_sip_row_wise_splits_across_sips(): """sip='row_wise' with 2 SIPs: each SIP gets M//2 rows.""" dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise") shards = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, num_pe=8, num_cubes=1, num_sips=2, ) assert len(shards) == 16 sip0_shards = [s for s in shards if s.pe_index < 8] sip1_shards = [s for s in shards if s.pe_index >= 8] # SIP0: rows 0..63, SIP1: rows 64..127 total_bytes = 128 * 256 * 2 assert sip0_shards[0].offset_bytes == 0 assert sip1_shards[0].offset_bytes == total_bytes // 2 # ── SP4. 3-level resolve ───────────────────────────────────────────── def test_3level_resolve_flat_index(): """3-level: sip × cube × pe produces correct flat indices.""" dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") shards = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, num_pe=8, num_cubes=2, num_sips=2, ) # 2 SIPs × 2 cubes × 8 PEs = 32 shards assert len(shards) == 32 # Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id indices = [s.pe_index for s in shards] # SIP0: 0..15, SIP1: 16..31 assert min(indices) == 0 assert max(indices) == 31 assert len(set(indices)) == 32 # all unique def test_3level_offsets_cover_full_tensor(): """3-level sharding covers the entire tensor with no gaps.""" dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") shards = resolve_dp_policy( dp, shape=(128, 256), itemsize=2, num_pe=4, num_cubes=1, num_sips=2, ) # 2 SIPs × 1 cube × 4 PEs = 8 shards # sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE total = 128 * 256 * 2 # For non-replicate, total shard bytes == tensor bytes # (replicate within cube means cube shards overlap, but sip shards don't) sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4) sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4) assert sip0_bytes + sip1_bytes == total # ── SP5. sip="replicate" backward compat ───────────────────────────── def test_sip_replicate_backward_compat(): """sip='replicate' produces same result as before (2-level).""" dp_old = DPPolicy(cube="replicate", pe="column_wise") dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise") shards_old = resolve_dp_policy( dp_old, shape=(128, 256), itemsize=2, num_pe=8, num_cubes=2, num_sips=2, ) shards_new = resolve_dp_policy( dp_new, shape=(128, 256), itemsize=2, num_pe=8, num_cubes=2, num_sips=2, ) assert len(shards_old) == len(shards_new) for a, b in zip(shards_old, shards_new): assert a.pe_index == b.pe_index assert a.offset_bytes == b.offset_bytes assert a.nbytes == b.nbytes # ── SP6. PE_CPU num_programs ────────────────────────────────────────── def test_pe_cpu_sets_num_programs(): """PE_CPU should create TLContext with num_programs = PEs per cube.""" # This test validates the interface contract. # After implementation, PE_CPU should derive num_programs from the # number of PE shards in the kernel launch's target cube. from kernbench.triton_emu.tl_context import TLContext # With 8 PEs per cube, num_programs should be 8 tl = TLContext(pe_id=3, num_programs=8) assert tl.program_id(0) == 3 assert tl.num_programs(0) == 8