Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -20,7 +20,6 @@ class RuntimeContext:
|
||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_va_allocator: Any = field(default=None, init=False)
|
||||
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
_tensors: list[Any] = field(default_factory=list, init=False)
|
||||
@@ -208,24 +207,17 @@ class RuntimeContext:
|
||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||
)
|
||||
|
||||
# Initialize VA allocator and per-PE MMUs
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
|
||||
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
|
||||
page_size = int(pe_mmu_attrs.get("page_size", 4096))
|
||||
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
|
||||
|
||||
self._va_allocator = VirtualAllocator(
|
||||
va_base=0x1_0000_0000,
|
||||
va_size=64 * (1 << 30), # 64 GB VA space
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pes = sip_count * cubes_per_sip * pes_per_cube
|
||||
for flat_idx in range(total_pes):
|
||||
self._mmus[flat_idx] = PeMMU(
|
||||
page_size=page_size, overhead_ns=tlb_overhead_ns,
|
||||
)
|
||||
|
||||
return self._allocators
|
||||
|
||||
@@ -276,11 +268,11 @@ class RuntimeContext:
|
||||
dp_policy = dp
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
|
||||
total_cubes = self._num_sips * self._num_cubes
|
||||
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
|
||||
num_sips=self._num_sips,
|
||||
)
|
||||
|
||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||
@@ -297,7 +289,6 @@ class RuntimeContext:
|
||||
placement=placement,
|
||||
allocators=allocators,
|
||||
va_allocator=self._va_allocator,
|
||||
mmus=self._mmus,
|
||||
)
|
||||
t._handle = handle
|
||||
import weakref
|
||||
@@ -305,6 +296,8 @@ class RuntimeContext:
|
||||
self._tensors.append(weakref.ref(t))
|
||||
|
||||
# Install VA→PA mappings via fabric MmuMapMsg
|
||||
# Strategy: always SIP-scoped (each SIP gets only its own shards).
|
||||
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
|
||||
if handle.va_base:
|
||||
from collections import defaultdict
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
@@ -313,47 +306,52 @@ class RuntimeContext:
|
||||
dp_policy is not None and dp_policy.cube == "replicate"
|
||||
)
|
||||
|
||||
if is_cube_replicate:
|
||||
# Replicate: each (sip, cube) gets only its own local PA mappings
|
||||
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||
for shard in handle.shards:
|
||||
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||
# Group shards by SIP
|
||||
sip_groups: dict[int, list] = defaultdict(list)
|
||||
for shard in handle.shards:
|
||||
sip_groups[shard.sip].append(shard)
|
||||
|
||||
for (sip, cube), group_shards in cube_groups.items():
|
||||
for sip, sip_shards in sip_groups.items():
|
||||
if is_cube_replicate:
|
||||
# Cube replicate: per-(sip, cube) local mapping
|
||||
cube_groups: dict[int, list] = defaultdict(list)
|
||||
for s in sip_shards:
|
||||
cube_groups[s.cube].append(s)
|
||||
|
||||
for cube, group_shards in cube_groups.items():
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in group_shards
|
||||
)
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
|
||||
entries=entries,
|
||||
target_sips=(sip,),
|
||||
target_cubes=(cube,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
else:
|
||||
# Cube sharded: broadcast all cubes within this SIP
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in group_shards
|
||||
for s in sip_shards
|
||||
)
|
||||
cube_set = sorted({s.cube for s in sip_shards})
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
|
||||
request_id=f"mmu_{tensor_name}_s{sip}",
|
||||
entries=entries,
|
||||
target_sips=(sip,),
|
||||
target_cubes=(cube,),
|
||||
target_cubes=tuple(cube_set),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
else:
|
||||
# Sharded: broadcast all mappings to all target (sip, cube)s
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in handle.shards
|
||||
)
|
||||
sip_set = sorted({s.sip for s in handle.shards})
|
||||
cube_set = sorted({s.cube for s in handle.shards})
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}",
|
||||
entries=entries,
|
||||
target_sips=tuple(sip_set),
|
||||
target_cubes=tuple(cube_set),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
|
||||
# Submit MemoryWriteMsg per shard (deploy data to device)
|
||||
if pattern is not None:
|
||||
@@ -384,11 +382,18 @@ class RuntimeContext:
|
||||
|
||||
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
|
||||
Keyword args: become ScalarArg (name is discarded, order preserved).
|
||||
|
||||
Creates per-SIP KernelLaunchMsg with local va_base per tensor
|
||||
(like host driver sending per-rank launch commands).
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
ScalarArg,
|
||||
TensorArg,
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
from kernbench.triton_emu.registry import register_kernel
|
||||
@@ -399,14 +404,14 @@ class RuntimeContext:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Build kernel args from positional + keyword args
|
||||
kernel_args: list = []
|
||||
# Collect tensors and scalars
|
||||
tensor_args: list[Tensor] = []
|
||||
scalar_args: list = []
|
||||
target_pe: int | str = 0
|
||||
|
||||
for a in args:
|
||||
if isinstance(a, Tensor):
|
||||
kernel_args.append(a.to_tensor_arg())
|
||||
# Infer target_pe from tensor DP metadata
|
||||
tensor_args.append(a)
|
||||
if a._dp_metadata is not None:
|
||||
dp_target = a._dp_metadata.target_pe
|
||||
if dp_target == "all":
|
||||
@@ -415,34 +420,121 @@ class RuntimeContext:
|
||||
target_pe = dp_target
|
||||
elif isinstance(a, (int, float)):
|
||||
dtype_str = "f32" if isinstance(a, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
|
||||
for v in kwargs.values():
|
||||
if isinstance(v, (int, float)):
|
||||
dtype_str = "f32" if isinstance(v, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
|
||||
# Determine target cubes from all tensor shards
|
||||
cube_set: set[int] = set()
|
||||
for a in args:
|
||||
if isinstance(a, Tensor) and a._handle is not None:
|
||||
for s in a._handle.shards:
|
||||
cube_set.add(s.cube)
|
||||
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
|
||||
# Determine all target SIPs from tensor shards
|
||||
sip_set: set[int] = set()
|
||||
for t in tensor_args:
|
||||
if t._handle is not None:
|
||||
for s in t._handle.shards:
|
||||
sip_set.add(s.sip)
|
||||
if not sip_set:
|
||||
sip_set = {0}
|
||||
|
||||
# Collect scalar values for GEMM FLOP calculation
|
||||
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
|
||||
# Build global→local dimension mapping from tensor DPPolicies.
|
||||
# Scalar args matching a tensor's global dimension get replaced
|
||||
# with the cube-local value (what the kernel actually operates on).
|
||||
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
|
||||
"""Compute cube-local shape from DPPolicy."""
|
||||
shape = t.shape
|
||||
if len(shape) < 2:
|
||||
shape = (1, shape[0])
|
||||
M, K = shape[0], shape[1]
|
||||
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
|
||||
if dp is None:
|
||||
return t.shape
|
||||
if dp.sip != "replicate":
|
||||
if dp.sip == "column_wise":
|
||||
K = K // self._num_sips
|
||||
elif dp.sip == "row_wise":
|
||||
M = M // self._num_sips
|
||||
if dp.cube != "replicate":
|
||||
if dp.cube == "column_wise":
|
||||
K = K // self._num_cubes
|
||||
elif dp.cube == "row_wise":
|
||||
M = M // self._num_cubes
|
||||
if len(t.shape) < 2:
|
||||
return (K,)
|
||||
return (M, K)
|
||||
|
||||
h = self.submit(KernelLaunchMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=kernel_name,
|
||||
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
|
||||
args=tuple(kernel_args),
|
||||
target_cubes=target_cubes,
|
||||
target_pe=target_pe,
|
||||
))
|
||||
self.wait(h, _meta={
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"target_pe": target_pe, "scalars": scalar_vals,
|
||||
})
|
||||
return h
|
||||
dim_map: dict[int, int] = {} # global_dim → local_dim
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
|
||||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||||
last_handle = None
|
||||
for sip_id in sorted(sip_set):
|
||||
sip_kernel_args: list = []
|
||||
sip_cube_set: set[int] = set()
|
||||
|
||||
for t in tensor_args:
|
||||
if t._handle is None:
|
||||
continue
|
||||
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
|
||||
if not sip_shards:
|
||||
sip_shards = list(t._handle.shards)
|
||||
|
||||
local_va_base = 0
|
||||
if t._handle.va_base:
|
||||
min_offset = min(s.offset_bytes for s in sip_shards)
|
||||
local_va_base = t._handle.va_base + min_offset
|
||||
|
||||
sip_kernel_args.append(TensorArg(
|
||||
shards=tuple(
|
||||
TensorArgShard(
|
||||
sip=s.sip, cube=s.cube, pe=s.pe,
|
||||
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
|
||||
)
|
||||
for s in sip_shards
|
||||
),
|
||||
va_base=local_va_base,
|
||||
))
|
||||
|
||||
for s in sip_shards:
|
||||
sip_cube_set.add(s.cube)
|
||||
|
||||
# Interleave tensor args and scalar args, replacing global dims with local
|
||||
final_args: list = []
|
||||
t_idx, s_idx = 0, 0
|
||||
for a in args:
|
||||
if isinstance(a, Tensor):
|
||||
final_args.append(sip_kernel_args[t_idx])
|
||||
t_idx += 1
|
||||
elif isinstance(a, (int, float)):
|
||||
sa = scalar_args[s_idx]
|
||||
if isinstance(a, int) and a in dim_map:
|
||||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
|
||||
final_args.append(sa)
|
||||
s_idx += 1
|
||||
while s_idx < len(scalar_args):
|
||||
sa = scalar_args[s_idx]
|
||||
if isinstance(sa.value, int) and int(sa.value) in dim_map:
|
||||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
|
||||
final_args.append(sa)
|
||||
s_idx += 1
|
||||
|
||||
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
|
||||
|
||||
h = self.submit(KernelLaunchMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"{kernel_name}_sip{sip_id}",
|
||||
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
|
||||
args=tuple(final_args),
|
||||
target_cubes=target_cubes,
|
||||
target_pe=target_pe,
|
||||
))
|
||||
self.wait(h, _meta={
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"sip": sip_id, "target_pe": target_pe,
|
||||
})
|
||||
last_handle = h
|
||||
|
||||
return last_handle
|
||||
|
||||
Reference in New Issue
Block a user