Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+42
View File
@@ -0,0 +1,42 @@
"""VA offset verification benchmark.
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
block from a sharded tensor and stores it back.
The kernel uses standard Triton patterns:
- tl.program_id(0) for PE index within cube
- tl.num_programs(0) for PE count within cube
- Shape args are automatically localized by launch()
"""
from kernbench.policy.placement.dp import DPPolicy
M, K = 128, 256
DTYPE = "f16"
def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton copy kernel. M and K are cube-local (set by launch)."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2 # f16
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def run(torch):
"""Run the VA offset verification benchmark with full TP sharding."""
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
# launch() automatically converts M, K to cube-local values
torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K)
# Sanity check: kernel completed with non-zero latency
kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"]
assert len(kernel_traces) > 0, "No kernel traces recorded"
for kt in kernel_traces:
assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"
+53
View File
@@ -0,0 +1,53 @@
# Component implementation registry.
# Maps impl names (used in topology.yaml) to Python class paths.
# Format: impl_name: module.path:ClassName
#
# ── Adding custom components ──────────────────────────────────────────
#
# 1. Create your implementation in:
# src/kernbench/components/custom/<your_component>.py
#
# Your class must inherit from ComponentBase (or PeEngineBase for PE engines).
#
# 2. Register it below under "Custom" with a unique impl name:
# my_pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
#
# 3. Reference it in topology.yaml:
# pe_cpu: { kind: pe_cpu, impl: my_pe_cpu_v2, attrs: { ... } }
#
# 4. Add unit tests in:
# tests/custom/test_<your_component>.py
#
# External packages also work — use the full module path:
# fast_gemm_v1: my_team.accel.fast_gemm:FastGemmComponent
# ──────────────────────────────────────────────────────────────────────
components:
# Infrastructure
forwarding_v1: kernbench.components.builtin.forwarding:TransitComponent
switch_v1: kernbench.components.builtin.forwarding:TransitComponent
noc_v1: kernbench.components.builtin.forwarding:TransitComponent
ucie_v1: kernbench.components.builtin.forwarding:TransitComponent
noc_2d_mesh_v1: kernbench.components.builtin.noc:TwoDMeshNocComponent
xbar_v1: kernbench.components.builtin.xbar:PositionAwareXbarComponent
# IO / Host interface
pcie_ep_v1: kernbench.components.builtin.pcie_ep:PcieEpComponent
io_cpu_v1: kernbench.components.builtin.io_cpu:IoCpuComponent
# Cube-level
m_cpu_v1: kernbench.components.builtin.m_cpu:MCpuComponent
hbm_ctrl_v1: kernbench.components.builtin.hbm_ctrl:HbmCtrlComponent
sram_v1: kernbench.components.builtin.sram:SramComponent
# PE-level
pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent
pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent
pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent
pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent
pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent
pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent
pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent
# Custom — add your implementations here
# pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
@@ -58,7 +58,7 @@ sufficient to execute kernels and issue DMA requests.
- Mapping strategy based on `DPPolicy.cube`: - Mapping strategy based on `DPPolicy.cube`:
- **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only. - **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only.
Each cube's PEs see only their local PA. No cross-cube mapping installed. Each cube's PEs see only their local PA. No cross-cube mapping installed.
- **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all - **Sharded** (`cube="column_wise"`, etc.): broadcast all shard mappings to all
target cubes. Enables cross-PE and cross-cube DMA. target cubes. Enables cross-PE and cross-cube DMA.
#### D3.4 Tensor Lifecycle #### D3.4 Tensor Lifecycle
+3 -3
View File
@@ -163,11 +163,11 @@ DefaultComponent ← 안전한 fallback
## 슬라이드 7 — Registry 등록 방식 ## 슬라이드 7 — Registry 등록 방식
```python ```python
# kernbench/components/impls/__init__.py # kernbench/components/builtin/__init__.py
from kernbench.components.base import ComponentRegistry from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.noc import TwoDMeshNocComponent from kernbench.components.builtin.noc import TwoDMeshNocComponent
from kernbench.components.impls.io_cpu import IoCpuComponent from kernbench.components.builtin.io_cpu import IoCpuComponent
# ... # ...
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent) ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
+54 -4
View File
@@ -140,16 +140,65 @@ class ComponentRegistry:
Resolution order for ComponentRegistry.create(node, overrides, ctx): Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override 1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl 2. _registry[node.impl] — globally registered impl (lazy import)
3. Error — no fallback; every node must have an impl 3. Error — no fallback; every node must have an impl
Registry is populated from components.yaml via load_components_yaml().
Manual register() is still supported for tests and overrides.
""" """
_registry: dict[str, type[ComponentBase]] = {} _registry: dict[str, type[ComponentBase]] = {}
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
_loaded: bool = False
@classmethod @classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None: def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls cls._registry[impl] = component_cls
@classmethod
def load_components_yaml(cls, path: str | None = None) -> None:
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
if cls._loaded:
return
import yaml
from pathlib import Path
if path is None:
# Search: project root (cwd), then relative to this file
candidates = [
Path.cwd() / "components.yaml",
Path(__file__).parent.parent.parent.parent / "components.yaml",
]
for p in candidates:
if p.exists():
path = str(p)
break
if path is None:
return
with open(path) as f:
spec = yaml.safe_load(f)
for impl, class_path in (spec.get("components") or {}).items():
cls._lazy[impl] = class_path
cls._loaded = True
@classmethod
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
if impl in cls._registry:
return cls._registry[impl]
if not cls._loaded:
cls.load_components_yaml()
class_path = cls._lazy.get(impl)
if class_path is None:
return None
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
component_cls = getattr(mod, class_name)
cls._registry[impl] = component_cls # cache for next lookup
return component_cls
@classmethod @classmethod
def create( def create(
cls, cls,
@@ -159,9 +208,10 @@ class ComponentRegistry:
) -> ComponentBase: ) -> ComponentBase:
if overrides and node.impl in overrides: if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx) return overrides[node.impl](node, ctx)
if node.impl in cls._registry: component_cls = cls._resolve(node.impl)
return cls._registry[node.impl](node, ctx) if component_cls is not None:
return component_cls(node, ctx)
raise ValueError( raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). " f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__." f"Add it to components.yaml or call ComponentRegistry.register()."
) )
@@ -0,0 +1,34 @@
"""Concrete component implementations.
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
Manual imports are no longer needed — add new impls to components.yaml.
Classes are still importable from this package via lazy __getattr__.
"""
from kernbench.components.base import ComponentRegistry
ComponentRegistry.load_components_yaml()
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
# without eagerly importing every module.
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
def _build_class_map() -> None:
if _CLASS_MAP:
return
for class_path in ComponentRegistry._lazy.values():
module_path, class_name = class_path.rsplit(":", 1)
_CLASS_MAP[class_name] = class_path
def __getattr__(name: str):
_build_class_map()
class_path = _CLASS_MAP.get(name)
if class_path is None:
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
yield from self.run(env, 0) yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name) kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Derive num_programs from the number of PE shards in this cube
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function # Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → VA base (or PA fallback), ScalarArg → value # TensorArg → va_base (already local, set by runtime) or PA fallback
kernel_args: list = [] kernel_args: list = []
for arg in request.args: for arg in request.args:
if arg.arg_kind == "tensor": if arg.arg_kind == "tensor":
@@ -0,0 +1,5 @@
"""Custom component implementations.
Place your component files here and register them in components.yaml.
See components.yaml header for instructions.
"""
@@ -1,59 +0,0 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"PositionAwareXbarComponent",
"SramComponent",
]
+47 -29
View File
@@ -7,12 +7,36 @@ from typing import Literal
@dataclass(frozen=True) @dataclass(frozen=True)
class DPPolicy: class DPPolicy:
"""Two-level data-parallel policy: cube-level + pe-level.""" """Three-level data-parallel policy: sip-level + cube-level + pe-level.
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate" Policies:
- "replicate": full copy at each unit
- "column_wise": split K (column) axis across units
- "row_wise": split M (row) axis across units
"""
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
def _split_shape(
policy: str, shape: tuple[int, int], count: int, itemsize: int,
) -> list[tuple[tuple[int, int], int]]:
"""Split shape by policy into (sub_shape, byte_offset) pairs."""
M, K = shape
if policy == "replicate":
return [((M, K), 0)] * count
elif policy == "column_wise":
chunk_k = K // count
return [((M, chunk_k), i * M * chunk_k * itemsize) for i in range(count)]
elif policy == "row_wise":
chunk_m = M // count
return [((chunk_m, K), i * chunk_m * K * itemsize) for i in range(count)]
else:
raise ValueError(f"Unknown policy: {policy}")
def resolve_dp_policy( def resolve_dp_policy(
policy: DPPolicy, policy: DPPolicy,
*, *,
@@ -20,11 +44,14 @@ def resolve_dp_policy(
itemsize: int, itemsize: int,
num_pe: int, num_pe: int,
num_cubes: int = 1, num_cubes: int = 1,
num_sips: int = 1,
) -> list[ShardSpec]: ) -> list[ShardSpec]:
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution. """Resolve a DPPolicy into a list[ShardSpec] with three-level resolution.
Cube-level policy distributes across cubes, pe-level distributes within SIP-level → cube-level → pe-level.
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id. num_cubes is cubes per SIP (not total).
ShardSpec.pe_index uses flat indexing:
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
""" """
_PE_RESOLVERS = { _PE_RESOLVERS = {
"replicate": replicate, "replicate": replicate,
@@ -35,38 +62,29 @@ def resolve_dp_policy(
if resolver is None: if resolver is None:
raise ValueError(f"Unknown pe-level policy: {policy.pe}") raise ValueError(f"Unknown pe-level policy: {policy.pe}")
if num_cubes <= 1: cubes_per_sip = num_cubes
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
# Two-level resolution: cube-level → pe-level
M, K = shape
all_shards: list[ShardSpec] = [] all_shards: list[ShardSpec] = []
for cube_id in range(num_cubes): # Level 1: SIP
# Determine per-cube shape based on cube-level policy sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize)
if policy.cube == "replicate":
cube_shape = (M, K)
cube_offset = 0
elif policy.cube == "shard_m":
chunk_m = M // num_cubes
cube_shape = (chunk_m, K)
cube_offset = cube_id * chunk_m * K * itemsize
elif policy.cube == "shard_k":
chunk_k = K // num_cubes
cube_shape = (M, chunk_k)
cube_offset = cube_id * M * chunk_k * itemsize
else:
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
# Resolve pe-level within this cube's shape for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
# Level 2: Cube within SIP
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
# Level 3: PE within cube
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
# Remap pe_index to flat index and adjust offset
for ps in pe_shards: for ps in pe_shards:
flat_idx = cube_id * num_pe + ps.pe_index flat_idx = (
sip_id * cubes_per_sip * num_pe
+ cube_id * num_pe
+ ps.pe_index
)
all_shards.append(ShardSpec( all_shards.append(ShardSpec(
pe_index=flat_idx, pe_index=flat_idx,
offset_bytes=cube_offset + ps.offset_bytes, offset_bytes=sip_offset + cube_offset + ps.offset_bytes,
nbytes=ps.nbytes, nbytes=ps.nbytes,
)) ))
+134 -42
View File
@@ -20,7 +20,6 @@ class RuntimeContext:
_completed: set[RequestHandle] = field(default_factory=set, init=False) _completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False) _allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False) _va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False) _tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False) _traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False) _tensors: list[Any] = field(default_factory=list, init=False)
@@ -208,24 +207,17 @@ class RuntimeContext:
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg, rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
) )
# Initialize VA allocator and per-PE MMUs # Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {}) pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096)) page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator( self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000, va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size, page_size=page_size,
) )
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators return self._allocators
@@ -276,11 +268,11 @@ class RuntimeContext:
dp_policy = dp dp_policy = dp
allocators = self._ensure_allocators() allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype) itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int] shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy( placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize, dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes, num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
num_sips=self._num_sips,
) )
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index # Infer target_pe from placement: multi-PE → "all", single PE → pe_index
@@ -297,7 +289,6 @@ class RuntimeContext:
placement=placement, placement=placement,
allocators=allocators, allocators=allocators,
va_allocator=self._va_allocator, va_allocator=self._va_allocator,
mmus=self._mmus,
) )
t._handle = handle t._handle = handle
import weakref import weakref
@@ -305,6 +296,8 @@ class RuntimeContext:
self._tensors.append(weakref.ref(t)) self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg # Install VA→PA mappings via fabric MmuMapMsg
# Strategy: always SIP-scoped (each SIP gets only its own shards).
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
if handle.va_base: if handle.va_base:
from collections import defaultdict from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg from kernbench.runtime_api.kernel import MmuMapMsg
@@ -313,13 +306,19 @@ class RuntimeContext:
dp_policy is not None and dp_policy.cube == "replicate" dp_policy is not None and dp_policy.cube == "replicate"
) )
if is_cube_replicate: # Group shards by SIP
# Replicate: each (sip, cube) gets only its own local PA mappings sip_groups: dict[int, list] = defaultdict(list)
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards: for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard) sip_groups[shard.sip].append(shard)
for (sip, cube), group_shards in cube_groups.items(): for sip, sip_shards in sip_groups.items():
if is_cube_replicate:
# Cube replicate: per-(sip, cube) local mapping
cube_groups: dict[int, list] = defaultdict(list)
for s in sip_shards:
cube_groups[s.cube].append(s)
for cube, group_shards in cube_groups.items():
entries = tuple( entries = tuple(
{"va": handle.va_base + s.offset_bytes, {"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes} "pa": s.pa, "size": s.nbytes}
@@ -336,19 +335,18 @@ class RuntimeContext:
h = self.submit(msg) h = self.submit(msg)
self.wait(h) self.wait(h)
else: else:
# Sharded: broadcast all mappings to all target (sip, cube)s # Cube sharded: broadcast all cubes within this SIP
entries = tuple( entries = tuple(
{"va": handle.va_base + s.offset_bytes, {"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes} "pa": s.pa, "size": s.nbytes}
for s in handle.shards for s in sip_shards
) )
sip_set = sorted({s.sip for s in handle.shards}) cube_set = sorted({s.cube for s in sip_shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg( msg = MmuMapMsg(
correlation_id=self.correlation_id, correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}", request_id=f"mmu_{tensor_name}_s{sip}",
entries=entries, entries=entries,
target_sips=tuple(sip_set), target_sips=(sip,),
target_cubes=tuple(cube_set), target_cubes=tuple(cube_set),
target_pe="all", target_pe="all",
) )
@@ -384,11 +382,18 @@ class RuntimeContext:
Positional args: Tensor objects become TensorArg, int/float become ScalarArg. Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved). Keyword args: become ScalarArg (name is discarded, order preserved).
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
""" """
from collections import defaultdict
from kernbench.runtime_api.kernel import ( from kernbench.runtime_api.kernel import (
KernelLaunchMsg, KernelLaunchMsg,
KernelRef, KernelRef,
ScalarArg, ScalarArg,
TensorArg,
TensorArgShard,
) )
from kernbench.runtime_api.tensor import Tensor from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel from kernbench.triton_emu.registry import register_kernel
@@ -399,14 +404,14 @@ class RuntimeContext:
except ValueError: except ValueError:
pass pass
# Build kernel args from positional + keyword args # Collect tensors and scalars
kernel_args: list = [] tensor_args: list[Tensor] = []
scalar_args: list = []
target_pe: int | str = 0 target_pe: int | str = 0
for a in args: for a in args:
if isinstance(a, Tensor): if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg()) tensor_args.append(a)
# Infer target_pe from tensor DP metadata
if a._dp_metadata is not None: if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe dp_target = a._dp_metadata.target_pe
if dp_target == "all": if dp_target == "all":
@@ -415,34 +420,121 @@ class RuntimeContext:
target_pe = dp_target target_pe = dp_target
elif isinstance(a, (int, float)): elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32" dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a)) scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values(): for v in kwargs.values():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32" dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v)) scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards # Determine all target SIPs from tensor shards
cube_set: set[int] = set() sip_set: set[int] = set()
for t in tensor_args:
if t._handle is not None:
for s in t._handle.shards:
sip_set.add(s.sip)
if not sip_set:
sip_set = {0}
# Build global→local dimension mapping from tensor DPPolicies.
# Scalar args matching a tensor's global dimension get replaced
# with the cube-local value (what the kernel actually operates on).
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
"""Compute cube-local shape from DPPolicy."""
shape = t.shape
if len(shape) < 2:
shape = (1, shape[0])
M, K = shape[0], shape[1]
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None:
return t.shape
if dp.sip != "replicate":
if dp.sip == "column_wise":
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate":
if dp.cube == "column_wise":
K = K // self._num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
for t in tensor_args:
if t._handle is None:
continue
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
if not sip_shards:
sip_shards = list(t._handle.shards)
local_va_base = 0
if t._handle.va_base:
min_offset = min(s.offset_bytes for s in sip_shards)
local_va_base = t._handle.va_base + min_offset
sip_kernel_args.append(TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in sip_shards
),
va_base=local_va_base,
))
for s in sip_shards:
sip_cube_set.add(s.cube)
# Interleave tensor args and scalar args, replacing global dims with local
final_args: list = []
t_idx, s_idx = 0, 0
for a in args: for a in args:
if isinstance(a, Tensor) and a._handle is not None: if isinstance(a, Tensor):
for s in a._handle.shards: final_args.append(sip_kernel_args[t_idx])
cube_set.add(s.cube) t_idx += 1
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,) elif isinstance(a, (int, float)):
sa = scalar_args[s_idx]
if isinstance(a, int) and a in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
final_args.append(sa)
s_idx += 1
while s_idx < len(scalar_args):
sa = scalar_args[s_idx]
if isinstance(sa.value, int) and int(sa.value) in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
final_args.append(sa)
s_idx += 1
# Collect scalar values for GEMM FLOP calculation target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
h = self.submit(KernelLaunchMsg( h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id, correlation_id=self.correlation_id,
request_id=kernel_name, request_id=f"{kernel_name}_sip{sip_id}",
kernel_ref=KernelRef(name=kernel_name, kind="builtin"), kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args), args=tuple(final_args),
target_cubes=target_cubes, target_cubes=target_cubes,
target_pe=target_pe, target_pe=target_pe,
)) ))
self.wait(h, _meta={ self.wait(h, _meta={
"phase": "kernel", "name": kernel_name, "phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals, "sip": sip_id, "target_pe": target_pe,
}) })
return h last_handle = h
return last_handle
+1 -11
View File
@@ -59,10 +59,7 @@ def deploy_tensor(
allocators: dict[int, PEMemAllocator], allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm", mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None, va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle: ) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype) isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize total_nbytes = math.prod(shape) * isize
@@ -78,22 +75,15 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes) pa = alloc.alloc_hbm(spec.nbytes)
else: else:
pa = alloc.alloc_tcm(spec.nbytes) pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard( shards.append(TensorShard(
sip=alloc._sip_id, sip=alloc._sip_id,
cube=alloc._cube_id, cube=alloc._cube_id,
pe=alloc._pe_id, pe=alloc._pe_id,
pa=encoded_pa, pa=pa.encode(),
nbytes=spec.nbytes, nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes, offset_bytes=spec.offset_bytes,
)) ))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle( return TensorHandle(
name=name, name=name,
shape=shape, shape=shape,
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Any
import simpy import simpy
from kernbench.common.types import Completion, RequestHandle, Trace from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.impls # noqa: F401 — registers built-in implementations import kernbench.components.builtin # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.phyaddr import PhysAddr
View File
+1 -1
View File
@@ -13,7 +13,7 @@ import simpy
from pathlib import Path from pathlib import Path
from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent from kernbench.components.builtin.forwarding import TransitComponent
from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryReadMsg from kernbench.runtime_api.kernel import MemoryReadMsg
from kernbench.sim_engine.engine import GraphEngine from kernbench.sim_engine.engine import GraphEngine
+3 -3
View File
@@ -73,7 +73,7 @@ def test_mmu_unmap_msg_fields():
def test_pe_mmu_registry(): def test_pe_mmu_registry():
"""pe_mmu_v1 impl resolves in ComponentRegistry.""" """pe_mmu_v1 impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_mmu import PeMmuComponent from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node from kernbench.topology.types import Node
node = Node( node = Node(
@@ -93,7 +93,7 @@ def test_pe_mmu_registry():
def test_pe_mmu_processes_map_msg(): def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works.""" """PE_MMU component receives MmuMapMsg → translate works."""
import simpy import simpy
from kernbench.components.impls.pe_mmu import PeMmuComponent from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node from kernbench.topology.types import Node
@@ -152,7 +152,7 @@ def test_pe_dma_translates_va():
# This test validates the interface contract. Full integration test # This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine. # requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call. # Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.impls.pe_dma import PeDmaComponent from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.topology.types import Node from kernbench.topology.types import Node
node = Node( node = Node(
+8 -10
View File
@@ -20,12 +20,12 @@ from kernbench.common.pe_commands import (
TensorHandle, TensorHandle,
) )
from kernbench.components.base import ComponentRegistry from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_cpu import PeCpuComponent from kernbench.components.builtin.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent from kernbench.components.builtin.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent from kernbench.components.builtin.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent from kernbench.components.builtin.pe_tcm import PeTcmComponent
from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import ( from kernbench.runtime_api.kernel import (
KernelLaunchMsg, KernelLaunchMsg,
@@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")] deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"] kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy) assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy)
assert len(kernel_traces) == 1 assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology)
assert kernel_traces[0]["name"] == "qkv_gemm" assert kernel_traces[0]["name"] == "qkv_gemm"
assert kernel_traces[0]["total_ns"] > 0 assert kernel_traces[0]["total_ns"] > 0
# Scalars should contain M, K, N
assert len(kernel_traces[0]["scalars"]) >= 3
clear_registry() clear_registry()
@@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes():
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")] deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"] kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8 assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8
assert len(kernel_traces) == 1 assert len(kernel_traces) >= 1 # one per SIP
assert kernel_traces[0]["target_pe"] == "all" assert kernel_traces[0]["target_pe"] == "all"
clear_registry() clear_registry()
+1 -1
View File
@@ -19,7 +19,7 @@ import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext from kernbench.components.context import ComponentContext
from kernbench.components.impls import ( from kernbench.components.builtin import (
HbmCtrlComponent, HbmCtrlComponent,
IoCpuComponent, IoCpuComponent,
MCpuComponent, MCpuComponent,
+157
View File
@@ -0,0 +1,157 @@
"""Tests for SIP-level tensor parallelism.
Validates:
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
SP3. sip="row_wise": tensor M-axis split across SIPs
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
SP5. sip="replicate": all SIPs get full copy (existing behavior)
SP6. PE_CPU sets num_programs from shard count per cube
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
"""
import pytest
from pathlib import Path
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
def test_dp_policy_sip_default_replicate():
"""DPPolicy without sip= defaults to 'replicate'."""
dp = DPPolicy(cube="replicate", pe="column_wise")
assert dp.sip == "replicate"
def test_dp_policy_sip_column_wise():
"""DPPolicy accepts sip='column_wise'."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
assert dp.sip == "column_wise"
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
def test_sip_column_wise_splits_across_sips():
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 8 PEs = 16 shards
assert len(shards) == 16
# SIP0 shards: first half of K (0 to K//2)
# SIP1 shards: second half of K (K//2 to K)
total_bytes = 128 * 256 * 2 # 64KB
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0 offsets start at 0
assert sip0_shards[0].offset_bytes == 0
# SIP1 offsets start at half
assert sip1_shards[0].offset_bytes == total_bytes // 2
# Total coverage
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
def test_sip_row_wise_splits_across_sips():
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2,
)
assert len(shards) == 16
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0: rows 0..63, SIP1: rows 64..127
total_bytes = 128 * 256 * 2
assert sip0_shards[0].offset_bytes == 0
assert sip1_shards[0].offset_bytes == total_bytes // 2
# ── SP4. 3-level resolve ─────────────────────────────────────────────
def test_3level_resolve_flat_index():
"""3-level: sip × cube × pe produces correct flat indices."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
assert len(shards) == 32
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
indices = [s.pe_index for s in shards]
# SIP0: 0..15, SIP1: 16..31
assert min(indices) == 0
assert max(indices) == 31
assert len(set(indices)) == 32 # all unique
def test_3level_offsets_cover_full_tensor():
"""3-level sharding covers the entire tensor with no gaps."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=1, num_sips=2,
)
# 2 SIPs × 1 cube × 4 PEs = 8 shards
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
total = 128 * 256 * 2
# For non-replicate, total shard bytes == tensor bytes
# (replicate within cube means cube shards overlap, but sip shards don't)
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
assert sip0_bytes + sip1_bytes == total
# ── SP5. sip="replicate" backward compat ─────────────────────────────
def test_sip_replicate_backward_compat():
"""sip='replicate' produces same result as before (2-level)."""
dp_old = DPPolicy(cube="replicate", pe="column_wise")
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
shards_old = resolve_dp_policy(
dp_old, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
shards_new = resolve_dp_policy(
dp_new, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
assert len(shards_old) == len(shards_new)
for a, b in zip(shards_old, shards_new):
assert a.pe_index == b.pe_index
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
def test_pe_cpu_sets_num_programs():
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
# This test validates the interface contract.
# After implementation, PE_CPU should derive num_programs from the
# number of PE shards in the kernel launch's target cube.
from kernbench.triton_emu.tl_context import TLContext
# With 8 PEs per cube, num_programs should be 8
tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8
+5 -19
View File
@@ -88,7 +88,6 @@ def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle.""" """deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor( th = deploy_tensor(
@@ -98,7 +97,6 @@ def test_deploy_tensor_assigns_va_base():
placement=placement, placement=placement,
allocators=allocs, allocators=allocs,
va_allocator=va_alloc, va_allocator=va_alloc,
mmus=mmus,
) )
assert th.va_base is not None assert th.va_base is not None
@@ -109,7 +107,6 @@ def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset.""" """VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor( th = deploy_tensor(
@@ -119,41 +116,32 @@ def test_deploy_tensor_va_covers_all_shards():
placement=placement, placement=placement,
allocators=allocs, allocators=allocs,
va_allocator=va_alloc, va_allocator=va_alloc,
mmus=mmus,
) )
# Each shard's VA is derivable: va_base + offset_bytes
for s in th.shards: for s in th.shards:
shard_va = th.va_base + s.offset_bytes shard_va = th.va_base + s.offset_bytes
assert shard_va > 0 assert shard_va > 0
def test_deploy_tensor_registers_mmu_mappings(): def test_deploy_tensor_does_not_install_mmu_mappings():
"""deploy_tensor registers VA→PA mappings in all PE MMUs.""" """deploy_tensor does NOT install MMU mappings — that's context's job."""
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
mmus = _make_mmus() mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor( deploy_tensor(
name="W", name="W",
shape=(1024, 512), shape=(1024, 512),
dtype="fp16", dtype="fp16",
placement=placement, placement=placement,
allocators=allocs, allocators=allocs,
va_allocator=va_alloc, va_allocator=va_alloc,
mmus=mmus,
) )
# Every MMU should have entries (broadcast) # No MMU should have any entries (mappings come from fabric MmuMapMsg)
for mmu in mmus.values(): for mmu in mmus.values():
assert mmu.num_entries > 0 assert mmu.num_entries == 0
# Each shard's derived VA should translate to its PA in every MMU
for mmu in mmus.values():
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert mmu.translate(shard_va) == s.pa
# ── T12. Tensor.va property ────────────────────────────────────────── # ── T12. Tensor.va property ──────────────────────────────────────────
@@ -165,7 +153,6 @@ def test_tensor_va_property():
allocs = _make_allocators(1) allocs = _make_allocators(1)
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
mmus = _make_mmus(1)
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)] placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test") t = Tensor(shape=(2048,), dtype="f16", name="test")
@@ -176,7 +163,6 @@ def test_tensor_va_property():
placement=placement, placement=placement,
allocators=allocs, allocators=allocs,
va_allocator=va_alloc, va_allocator=va_alloc,
mmus=mmus,
) )
assert t.va > 0 assert t.va > 0
assert t.va == t._handle.va_base assert t.va == t._handle.va_base
+216
View File
@@ -0,0 +1,216 @@
"""VA offset verification: each PE accesses its own local HBM slice.
Verifies that column-wise sharding + VA offset calculation produces DMA
addresses that translate to the correct PE's local HBM — not a remote PE.
Tests:
VO1. Per-PE DMA addresses are correct VAs (2D)
VO2. Each VA translates to the executing PE's own HBM slice (2D)
VO3. End-to-end bench completes (2D, full TP)
VO4. Per-PE DMA addresses are correct VAs (1D)
VO5. Each VA translates to local HBM (1D)
VO6. End-to-end 1D bench completes
"""
from pathlib import Path
import pytest
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import DPPolicy, column_wise
from kernbench.runtime_api.tensor import deploy_tensor
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
from kernbench.triton_emu.tl_context import TLContext, run_kernel
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
M, K = 128, 256
DTYPE = "f16"
NUM_PE = 8
ELEM_BYTES = 2
def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton 2D copy. M, K are cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"):
"""Standard Triton 1D copy. N is cube-local."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
elems_per_pe = N // num_pe
elem_bytes = 2
offset = pid * elems_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def _make_standalone(shape, num_pe=NUM_PE):
"""Create standalone allocators + MMUs for unit testing."""
cfg = AddressConfig(
sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe,
hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe,
tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
allocators = {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
for i in range(num_pe)
}
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)}
return cfg, allocators, va_alloc, mmus
# ── VO1. 2D: Per-PE DMA addresses are correct VAs ────────────────────
def test_2d_each_pe_computes_correct_va_offset():
"""2D: each PE generates DMA at va_base + pid * block_bytes."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO2. 2D: Each VA translates to local HBM ─────────────────────────
def test_2d_va_translates_to_local_hbm():
"""2D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((M, K))
slice_size = cfg.hbm_slice_bytes
cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src", shape=(M, K), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
# Install per-PE mappings (simulating what context does via MmuMapMsg)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO3. 2D: End-to-end bench completes ──────────────────────────────
def test_2d_bench_completes():
"""2D: full TP bench with standard Triton kernel pattern."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo3", spec=graph.spec,
)
from benches.va_offset_verify import run as bench_run
bench_run(ctx)
ctx.wait_all()
# ── VO4. 1D: Per-PE DMA addresses ────────────────────────────────────
N_1D = 1024
def test_1d_each_pe_computes_correct_offset():
"""1D: each PE generates DMA at correct offset."""
src_va = 0x1_0000_0000
dst_va = 0x2_0000_0000
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
for pe_id in range(NUM_PE):
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D)
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
expected_offset = pe_id * block_bytes
assert reads[0].src_addr == src_va + expected_offset
assert writes[0].dst_addr == dst_va + expected_offset
# ── VO5. 1D: VA translates to local HBM ──────────────────────────────
def test_1d_va_translates_to_local_hbm():
"""1D: each PE's DMA VA translates to its own HBM slice."""
cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D))
slice_size = cfg.hbm_slice_bytes
elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE)
handle = deploy_tensor(
name="src_1d", shape=(N_1D,), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc,
)
for s in handle.shards:
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
for pe_id in range(NUM_PE):
va = handle.va_base + pe_id * block_bytes
pa = mmus[pe_id].translate(va)
decoded = PhysAddr.decode(pa)
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM"
# ── VO6. 1D: End-to-end ──────────────────────────────────────────────
def test_1d_e2e_completes():
"""1D: full engine run with column_wise TP sharding."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo6", spec=graph.spec,
)
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
# launch() auto-localizes N_1D → cube-local N
ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D)
ctx.wait_all()