Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+161 -69
View File
@@ -20,7 +20,6 @@ class RuntimeContext:
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
@@ -208,24 +207,17 @@ class RuntimeContext:
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
@@ -276,11 +268,11 @@ class RuntimeContext:
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
num_sips=self._num_sips,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
@@ -297,7 +289,6 @@ class RuntimeContext:
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
@@ -305,6 +296,8 @@ class RuntimeContext:
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
# Strategy: always SIP-scoped (each SIP gets only its own shards).
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
@@ -313,47 +306,52 @@ class RuntimeContext:
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
# Group shards by SIP
sip_groups: dict[int, list] = defaultdict(list)
for shard in handle.shards:
sip_groups[shard.sip].append(shard)
for (sip, cube), group_shards in cube_groups.items():
for sip, sip_shards in sip_groups.items():
if is_cube_replicate:
# Cube replicate: per-(sip, cube) local mapping
cube_groups: dict[int, list] = defaultdict(list)
for s in sip_shards:
cube_groups[s.cube].append(s)
for cube, group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Cube sharded: broadcast all cubes within this SIP
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
for s in sip_shards
)
cube_set = sorted({s.cube for s in sip_shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
request_id=f"mmu_{tensor_name}_s{sip}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
@@ -384,11 +382,18 @@ class RuntimeContext:
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
"""
from collections import defaultdict
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
@@ -399,14 +404,14 @@ class RuntimeContext:
except ValueError:
pass
# Build kernel args from positional + keyword args
kernel_args: list = []
# Collect tensors and scalars
tensor_args: list[Tensor] = []
scalar_args: list = []
target_pe: int | str = 0
for a in args:
if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg())
# Infer target_pe from tensor DP metadata
tensor_args.append(a)
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
@@ -415,34 +420,121 @@ class RuntimeContext:
target_pe = dp_target
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards
cube_set: set[int] = set()
for a in args:
if isinstance(a, Tensor) and a._handle is not None:
for s in a._handle.shards:
cube_set.add(s.cube)
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
# Determine all target SIPs from tensor shards
sip_set: set[int] = set()
for t in tensor_args:
if t._handle is not None:
for s in t._handle.shards:
sip_set.add(s.sip)
if not sip_set:
sip_set = {0}
# Collect scalar values for GEMM FLOP calculation
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
# Build global→local dimension mapping from tensor DPPolicies.
# Scalar args matching a tensor's global dimension get replaced
# with the cube-local value (what the kernel actually operates on).
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
"""Compute cube-local shape from DPPolicy."""
shape = t.shape
if len(shape) < 2:
shape = (1, shape[0])
M, K = shape[0], shape[1]
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None:
return t.shape
if dp.sip != "replicate":
if dp.sip == "column_wise":
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate":
if dp.cube == "column_wise":
K = K // self._num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=kernel_name,
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals,
})
return h
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
for t in tensor_args:
if t._handle is None:
continue
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
if not sip_shards:
sip_shards = list(t._handle.shards)
local_va_base = 0
if t._handle.va_base:
min_offset = min(s.offset_bytes for s in sip_shards)
local_va_base = t._handle.va_base + min_offset
sip_kernel_args.append(TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in sip_shards
),
va_base=local_va_base,
))
for s in sip_shards:
sip_cube_set.add(s.cube)
# Interleave tensor args and scalar args, replacing global dims with local
final_args: list = []
t_idx, s_idx = 0, 0
for a in args:
if isinstance(a, Tensor):
final_args.append(sip_kernel_args[t_idx])
t_idx += 1
elif isinstance(a, (int, float)):
sa = scalar_args[s_idx]
if isinstance(a, int) and a in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
final_args.append(sa)
s_idx += 1
while s_idx < len(scalar_args):
sa = scalar_args[s_idx]
if isinstance(sa.value, int) and int(sa.value) in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
final_args.append(sa)
s_idx += 1
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=f"{kernel_name}_sip{sip_id}",
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(final_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
last_handle = h
return last_handle
+1 -11
View File
@@ -59,10 +59,7 @@ def deploy_tensor(
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize
@@ -78,22 +75,15 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=encoded_pa,
pa=pa.encode(),
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle(
name=name,
shape=shape,