Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+42
View File
@@ -0,0 +1,42 @@
"""VA offset verification benchmark.
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
block from a sharded tensor and stores it back.
The kernel uses standard Triton patterns:
- tl.program_id(0) for PE index within cube
- tl.num_programs(0) for PE count within cube
- Shape args are automatically localized by launch()
"""
from kernbench.policy.placement.dp import DPPolicy
M, K = 128, 256
DTYPE = "f16"
def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton copy kernel. M and K are cube-local (set by launch)."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2 # f16
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def run(torch):
"""Run the VA offset verification benchmark with full TP sharding."""
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
# launch() automatically converts M, K to cube-local values
torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K)
# Sanity check: kernel completed with non-zero latency
kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"]
assert len(kernel_traces) > 0, "No kernel traces recorded"
for kt in kernel_traces:
assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"
+53
View File
@@ -0,0 +1,53 @@
# Component implementation registry.
# Maps impl names (used in topology.yaml) to Python class paths.
# Format: impl_name: module.path:ClassName
#
# ── Adding custom components ──────────────────────────────────────────
#
# 1. Create your implementation in:
# src/kernbench/components/custom/<your_component>.py
#
# Your class must inherit from ComponentBase (or PeEngineBase for PE engines).
#
# 2. Register it below under "Custom" with a unique impl name:
# my_pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
#
# 3. Reference it in topology.yaml:
# pe_cpu: { kind: pe_cpu, impl: my_pe_cpu_v2, attrs: { ... } }
#
# 4. Add unit tests in:
# tests/custom/test_<your_component>.py
#
# External packages also work — use the full module path:
# fast_gemm_v1: my_team.accel.fast_gemm:FastGemmComponent
# ──────────────────────────────────────────────────────────────────────
components:
# Infrastructure
forwarding_v1: kernbench.components.builtin.forwarding:TransitComponent
switch_v1: kernbench.components.builtin.forwarding:TransitComponent
noc_v1: kernbench.components.builtin.forwarding:TransitComponent
ucie_v1: kernbench.components.builtin.forwarding:TransitComponent
noc_2d_mesh_v1: kernbench.components.builtin.noc:TwoDMeshNocComponent
xbar_v1: kernbench.components.builtin.xbar:PositionAwareXbarComponent
# IO / Host interface
pcie_ep_v1: kernbench.components.builtin.pcie_ep:PcieEpComponent
io_cpu_v1: kernbench.components.builtin.io_cpu:IoCpuComponent
# Cube-level
m_cpu_v1: kernbench.components.builtin.m_cpu:MCpuComponent
hbm_ctrl_v1: kernbench.components.builtin.hbm_ctrl:HbmCtrlComponent
sram_v1: kernbench.components.builtin.sram:SramComponent
# PE-level
pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent
pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent
pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent
pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent
pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent
pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent
pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent
# Custom — add your implementations here
# pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
@@ -58,7 +58,7 @@ sufficient to execute kernels and issue DMA requests.
- Mapping strategy based on `DPPolicy.cube`:
- **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only.
Each cube's PEs see only their local PA. No cross-cube mapping installed.
- **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all
- **Sharded** (`cube="column_wise"`, etc.): broadcast all shard mappings to all
target cubes. Enables cross-PE and cross-cube DMA.
#### D3.4 Tensor Lifecycle
+3 -3
View File
@@ -163,11 +163,11 @@ DefaultComponent ← 안전한 fallback
## 슬라이드 7 — Registry 등록 방식
```python
# kernbench/components/impls/__init__.py
# kernbench/components/builtin/__init__.py
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.builtin.noc import TwoDMeshNocComponent
from kernbench.components.builtin.io_cpu import IoCpuComponent
# ...
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
+54 -4
View File
@@ -140,16 +140,65 @@ class ComponentRegistry:
Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl
2. _registry[node.impl] — globally registered impl (lazy import)
3. Error — no fallback; every node must have an impl
Registry is populated from components.yaml via load_components_yaml().
Manual register() is still supported for tests and overrides.
"""
_registry: dict[str, type[ComponentBase]] = {}
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
_loaded: bool = False
@classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls
@classmethod
def load_components_yaml(cls, path: str | None = None) -> None:
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
if cls._loaded:
return
import yaml
from pathlib import Path
if path is None:
# Search: project root (cwd), then relative to this file
candidates = [
Path.cwd() / "components.yaml",
Path(__file__).parent.parent.parent.parent / "components.yaml",
]
for p in candidates:
if p.exists():
path = str(p)
break
if path is None:
return
with open(path) as f:
spec = yaml.safe_load(f)
for impl, class_path in (spec.get("components") or {}).items():
cls._lazy[impl] = class_path
cls._loaded = True
@classmethod
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
if impl in cls._registry:
return cls._registry[impl]
if not cls._loaded:
cls.load_components_yaml()
class_path = cls._lazy.get(impl)
if class_path is None:
return None
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
component_cls = getattr(mod, class_name)
cls._registry[impl] = component_cls # cache for next lookup
return component_cls
@classmethod
def create(
cls,
@@ -159,9 +208,10 @@ class ComponentRegistry:
) -> ComponentBase:
if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx)
if node.impl in cls._registry:
return cls._registry[node.impl](node, ctx)
component_cls = cls._resolve(node.impl)
if component_cls is not None:
return component_cls(node, ctx)
raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__."
f"Add it to components.yaml or call ComponentRegistry.register()."
)
@@ -0,0 +1,34 @@
"""Concrete component implementations.
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
Manual imports are no longer needed — add new impls to components.yaml.
Classes are still importable from this package via lazy __getattr__.
"""
from kernbench.components.base import ComponentRegistry
ComponentRegistry.load_components_yaml()
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
# without eagerly importing every module.
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
def _build_class_map() -> None:
if _CLASS_MAP:
return
for class_path in ComponentRegistry._lazy.values():
module_path, class_name = class_path.rsplit(":", 1)
_CLASS_MAP[class_name] = class_path
def __getattr__(name: str):
_build_class_map()
class_path = _CLASS_MAP.get(name)
if class_path is None:
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Derive num_programs from the number of PE shards in this cube
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → VA base (or PA fallback), ScalarArg → value
# TensorArg → va_base (already local, set by runtime) or PA fallback
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
@@ -0,0 +1,5 @@
"""Custom component implementations.
Place your component files here and register them in components.yaml.
See components.yaml header for instructions.
"""
@@ -1,59 +0,0 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"PositionAwareXbarComponent",
"SramComponent",
]
+53 -35
View File
@@ -7,12 +7,36 @@ from typing import Literal
@dataclass(frozen=True)
class DPPolicy:
"""Two-level data-parallel policy: cube-level + pe-level."""
"""Three-level data-parallel policy: sip-level + cube-level + pe-level.
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate"
Policies:
- "replicate": full copy at each unit
- "column_wise": split K (column) axis across units
- "row_wise": split M (row) axis across units
"""
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
def _split_shape(
policy: str, shape: tuple[int, int], count: int, itemsize: int,
) -> list[tuple[tuple[int, int], int]]:
"""Split shape by policy into (sub_shape, byte_offset) pairs."""
M, K = shape
if policy == "replicate":
return [((M, K), 0)] * count
elif policy == "column_wise":
chunk_k = K // count
return [((M, chunk_k), i * M * chunk_k * itemsize) for i in range(count)]
elif policy == "row_wise":
chunk_m = M // count
return [((chunk_m, K), i * chunk_m * K * itemsize) for i in range(count)]
else:
raise ValueError(f"Unknown policy: {policy}")
def resolve_dp_policy(
policy: DPPolicy,
*,
@@ -20,11 +44,14 @@ def resolve_dp_policy(
itemsize: int,
num_pe: int,
num_cubes: int = 1,
num_sips: int = 1,
) -> list[ShardSpec]:
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution.
"""Resolve a DPPolicy into a list[ShardSpec] with three-level resolution.
Cube-level policy distributes across cubes, pe-level distributes within
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id.
SIP-level → cube-level → pe-level.
num_cubes is cubes per SIP (not total).
ShardSpec.pe_index uses flat indexing:
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
"""
_PE_RESOLVERS = {
"replicate": replicate,
@@ -35,40 +62,31 @@ def resolve_dp_policy(
if resolver is None:
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
if num_cubes <= 1:
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
# Two-level resolution: cube-level → pe-level
M, K = shape
cubes_per_sip = num_cubes
all_shards: list[ShardSpec] = []
for cube_id in range(num_cubes):
# Determine per-cube shape based on cube-level policy
if policy.cube == "replicate":
cube_shape = (M, K)
cube_offset = 0
elif policy.cube == "shard_m":
chunk_m = M // num_cubes
cube_shape = (chunk_m, K)
cube_offset = cube_id * chunk_m * K * itemsize
elif policy.cube == "shard_k":
chunk_k = K // num_cubes
cube_shape = (M, chunk_k)
cube_offset = cube_id * M * chunk_k * itemsize
else:
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
# Level 1: SIP
sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize)
# Resolve pe-level within this cube's shape
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
# Level 2: Cube within SIP
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
# Remap pe_index to flat index and adjust offset
for ps in pe_shards:
flat_idx = cube_id * num_pe + ps.pe_index
all_shards.append(ShardSpec(
pe_index=flat_idx,
offset_bytes=cube_offset + ps.offset_bytes,
nbytes=ps.nbytes,
))
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
# Level 3: PE within cube
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
for ps in pe_shards:
flat_idx = (
sip_id * cubes_per_sip * num_pe
+ cube_id * num_pe
+ ps.pe_index
)
all_shards.append(ShardSpec(
pe_index=flat_idx,
offset_bytes=sip_offset + cube_offset + ps.offset_bytes,
nbytes=ps.nbytes,
))
return all_shards
+161 -69
View File
@@ -20,7 +20,6 @@ class RuntimeContext:
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
@@ -208,24 +207,17 @@ class RuntimeContext:
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
@@ -276,11 +268,11 @@ class RuntimeContext:
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
num_sips=self._num_sips,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
@@ -297,7 +289,6 @@ class RuntimeContext:
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
@@ -305,6 +296,8 @@ class RuntimeContext:
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
# Strategy: always SIP-scoped (each SIP gets only its own shards).
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
@@ -313,47 +306,52 @@ class RuntimeContext:
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
# Group shards by SIP
sip_groups: dict[int, list] = defaultdict(list)
for shard in handle.shards:
sip_groups[shard.sip].append(shard)
for (sip, cube), group_shards in cube_groups.items():
for sip, sip_shards in sip_groups.items():
if is_cube_replicate:
# Cube replicate: per-(sip, cube) local mapping
cube_groups: dict[int, list] = defaultdict(list)
for s in sip_shards:
cube_groups[s.cube].append(s)
for cube, group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Cube sharded: broadcast all cubes within this SIP
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
for s in sip_shards
)
cube_set = sorted({s.cube for s in sip_shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
request_id=f"mmu_{tensor_name}_s{sip}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
@@ -384,11 +382,18 @@ class RuntimeContext:
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
"""
from collections import defaultdict
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
@@ -399,14 +404,14 @@ class RuntimeContext:
except ValueError:
pass
# Build kernel args from positional + keyword args
kernel_args: list = []
# Collect tensors and scalars
tensor_args: list[Tensor] = []
scalar_args: list = []
target_pe: int | str = 0
for a in args:
if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg())
# Infer target_pe from tensor DP metadata
tensor_args.append(a)
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
@@ -415,34 +420,121 @@ class RuntimeContext:
target_pe = dp_target
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards
cube_set: set[int] = set()
for a in args:
if isinstance(a, Tensor) and a._handle is not None:
for s in a._handle.shards:
cube_set.add(s.cube)
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
# Determine all target SIPs from tensor shards
sip_set: set[int] = set()
for t in tensor_args:
if t._handle is not None:
for s in t._handle.shards:
sip_set.add(s.sip)
if not sip_set:
sip_set = {0}
# Collect scalar values for GEMM FLOP calculation
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
# Build global→local dimension mapping from tensor DPPolicies.
# Scalar args matching a tensor's global dimension get replaced
# with the cube-local value (what the kernel actually operates on).
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
"""Compute cube-local shape from DPPolicy."""
shape = t.shape
if len(shape) < 2:
shape = (1, shape[0])
M, K = shape[0], shape[1]
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None:
return t.shape
if dp.sip != "replicate":
if dp.sip == "column_wise":
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate":
if dp.cube == "column_wise":
K = K // self._num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=kernel_name,
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals,
})
return h
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
for t in tensor_args:
if t._handle is None:
continue
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
if not sip_shards:
sip_shards = list(t._handle.shards)
local_va_base = 0
if t._handle.va_base:
min_offset = min(s.offset_bytes for s in sip_shards)
local_va_base = t._handle.va_base + min_offset
sip_kernel_args.append(TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in sip_shards
),
va_base=local_va_base,
))
for s in sip_shards:
sip_cube_set.add(s.cube)
# Interleave tensor args and scalar args, replacing global dims with local
final_args: list = []
t_idx, s_idx = 0, 0
for a in args:
if isinstance(a, Tensor):
final_args.append(sip_kernel_args[t_idx])
t_idx += 1
elif isinstance(a, (int, float)):
sa = scalar_args[s_idx]
if isinstance(a, int) and a in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
final_args.append(sa)
s_idx += 1
while s_idx < len(scalar_args):
sa = scalar_args[s_idx]
if isinstance(sa.value, int) and int(sa.value) in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
final_args.append(sa)
s_idx += 1
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=f"{kernel_name}_sip{sip_id}",
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(final_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
last_handle = h
return last_handle
+1 -11
View File
@@ -59,10 +59,7 @@ def deploy_tensor(
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize
@@ -78,22 +75,15 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=encoded_pa,
pa=pa.encode(),
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle(
name=name,
shape=shape,
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Any
import simpy
from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.impls # noqa: F401 — registers built-in implementations
import kernbench.components.builtin # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr
View File
+1 -1
View File
@@ -13,7 +13,7 @@ import simpy
from pathlib import Path
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.builtin.forwarding import TransitComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryReadMsg
from kernbench.sim_engine.engine import GraphEngine
+3 -3
View File
@@ -73,7 +73,7 @@ def test_mmu_unmap_msg_fields():
def test_pe_mmu_registry():
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node
node = Node(
@@ -93,7 +93,7 @@ def test_pe_mmu_registry():
def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works."""
import simpy
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node
@@ -152,7 +152,7 @@ def test_pe_dma_translates_va():
# This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.topology.types import Node
node = Node(
+8 -10
View File
@@ -20,12 +20,12 @@ from kernbench.common.pe_commands import (
TensorHandle,
)
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.builtin.pe_cpu import PeCpuComponent
from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.components.builtin.pe_gemm import PeGemmComponent
from kernbench.components.builtin.pe_math import PeMathComponent
from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent
from kernbench.components.builtin.pe_tcm import PeTcmComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
@@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy)
assert len(kernel_traces) == 1
assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology)
assert kernel_traces[0]["name"] == "qkv_gemm"
assert kernel_traces[0]["total_ns"] > 0
# Scalars should contain M, K, N
assert len(kernel_traces[0]["scalars"]) >= 3
clear_registry()
@@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8
assert len(kernel_traces) == 1
assert len(kernel_traces) >= 1 # one per SIP
assert kernel_traces[0]["target_pe"] == "all"
clear_registry()
+1 -1
View File
@@ -19,7 +19,7 @@ import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.components.impls import (
from kernbench.components.builtin import (
HbmCtrlComponent,
IoCpuComponent,
MCpuComponent,
+157
View File
@@ -0,0 +1,157 @@
"""Tests for SIP-level tensor parallelism.
Validates:
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
SP3. sip="row_wise": tensor M-axis split across SIPs
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
SP5. sip="replicate": all SIPs get full copy (existing behavior)
SP6. PE_CPU sets num_programs from shard count per cube
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
"""
import pytest
from pathlib import Path
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
def test_dp_policy_sip_default_replicate():
"""DPPolicy without sip= defaults to 'replicate'."""
dp = DPPolicy(cube="replicate", pe="column_wise")
assert dp.sip == "replicate"
def test_dp_policy_sip_column_wise():
"""DPPolicy accepts sip='column_wise'."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
assert dp.sip == "column_wise"
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
def test_sip_column_wise_splits_across_sips():
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 8 PEs = 16 shards
assert len(shards) == 16
# SIP0 shards: first half of K (0 to K//2)
# SIP1 shards: second half of K (K//2 to K)
total_bytes = 128 * 256 * 2 # 64KB
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0 offsets start at 0
assert sip0_shards[0].offset_bytes == 0
# SIP1 offsets start at half
assert sip1_shards[0].offset_bytes == total_bytes // 2
# Total coverage
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
def test_sip_row_wise_splits_across_sips():
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
assert len(shards) == 16
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0: rows 0..63, SIP1: rows 64..127
total_bytes = 128 * 256 * 2
assert sip0_shards[0].offset_bytes == 0
assert sip1_shards[0].offset_bytes == total_bytes // 2
# ── SP4. 3-level resolve ─────────────────────────────────────────────
def test_3level_resolve_flat_index():
"""3-level: sip × cube × pe produces correct flat indices."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
assert len(shards) == 32
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
indices = [s.pe_index for s in shards]
# SIP0: 0..15, SIP1: 16..31
assert min(indices) == 0
assert max(indices) == 31
assert len(set(indices)) == 32 # all unique
def test_3level_offsets_cover_full_tensor():
"""3-level sharding covers the entire tensor with no gaps."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 4 PEs = 8 shards
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
total = 128 * 256 * 2
# For non-replicate, total shard bytes == tensor bytes
# (replicate within cube means cube shards overlap, but sip shards don't)
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
assert sip0_bytes + sip1_bytes == total
# ── SP5. sip="replicate" backward compat ─────────────────────────────
def test_sip_replicate_backward_compat():
"""sip='replicate' produces same result as before (2-level)."""
dp_old = DPPolicy(cube="replicate", pe="column_wise")
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
shards_old = resolve_dp_policy(
dp_old, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
shards_new = resolve_dp_policy(
dp_new, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
assert len(shards_old) == len(shards_new)
for a, b in zip(shards_old, shards_new):
assert a.pe_index == b.pe_index
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
def test_pe_cpu_sets_num_programs():
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
# This test validates the interface contract.
# After implementation, PE_CPU should derive num_programs from the
# number of PE shards in the kernel launch's target cube.
from kernbench.triton_emu.tl_context import TLContext
# With 8 PEs per cube, num_programs should be 8
tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8
+5 -19
View File
@@ -88,7 +88,6 @@ def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
@@ -98,7 +97,6 @@ def test_deploy_tensor_assigns_va_base():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert th.va_base is not None
@@ -109,7 +107,6 @@ def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
@@ -119,41 +116,32 @@ def test_deploy_tensor_va_covers_all_shards():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Each shard's VA is derivable: va_base + offset_bytes
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_registers_mmu_mappings():
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
def test_deploy_tensor_does_not_install_mmu_mappings():
"""deploy_tensor does NOT install MMU mappings — that's context's job."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Every MMU should have entries (broadcast)
# No MMU should have any entries (mappings come from fabric MmuMapMsg)
for mmu in mmus.values():
assert mmu.num_entries > 0
# Each shard's derived VA should translate to its PA in every MMU
for mmu in mmus.values():
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert mmu.translate(shard_va) == s.pa
assert mmu.num_entries == 0
# ── T12. Tensor.va property ──────────────────────────────────────────
@@ -165,7 +153,6 @@ def test_tensor_va_property():
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
mmus = _make_mmus(1)
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
@@ -176,7 +163,6 @@ def test_tensor_va_property():
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert t.va > 0
assert t.va == t._handle.va_base
+216
View File
@@ -0,0 +1,216 @@
"""VA offset verification: each PE accesses its own local HBM slice.
Verifies that column-wise sharding + VA offset calculation produces DMA
addresses that translate to the correct PE's local HBM — not a remote PE.
Tests:
VO1. Per-PE DMA addresses are correct VAs (2D)
VO2. Each VA translates to the executing PE's own HBM slice (2D)
VO3. End-to-end bench completes (2D, full TP)
VO4. Per-PE DMA addresses are correct VAs (1D)
VO5. Each VA translates to local HBM (1D)
VO6. End-to-end 1D bench completes
"""
from pathlib import Path
import pytest
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import DPPolicy, column_wise
from kernbench.runtime_api.tensor import deploy_tensor
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
from kernbench.triton_emu.tl_context import TLContext, run_kernel
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
M, K = 128, 256
DTYPE = "f16"
NUM_PE = 8
ELEM_BYTES = 2
def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton 2D copy. M, K are cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"):
"""Standard Triton 1D copy. N is cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
elems_per_pe = N // num_pe
elem_bytes = 2
offset = pid * elems_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _make_standalone(shape, num_pe=NUM_PE):
"""Create standalone allocators + MMUs for unit testing."""
cfg = AddressConfig(
sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe,
hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe,
tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
allocators = {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
for i in range(num_pe)
}
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)}
return cfg, allocators, va_alloc, mmus
# ── VO1. 2D: Per-PE DMA addresses are correct VAs ────────────────────
def test_2d_each_pe_computes_correct_va_offset():
"""2D: each PE generates DMA at va_base + pid * block_bytes."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO2. 2D: Each VA translates to local HBM ─────────────────────────
def test_2d_va_translates_to_local_hbm():
"""2D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((M, K))
slice_size = cfg.hbm_slice_bytes
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src", shape=(M, K), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
# Install per-PE mappings (simulating what context does via MmuMapMsg)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO3. 2D: End-to-end bench completes ──────────────────────────────
def test_2d_bench_completes():
"""2D: full TP bench with standard Triton kernel pattern."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo3", spec=graph.spec,
)
from benches.va_offset_verify import run as bench_run
bench_run(ctx)
ctx.wait_all()
# ── VO4. 1D: Per-PE DMA addresses ────────────────────────────────────
N_1D = 1024
def test_1d_each_pe_computes_correct_offset():
"""1D: each PE generates DMA at correct offset."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO5. 1D: VA translates to local HBM ──────────────────────────────
def test_1d_va_translates_to_local_hbm():
"""1D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D))
slice_size = cfg.hbm_slice_bytes
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src_1d", shape=(N_1D,), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO6. 1D: End-to-end ──────────────────────────────────────────────
def test_1d_e2e_completes():
"""1D: full engine run with column_wise TP sharding."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo6", spec=graph.spec,
)
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
# launch() auto-localizes N_1D → cube-local N
ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D)
ctx.wait_all()