Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
View File
+1 -1
View File
@@ -13,7 +13,7 @@ import simpy
from pathlib import Path
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.builtin.forwarding import TransitComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryReadMsg
from kernbench.sim_engine.engine import GraphEngine
+3 -3
View File
@@ -73,7 +73,7 @@ def test_mmu_unmap_msg_fields():
def test_pe_mmu_registry():
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node
node = Node(
@@ -93,7 +93,7 @@ def test_pe_mmu_registry():
def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works."""
import simpy
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node
@@ -152,7 +152,7 @@ def test_pe_dma_translates_va():
# This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.topology.types import Node
node = Node(
+8 -10
View File
@@ -20,12 +20,12 @@ from kernbench.common.pe_commands import (
TensorHandle,
)
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.builtin.pe_cpu import PeCpuComponent
from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.components.builtin.pe_gemm import PeGemmComponent
from kernbench.components.builtin.pe_math import PeMathComponent
from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent
from kernbench.components.builtin.pe_tcm import PeTcmComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
@@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy)
assert len(kernel_traces) == 1
assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology)
assert kernel_traces[0]["name"] == "qkv_gemm"
assert kernel_traces[0]["total_ns"] > 0
# Scalars should contain M, K, N
assert len(kernel_traces[0]["scalars"]) >= 3
clear_registry()
@@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8
assert len(kernel_traces) == 1
assert len(kernel_traces) >= 1 # one per SIP
assert kernel_traces[0]["target_pe"] == "all"
clear_registry()
+1 -1
View File
@@ -19,7 +19,7 @@ import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.components.impls import (
from kernbench.components.builtin import (
HbmCtrlComponent,
IoCpuComponent,
MCpuComponent,
+157
View File
@@ -0,0 +1,157 @@
"""Tests for SIP-level tensor parallelism.
Validates:
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
SP3. sip="row_wise": tensor M-axis split across SIPs
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
SP5. sip="replicate": all SIPs get full copy (existing behavior)
SP6. PE_CPU sets num_programs from shard count per cube
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
"""
import pytest
from pathlib import Path
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
def test_dp_policy_sip_default_replicate():
"""DPPolicy without sip= defaults to 'replicate'."""
dp = DPPolicy(cube="replicate", pe="column_wise")
assert dp.sip == "replicate"
def test_dp_policy_sip_column_wise():
"""DPPolicy accepts sip='column_wise'."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
assert dp.sip == "column_wise"
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
def test_sip_column_wise_splits_across_sips():
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 8 PEs = 16 shards
assert len(shards) == 16
# SIP0 shards: first half of K (0 to K//2)
# SIP1 shards: second half of K (K//2 to K)
total_bytes = 128 * 256 * 2 # 64KB
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0 offsets start at 0
assert sip0_shards[0].offset_bytes == 0
# SIP1 offsets start at half
assert sip1_shards[0].offset_bytes == total_bytes // 2
# Total coverage
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
def test_sip_row_wise_splits_across_sips():
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
assert len(shards) == 16
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0: rows 0..63, SIP1: rows 64..127
total_bytes = 128 * 256 * 2
assert sip0_shards[0].offset_bytes == 0
assert sip1_shards[0].offset_bytes == total_bytes // 2
# ── SP4. 3-level resolve ─────────────────────────────────────────────
def test_3level_resolve_flat_index():
"""3-level: sip × cube × pe produces correct flat indices."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
assert len(shards) == 32
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
indices = [s.pe_index for s in shards]
# SIP0: 0..15, SIP1: 16..31
assert min(indices) == 0
assert max(indices) == 31
assert len(set(indices)) == 32 # all unique
def test_3level_offsets_cover_full_tensor():
"""3-level sharding covers the entire tensor with no gaps."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 4 PEs = 8 shards
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
total = 128 * 256 * 2
# For non-replicate, total shard bytes == tensor bytes
# (replicate within cube means cube shards overlap, but sip shards don't)
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
assert sip0_bytes + sip1_bytes == total
# ── SP5. sip="replicate" backward compat ─────────────────────────────
def test_sip_replicate_backward_compat():
"""sip='replicate' produces same result as before (2-level)."""
dp_old = DPPolicy(cube="replicate", pe="column_wise")
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
shards_old = resolve_dp_policy(
dp_old, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
shards_new = resolve_dp_policy(
dp_new, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
assert len(shards_old) == len(shards_new)
for a, b in zip(shards_old, shards_new):
assert a.pe_index == b.pe_index
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
def test_pe_cpu_sets_num_programs():
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
# This test validates the interface contract.
# After implementation, PE_CPU should derive num_programs from the
# number of PE shards in the kernel launch's target cube.
from kernbench.triton_emu.tl_context import TLContext
# With 8 PEs per cube, num_programs should be 8
tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8
+5 -19
View File
@@ -88,7 +88,6 @@ def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
@@ -98,7 +97,6 @@ def test_deploy_tensor_assigns_va_base():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert th.va_base is not None
@@ -109,7 +107,6 @@ def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
@@ -119,41 +116,32 @@ def test_deploy_tensor_va_covers_all_shards():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Each shard's VA is derivable: va_base + offset_bytes
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_registers_mmu_mappings():
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
def test_deploy_tensor_does_not_install_mmu_mappings():
"""deploy_tensor does NOT install MMU mappings — that's context's job."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Every MMU should have entries (broadcast)
# No MMU should have any entries (mappings come from fabric MmuMapMsg)
for mmu in mmus.values():
assert mmu.num_entries > 0
# Each shard's derived VA should translate to its PA in every MMU
for mmu in mmus.values():
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert mmu.translate(shard_va) == s.pa
assert mmu.num_entries == 0
# ── T12. Tensor.va property ──────────────────────────────────────────
@@ -165,7 +153,6 @@ def test_tensor_va_property():
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
mmus = _make_mmus(1)
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
@@ -176,7 +163,6 @@ def test_tensor_va_property():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert t.va > 0
assert t.va == t._handle.va_base
+216
View File
@@ -0,0 +1,216 @@
"""VA offset verification: each PE accesses its own local HBM slice.
Verifies that column-wise sharding + VA offset calculation produces DMA
addresses that translate to the correct PE's local HBM — not a remote PE.
Tests:
VO1. Per-PE DMA addresses are correct VAs (2D)
VO2. Each VA translates to the executing PE's own HBM slice (2D)
VO3. End-to-end bench completes (2D, full TP)
VO4. Per-PE DMA addresses are correct VAs (1D)
VO5. Each VA translates to local HBM (1D)
VO6. End-to-end 1D bench completes
"""
from pathlib import Path
import pytest
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import DPPolicy, column_wise
from kernbench.runtime_api.tensor import deploy_tensor
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
from kernbench.triton_emu.tl_context import TLContext, run_kernel
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
M, K = 128, 256
DTYPE = "f16"
NUM_PE = 8
ELEM_BYTES = 2
def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton 2D copy. M, K are cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"):
"""Standard Triton 1D copy. N is cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
elems_per_pe = N // num_pe
elem_bytes = 2
offset = pid * elems_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _make_standalone(shape, num_pe=NUM_PE):
"""Create standalone allocators + MMUs for unit testing."""
cfg = AddressConfig(
sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe,
hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe,
tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
allocators = {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
for i in range(num_pe)
}
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)}
return cfg, allocators, va_alloc, mmus
# ── VO1. 2D: Per-PE DMA addresses are correct VAs ────────────────────
def test_2d_each_pe_computes_correct_va_offset():
"""2D: each PE generates DMA at va_base + pid * block_bytes."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO2. 2D: Each VA translates to local HBM ─────────────────────────
def test_2d_va_translates_to_local_hbm():
"""2D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((M, K))
slice_size = cfg.hbm_slice_bytes
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src", shape=(M, K), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
# Install per-PE mappings (simulating what context does via MmuMapMsg)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO3. 2D: End-to-end bench completes ──────────────────────────────
def test_2d_bench_completes():
"""2D: full TP bench with standard Triton kernel pattern."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo3", spec=graph.spec,
)
from benches.va_offset_verify import run as bench_run
bench_run(ctx)
ctx.wait_all()
# ── VO4. 1D: Per-PE DMA addresses ────────────────────────────────────
N_1D = 1024
def test_1d_each_pe_computes_correct_offset():
"""1D: each PE generates DMA at correct offset."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO5. 1D: VA translates to local HBM ──────────────────────────────
def test_1d_va_translates_to_local_hbm():
"""1D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D))
slice_size = cfg.hbm_slice_bytes
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src_1d", shape=(N_1D,), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO6. 1D: End-to-end ──────────────────────────────────────────────
def test_1d_e2e_completes():
"""1D: full engine run with column_wise TP sharding."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo6", spec=graph.spec,
)
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
# launch() auto-localizes N_1D → cube-local N
ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D)
ctx.wait_all()