Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""VA offset verification benchmark.
|
||||
|
||||
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
|
||||
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
|
||||
block from a sharded tensor and stores it back.
|
||||
|
||||
The kernel uses standard Triton patterns:
|
||||
- tl.program_id(0) for PE index within cube
|
||||
- tl.num_programs(0) for PE count within cube
|
||||
- Shape args are automatically localized by launch()
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
M, K = 128, 256
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
|
||||
"""Standard Triton copy kernel. M and K are cube-local (set by launch)."""
|
||||
pid = tl.program_id(0)
|
||||
num_pe = tl.num_programs(0)
|
||||
cols_per_pe = K // num_pe
|
||||
elem_bytes = 2 # f16
|
||||
offset = pid * M * cols_per_pe * elem_bytes
|
||||
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
|
||||
tl.store(dst_ptr + offset, data)
|
||||
|
||||
|
||||
def run(torch):
|
||||
"""Run the VA offset verification benchmark with full TP sharding."""
|
||||
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
|
||||
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
|
||||
|
||||
# launch() automatically converts M, K to cube-local values
|
||||
torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K)
|
||||
|
||||
# Sanity check: kernel completed with non-zero latency
|
||||
kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"]
|
||||
assert len(kernel_traces) > 0, "No kernel traces recorded"
|
||||
for kt in kernel_traces:
|
||||
assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"
|
||||
@@ -0,0 +1,53 @@
|
||||
# Component implementation registry.
|
||||
# Maps impl names (used in topology.yaml) to Python class paths.
|
||||
# Format: impl_name: module.path:ClassName
|
||||
#
|
||||
# ── Adding custom components ──────────────────────────────────────────
|
||||
#
|
||||
# 1. Create your implementation in:
|
||||
# src/kernbench/components/custom/<your_component>.py
|
||||
#
|
||||
# Your class must inherit from ComponentBase (or PeEngineBase for PE engines).
|
||||
#
|
||||
# 2. Register it below under "Custom" with a unique impl name:
|
||||
# my_pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
|
||||
#
|
||||
# 3. Reference it in topology.yaml:
|
||||
# pe_cpu: { kind: pe_cpu, impl: my_pe_cpu_v2, attrs: { ... } }
|
||||
#
|
||||
# 4. Add unit tests in:
|
||||
# tests/custom/test_<your_component>.py
|
||||
#
|
||||
# External packages also work — use the full module path:
|
||||
# fast_gemm_v1: my_team.accel.fast_gemm:FastGemmComponent
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
components:
|
||||
# Infrastructure
|
||||
forwarding_v1: kernbench.components.builtin.forwarding:TransitComponent
|
||||
switch_v1: kernbench.components.builtin.forwarding:TransitComponent
|
||||
noc_v1: kernbench.components.builtin.forwarding:TransitComponent
|
||||
ucie_v1: kernbench.components.builtin.forwarding:TransitComponent
|
||||
noc_2d_mesh_v1: kernbench.components.builtin.noc:TwoDMeshNocComponent
|
||||
xbar_v1: kernbench.components.builtin.xbar:PositionAwareXbarComponent
|
||||
|
||||
# IO / Host interface
|
||||
pcie_ep_v1: kernbench.components.builtin.pcie_ep:PcieEpComponent
|
||||
io_cpu_v1: kernbench.components.builtin.io_cpu:IoCpuComponent
|
||||
|
||||
# Cube-level
|
||||
m_cpu_v1: kernbench.components.builtin.m_cpu:MCpuComponent
|
||||
hbm_ctrl_v1: kernbench.components.builtin.hbm_ctrl:HbmCtrlComponent
|
||||
sram_v1: kernbench.components.builtin.sram:SramComponent
|
||||
|
||||
# PE-level
|
||||
pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent
|
||||
pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent
|
||||
pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent
|
||||
pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent
|
||||
pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent
|
||||
pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent
|
||||
pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent
|
||||
|
||||
# Custom — add your implementations here
|
||||
# pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
|
||||
@@ -58,7 +58,7 @@ sufficient to execute kernels and issue DMA requests.
|
||||
- Mapping strategy based on `DPPolicy.cube`:
|
||||
- **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only.
|
||||
Each cube's PEs see only their local PA. No cross-cube mapping installed.
|
||||
- **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all
|
||||
- **Sharded** (`cube="column_wise"`, etc.): broadcast all shard mappings to all
|
||||
target cubes. Enables cross-PE and cross-cube DMA.
|
||||
|
||||
#### D3.4 Tensor Lifecycle
|
||||
|
||||
@@ -163,11 +163,11 @@ DefaultComponent ← 안전한 fallback
|
||||
## 슬라이드 7 — Registry 등록 방식
|
||||
|
||||
```python
|
||||
# kernbench/components/impls/__init__.py
|
||||
# kernbench/components/builtin/__init__.py
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.impls.io_cpu import IoCpuComponent
|
||||
from kernbench.components.builtin.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.builtin.io_cpu import IoCpuComponent
|
||||
# ...
|
||||
|
||||
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
|
||||
|
||||
@@ -140,16 +140,65 @@ class ComponentRegistry:
|
||||
|
||||
Resolution order for ComponentRegistry.create(node, overrides, ctx):
|
||||
1. overrides[node.impl] — caller-injected override
|
||||
2. _registry[node.impl] — globally registered impl
|
||||
2. _registry[node.impl] — globally registered impl (lazy import)
|
||||
3. Error — no fallback; every node must have an impl
|
||||
|
||||
Registry is populated from components.yaml via load_components_yaml().
|
||||
Manual register() is still supported for tests and overrides.
|
||||
"""
|
||||
|
||||
_registry: dict[str, type[ComponentBase]] = {}
|
||||
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
|
||||
_loaded: bool = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
|
||||
cls._registry[impl] = component_cls
|
||||
|
||||
@classmethod
|
||||
def load_components_yaml(cls, path: str | None = None) -> None:
|
||||
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
|
||||
if cls._loaded:
|
||||
return
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
if path is None:
|
||||
# Search: project root (cwd), then relative to this file
|
||||
candidates = [
|
||||
Path.cwd() / "components.yaml",
|
||||
Path(__file__).parent.parent.parent.parent / "components.yaml",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
path = str(p)
|
||||
break
|
||||
if path is None:
|
||||
return
|
||||
|
||||
with open(path) as f:
|
||||
spec = yaml.safe_load(f)
|
||||
for impl, class_path in (spec.get("components") or {}).items():
|
||||
cls._lazy[impl] = class_path
|
||||
cls._loaded = True
|
||||
|
||||
@classmethod
|
||||
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
|
||||
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
|
||||
if impl in cls._registry:
|
||||
return cls._registry[impl]
|
||||
if not cls._loaded:
|
||||
cls.load_components_yaml()
|
||||
class_path = cls._lazy.get(impl)
|
||||
if class_path is None:
|
||||
return None
|
||||
import importlib
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
component_cls = getattr(mod, class_name)
|
||||
cls._registry[impl] = component_cls # cache for next lookup
|
||||
return component_cls
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -159,9 +208,10 @@ class ComponentRegistry:
|
||||
) -> ComponentBase:
|
||||
if overrides and node.impl in overrides:
|
||||
return overrides[node.impl](node, ctx)
|
||||
if node.impl in cls._registry:
|
||||
return cls._registry[node.impl](node, ctx)
|
||||
component_cls = cls._resolve(node.impl)
|
||||
if component_cls is not None:
|
||||
return component_cls(node, ctx)
|
||||
raise ValueError(
|
||||
f"No component registered for impl '{node.impl}' (node: {node.id}). "
|
||||
f"Register it in kernbench.components.impls.__init__."
|
||||
f"Add it to components.yaml or call ComponentRegistry.register()."
|
||||
)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
|
||||
Manual imports are no longer needed — add new impls to components.yaml.
|
||||
|
||||
Classes are still importable from this package via lazy __getattr__.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
|
||||
ComponentRegistry.load_components_yaml()
|
||||
|
||||
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
|
||||
# without eagerly importing every module.
|
||||
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
|
||||
|
||||
|
||||
def _build_class_map() -> None:
|
||||
if _CLASS_MAP:
|
||||
return
|
||||
for class_path in ComponentRegistry._lazy.values():
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
_CLASS_MAP[class_name] = class_path
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
_build_class_map()
|
||||
class_path = _CLASS_MAP.get(name)
|
||||
if class_path is None:
|
||||
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
|
||||
import importlib
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, class_name)
|
||||
+14
-2
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||
|
||||
# Derive num_programs from the number of PE shards in this cube
|
||||
num_programs = 1
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
cube_pe_count = sum(
|
||||
1 for s in arg.shards
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx
|
||||
)
|
||||
if cube_pe_count > num_programs:
|
||||
num_programs = cube_pe_count
|
||||
|
||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → VA base (or PA fallback), ScalarArg → value
|
||||
# TensorArg → va_base (already local, set by runtime) or PA fallback
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Custom component implementations.
|
||||
|
||||
Place your component files here and register them in components.yaml.
|
||||
See components.yaml header for instructions.
|
||||
"""
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Each module registers its component(s) with ComponentRegistry on import.
|
||||
Import this package to activate all built-in implementations.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
|
||||
from kernbench.components.impls.io_cpu import IoCpuComponent
|
||||
from kernbench.components.impls.m_cpu import MCpuComponent
|
||||
from kernbench.components.impls.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.impls.pcie_ep import PcieEpComponent
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.impls.sram import SramComponent
|
||||
from kernbench.components.impls.xbar import PositionAwareXbarComponent
|
||||
|
||||
ComponentRegistry.register("forwarding_v1", TransitComponent)
|
||||
ComponentRegistry.register("switch_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
|
||||
ComponentRegistry.register("ucie_v1", TransitComponent)
|
||||
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
|
||||
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
|
||||
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
|
||||
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
|
||||
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
|
||||
ComponentRegistry.register("sram_v1", SramComponent)
|
||||
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
|
||||
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
|
||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||
|
||||
__all__ = [
|
||||
"HbmCtrlComponent",
|
||||
"IoCpuComponent",
|
||||
"MCpuComponent",
|
||||
"PcieEpComponent",
|
||||
"PeCpuComponent",
|
||||
"PeDmaComponent",
|
||||
"PeGemmComponent",
|
||||
"PeMathComponent",
|
||||
"PeMmuComponent",
|
||||
"PeSchedulerComponent",
|
||||
"PeTcmComponent",
|
||||
"TransitComponent",
|
||||
"TwoDMeshNocComponent",
|
||||
"PositionAwareXbarComponent",
|
||||
"SramComponent",
|
||||
]
|
||||
@@ -7,12 +7,36 @@ from typing import Literal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DPPolicy:
|
||||
"""Two-level data-parallel policy: cube-level + pe-level."""
|
||||
"""Three-level data-parallel policy: sip-level + cube-level + pe-level.
|
||||
|
||||
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate"
|
||||
Policies:
|
||||
- "replicate": full copy at each unit
|
||||
- "column_wise": split K (column) axis across units
|
||||
- "row_wise": split M (row) axis across units
|
||||
"""
|
||||
|
||||
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
|
||||
|
||||
def _split_shape(
|
||||
policy: str, shape: tuple[int, int], count: int, itemsize: int,
|
||||
) -> list[tuple[tuple[int, int], int]]:
|
||||
"""Split shape by policy into (sub_shape, byte_offset) pairs."""
|
||||
M, K = shape
|
||||
if policy == "replicate":
|
||||
return [((M, K), 0)] * count
|
||||
elif policy == "column_wise":
|
||||
chunk_k = K // count
|
||||
return [((M, chunk_k), i * M * chunk_k * itemsize) for i in range(count)]
|
||||
elif policy == "row_wise":
|
||||
chunk_m = M // count
|
||||
return [((chunk_m, K), i * chunk_m * K * itemsize) for i in range(count)]
|
||||
else:
|
||||
raise ValueError(f"Unknown policy: {policy}")
|
||||
|
||||
|
||||
def resolve_dp_policy(
|
||||
policy: DPPolicy,
|
||||
*,
|
||||
@@ -20,11 +44,14 @@ def resolve_dp_policy(
|
||||
itemsize: int,
|
||||
num_pe: int,
|
||||
num_cubes: int = 1,
|
||||
num_sips: int = 1,
|
||||
) -> list[ShardSpec]:
|
||||
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution.
|
||||
"""Resolve a DPPolicy into a list[ShardSpec] with three-level resolution.
|
||||
|
||||
Cube-level policy distributes across cubes, pe-level distributes within
|
||||
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id.
|
||||
SIP-level → cube-level → pe-level.
|
||||
num_cubes is cubes per SIP (not total).
|
||||
ShardSpec.pe_index uses flat indexing:
|
||||
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
|
||||
"""
|
||||
_PE_RESOLVERS = {
|
||||
"replicate": replicate,
|
||||
@@ -35,38 +62,29 @@ def resolve_dp_policy(
|
||||
if resolver is None:
|
||||
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
|
||||
|
||||
if num_cubes <= 1:
|
||||
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
|
||||
|
||||
# Two-level resolution: cube-level → pe-level
|
||||
M, K = shape
|
||||
cubes_per_sip = num_cubes
|
||||
all_shards: list[ShardSpec] = []
|
||||
|
||||
for cube_id in range(num_cubes):
|
||||
# Determine per-cube shape based on cube-level policy
|
||||
if policy.cube == "replicate":
|
||||
cube_shape = (M, K)
|
||||
cube_offset = 0
|
||||
elif policy.cube == "shard_m":
|
||||
chunk_m = M // num_cubes
|
||||
cube_shape = (chunk_m, K)
|
||||
cube_offset = cube_id * chunk_m * K * itemsize
|
||||
elif policy.cube == "shard_k":
|
||||
chunk_k = K // num_cubes
|
||||
cube_shape = (M, chunk_k)
|
||||
cube_offset = cube_id * M * chunk_k * itemsize
|
||||
else:
|
||||
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
|
||||
# Level 1: SIP
|
||||
sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize)
|
||||
|
||||
# Resolve pe-level within this cube's shape
|
||||
for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
|
||||
# Level 2: Cube within SIP
|
||||
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
|
||||
|
||||
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
|
||||
# Level 3: PE within cube
|
||||
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
|
||||
|
||||
# Remap pe_index to flat index and adjust offset
|
||||
for ps in pe_shards:
|
||||
flat_idx = cube_id * num_pe + ps.pe_index
|
||||
flat_idx = (
|
||||
sip_id * cubes_per_sip * num_pe
|
||||
+ cube_id * num_pe
|
||||
+ ps.pe_index
|
||||
)
|
||||
all_shards.append(ShardSpec(
|
||||
pe_index=flat_idx,
|
||||
offset_bytes=cube_offset + ps.offset_bytes,
|
||||
offset_bytes=sip_offset + cube_offset + ps.offset_bytes,
|
||||
nbytes=ps.nbytes,
|
||||
))
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ class RuntimeContext:
|
||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_va_allocator: Any = field(default=None, init=False)
|
||||
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
_tensors: list[Any] = field(default_factory=list, init=False)
|
||||
@@ -208,24 +207,17 @@ class RuntimeContext:
|
||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||
)
|
||||
|
||||
# Initialize VA allocator and per-PE MMUs
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
|
||||
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
|
||||
page_size = int(pe_mmu_attrs.get("page_size", 4096))
|
||||
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
|
||||
|
||||
self._va_allocator = VirtualAllocator(
|
||||
va_base=0x1_0000_0000,
|
||||
va_size=64 * (1 << 30), # 64 GB VA space
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pes = sip_count * cubes_per_sip * pes_per_cube
|
||||
for flat_idx in range(total_pes):
|
||||
self._mmus[flat_idx] = PeMMU(
|
||||
page_size=page_size, overhead_ns=tlb_overhead_ns,
|
||||
)
|
||||
|
||||
return self._allocators
|
||||
|
||||
@@ -276,11 +268,11 @@ class RuntimeContext:
|
||||
dp_policy = dp
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
|
||||
total_cubes = self._num_sips * self._num_cubes
|
||||
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
|
||||
num_sips=self._num_sips,
|
||||
)
|
||||
|
||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||
@@ -297,7 +289,6 @@ class RuntimeContext:
|
||||
placement=placement,
|
||||
allocators=allocators,
|
||||
va_allocator=self._va_allocator,
|
||||
mmus=self._mmus,
|
||||
)
|
||||
t._handle = handle
|
||||
import weakref
|
||||
@@ -305,6 +296,8 @@ class RuntimeContext:
|
||||
self._tensors.append(weakref.ref(t))
|
||||
|
||||
# Install VA→PA mappings via fabric MmuMapMsg
|
||||
# Strategy: always SIP-scoped (each SIP gets only its own shards).
|
||||
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
|
||||
if handle.va_base:
|
||||
from collections import defaultdict
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
@@ -313,13 +306,19 @@ class RuntimeContext:
|
||||
dp_policy is not None and dp_policy.cube == "replicate"
|
||||
)
|
||||
|
||||
if is_cube_replicate:
|
||||
# Replicate: each (sip, cube) gets only its own local PA mappings
|
||||
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||
# Group shards by SIP
|
||||
sip_groups: dict[int, list] = defaultdict(list)
|
||||
for shard in handle.shards:
|
||||
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||
sip_groups[shard.sip].append(shard)
|
||||
|
||||
for (sip, cube), group_shards in cube_groups.items():
|
||||
for sip, sip_shards in sip_groups.items():
|
||||
if is_cube_replicate:
|
||||
# Cube replicate: per-(sip, cube) local mapping
|
||||
cube_groups: dict[int, list] = defaultdict(list)
|
||||
for s in sip_shards:
|
||||
cube_groups[s.cube].append(s)
|
||||
|
||||
for cube, group_shards in cube_groups.items():
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
@@ -336,19 +335,18 @@ class RuntimeContext:
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
else:
|
||||
# Sharded: broadcast all mappings to all target (sip, cube)s
|
||||
# Cube sharded: broadcast all cubes within this SIP
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in handle.shards
|
||||
for s in sip_shards
|
||||
)
|
||||
sip_set = sorted({s.sip for s in handle.shards})
|
||||
cube_set = sorted({s.cube for s in handle.shards})
|
||||
cube_set = sorted({s.cube for s in sip_shards})
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}",
|
||||
request_id=f"mmu_{tensor_name}_s{sip}",
|
||||
entries=entries,
|
||||
target_sips=tuple(sip_set),
|
||||
target_sips=(sip,),
|
||||
target_cubes=tuple(cube_set),
|
||||
target_pe="all",
|
||||
)
|
||||
@@ -384,11 +382,18 @@ class RuntimeContext:
|
||||
|
||||
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
|
||||
Keyword args: become ScalarArg (name is discarded, order preserved).
|
||||
|
||||
Creates per-SIP KernelLaunchMsg with local va_base per tensor
|
||||
(like host driver sending per-rank launch commands).
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
ScalarArg,
|
||||
TensorArg,
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
from kernbench.triton_emu.registry import register_kernel
|
||||
@@ -399,14 +404,14 @@ class RuntimeContext:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Build kernel args from positional + keyword args
|
||||
kernel_args: list = []
|
||||
# Collect tensors and scalars
|
||||
tensor_args: list[Tensor] = []
|
||||
scalar_args: list = []
|
||||
target_pe: int | str = 0
|
||||
|
||||
for a in args:
|
||||
if isinstance(a, Tensor):
|
||||
kernel_args.append(a.to_tensor_arg())
|
||||
# Infer target_pe from tensor DP metadata
|
||||
tensor_args.append(a)
|
||||
if a._dp_metadata is not None:
|
||||
dp_target = a._dp_metadata.target_pe
|
||||
if dp_target == "all":
|
||||
@@ -415,34 +420,121 @@ class RuntimeContext:
|
||||
target_pe = dp_target
|
||||
elif isinstance(a, (int, float)):
|
||||
dtype_str = "f32" if isinstance(a, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
|
||||
for v in kwargs.values():
|
||||
if isinstance(v, (int, float)):
|
||||
dtype_str = "f32" if isinstance(v, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
|
||||
# Determine target cubes from all tensor shards
|
||||
cube_set: set[int] = set()
|
||||
# Determine all target SIPs from tensor shards
|
||||
sip_set: set[int] = set()
|
||||
for t in tensor_args:
|
||||
if t._handle is not None:
|
||||
for s in t._handle.shards:
|
||||
sip_set.add(s.sip)
|
||||
if not sip_set:
|
||||
sip_set = {0}
|
||||
|
||||
# Build global→local dimension mapping from tensor DPPolicies.
|
||||
# Scalar args matching a tensor's global dimension get replaced
|
||||
# with the cube-local value (what the kernel actually operates on).
|
||||
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
|
||||
"""Compute cube-local shape from DPPolicy."""
|
||||
shape = t.shape
|
||||
if len(shape) < 2:
|
||||
shape = (1, shape[0])
|
||||
M, K = shape[0], shape[1]
|
||||
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
|
||||
if dp is None:
|
||||
return t.shape
|
||||
if dp.sip != "replicate":
|
||||
if dp.sip == "column_wise":
|
||||
K = K // self._num_sips
|
||||
elif dp.sip == "row_wise":
|
||||
M = M // self._num_sips
|
||||
if dp.cube != "replicate":
|
||||
if dp.cube == "column_wise":
|
||||
K = K // self._num_cubes
|
||||
elif dp.cube == "row_wise":
|
||||
M = M // self._num_cubes
|
||||
if len(t.shape) < 2:
|
||||
return (K,)
|
||||
return (M, K)
|
||||
|
||||
dim_map: dict[int, int] = {} # global_dim → local_dim
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
|
||||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||||
last_handle = None
|
||||
for sip_id in sorted(sip_set):
|
||||
sip_kernel_args: list = []
|
||||
sip_cube_set: set[int] = set()
|
||||
|
||||
for t in tensor_args:
|
||||
if t._handle is None:
|
||||
continue
|
||||
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
|
||||
if not sip_shards:
|
||||
sip_shards = list(t._handle.shards)
|
||||
|
||||
local_va_base = 0
|
||||
if t._handle.va_base:
|
||||
min_offset = min(s.offset_bytes for s in sip_shards)
|
||||
local_va_base = t._handle.va_base + min_offset
|
||||
|
||||
sip_kernel_args.append(TensorArg(
|
||||
shards=tuple(
|
||||
TensorArgShard(
|
||||
sip=s.sip, cube=s.cube, pe=s.pe,
|
||||
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
|
||||
)
|
||||
for s in sip_shards
|
||||
),
|
||||
va_base=local_va_base,
|
||||
))
|
||||
|
||||
for s in sip_shards:
|
||||
sip_cube_set.add(s.cube)
|
||||
|
||||
# Interleave tensor args and scalar args, replacing global dims with local
|
||||
final_args: list = []
|
||||
t_idx, s_idx = 0, 0
|
||||
for a in args:
|
||||
if isinstance(a, Tensor) and a._handle is not None:
|
||||
for s in a._handle.shards:
|
||||
cube_set.add(s.cube)
|
||||
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
|
||||
if isinstance(a, Tensor):
|
||||
final_args.append(sip_kernel_args[t_idx])
|
||||
t_idx += 1
|
||||
elif isinstance(a, (int, float)):
|
||||
sa = scalar_args[s_idx]
|
||||
if isinstance(a, int) and a in dim_map:
|
||||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
|
||||
final_args.append(sa)
|
||||
s_idx += 1
|
||||
while s_idx < len(scalar_args):
|
||||
sa = scalar_args[s_idx]
|
||||
if isinstance(sa.value, int) and int(sa.value) in dim_map:
|
||||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
|
||||
final_args.append(sa)
|
||||
s_idx += 1
|
||||
|
||||
# Collect scalar values for GEMM FLOP calculation
|
||||
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
|
||||
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
|
||||
|
||||
h = self.submit(KernelLaunchMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=kernel_name,
|
||||
request_id=f"{kernel_name}_sip{sip_id}",
|
||||
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
|
||||
args=tuple(kernel_args),
|
||||
args=tuple(final_args),
|
||||
target_cubes=target_cubes,
|
||||
target_pe=target_pe,
|
||||
))
|
||||
self.wait(h, _meta={
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"target_pe": target_pe, "scalars": scalar_vals,
|
||||
"sip": sip_id, "target_pe": target_pe,
|
||||
})
|
||||
return h
|
||||
last_handle = h
|
||||
|
||||
return last_handle
|
||||
|
||||
@@ -59,10 +59,7 @@ def deploy_tensor(
|
||||
allocators: dict[int, PEMemAllocator],
|
||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||
va_allocator=None,
|
||||
mmus: dict | None = None,
|
||||
) -> TensorHandle:
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
|
||||
isize = dtype_itemsize(dtype)
|
||||
total_nbytes = math.prod(shape) * isize
|
||||
|
||||
@@ -78,22 +75,15 @@ def deploy_tensor(
|
||||
pa = alloc.alloc_hbm(spec.nbytes)
|
||||
else:
|
||||
pa = alloc.alloc_tcm(spec.nbytes)
|
||||
encoded_pa = pa.encode()
|
||||
shards.append(TensorShard(
|
||||
sip=alloc._sip_id,
|
||||
cube=alloc._cube_id,
|
||||
pe=alloc._pe_id,
|
||||
pa=encoded_pa,
|
||||
pa=pa.encode(),
|
||||
nbytes=spec.nbytes,
|
||||
offset_bytes=spec.offset_bytes,
|
||||
))
|
||||
|
||||
# Register VA→PA mapping in all MMUs (broadcast)
|
||||
if va_base and mmus is not None:
|
||||
shard_va = va_base + spec.offset_bytes
|
||||
for mmu in mmus.values():
|
||||
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
|
||||
|
||||
return TensorHandle(
|
||||
name=name,
|
||||
shape=shape,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
import simpy
|
||||
|
||||
from kernbench.common.types import Completion, RequestHandle, Trace
|
||||
import kernbench.components.impls # noqa: F401 — registers built-in implementations
|
||||
import kernbench.components.builtin # noqa: F401 — registers built-in implementations
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
|
||||
@@ -13,7 +13,7 @@ import simpy
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.components.builtin.forwarding import TransitComponent
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import MemoryReadMsg
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
@@ -73,7 +73,7 @@ def test_mmu_unmap_msg_fields():
|
||||
def test_pe_mmu_registry():
|
||||
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.components.builtin.pe_mmu import PeMmuComponent
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
node = Node(
|
||||
@@ -93,7 +93,7 @@ def test_pe_mmu_registry():
|
||||
def test_pe_mmu_processes_map_msg():
|
||||
"""PE_MMU component receives MmuMapMsg → translate works."""
|
||||
import simpy
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.components.builtin.pe_mmu import PeMmuComponent
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
@@ -152,7 +152,7 @@ def test_pe_dma_translates_va():
|
||||
# This test validates the interface contract. Full integration test
|
||||
# requires the engine wiring which is validated in test_engine.
|
||||
# Here we check that PE_DMA has an mmu attribute it can call.
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.builtin.pe_dma import PeDmaComponent
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
node = Node(
|
||||
|
||||
@@ -20,12 +20,12 @@ from kernbench.common.pe_commands import (
|
||||
TensorHandle,
|
||||
)
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.builtin.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.builtin.pe_dma import PeDmaComponent
|
||||
from kernbench.components.builtin.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.builtin.pe_math import PeMathComponent
|
||||
from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.builtin.pe_tcm import PeTcmComponent
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
@@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes():
|
||||
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
|
||||
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
|
||||
assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy)
|
||||
assert len(kernel_traces) == 1
|
||||
assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology)
|
||||
assert kernel_traces[0]["name"] == "qkv_gemm"
|
||||
assert kernel_traces[0]["total_ns"] > 0
|
||||
# Scalars should contain M, K, N
|
||||
assert len(kernel_traces[0]["scalars"]) >= 3
|
||||
|
||||
clear_registry()
|
||||
|
||||
@@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes():
|
||||
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
|
||||
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
|
||||
assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8
|
||||
assert len(kernel_traces) == 1
|
||||
assert len(kernel_traces) >= 1 # one per SIP
|
||||
assert kernel_traces[0]["target_pe"] == "all"
|
||||
|
||||
clear_registry()
|
||||
|
||||
@@ -19,7 +19,7 @@ import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.components.impls import (
|
||||
from kernbench.components.builtin import (
|
||||
HbmCtrlComponent,
|
||||
IoCpuComponent,
|
||||
MCpuComponent,
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tests for SIP-level tensor parallelism.
|
||||
|
||||
Validates:
|
||||
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
|
||||
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
|
||||
SP3. sip="row_wise": tensor M-axis split across SIPs
|
||||
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
|
||||
SP5. sip="replicate": all SIPs get full copy (existing behavior)
|
||||
SP6. PE_CPU sets num_programs from shard count per cube
|
||||
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
|
||||
|
||||
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_dp_policy_sip_default_replicate():
|
||||
"""DPPolicy without sip= defaults to 'replicate'."""
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "replicate"
|
||||
|
||||
|
||||
def test_dp_policy_sip_column_wise():
|
||||
"""DPPolicy accepts sip='column_wise'."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "column_wise"
|
||||
|
||||
|
||||
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sip_column_wise_splits_across_sips():
|
||||
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
)
|
||||
# 2 SIPs × 1 cube × 8 PEs = 16 shards
|
||||
assert len(shards) == 16
|
||||
|
||||
# SIP0 shards: first half of K (0 to K//2)
|
||||
# SIP1 shards: second half of K (K//2 to K)
|
||||
total_bytes = 128 * 256 * 2 # 64KB
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0 offsets start at 0
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
# SIP1 offsets start at half
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
# Total coverage
|
||||
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
|
||||
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
|
||||
|
||||
|
||||
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sip_row_wise_splits_across_sips():
|
||||
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
|
||||
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
)
|
||||
assert len(shards) == 16
|
||||
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0: rows 0..63, SIP1: rows 64..127
|
||||
total_bytes = 128 * 256 * 2
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
|
||||
# ── SP4. 3-level resolve ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_3level_resolve_flat_index():
|
||||
"""3-level: sip × cube × pe produces correct flat indices."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
|
||||
assert len(shards) == 32
|
||||
|
||||
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
|
||||
indices = [s.pe_index for s in shards]
|
||||
# SIP0: 0..15, SIP1: 16..31
|
||||
assert min(indices) == 0
|
||||
assert max(indices) == 31
|
||||
assert len(set(indices)) == 32 # all unique
|
||||
|
||||
|
||||
def test_3level_offsets_cover_full_tensor():
|
||||
"""3-level sharding covers the entire tensor with no gaps."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=4, num_cubes=1, num_sips=2,
|
||||
)
|
||||
# 2 SIPs × 1 cube × 4 PEs = 8 shards
|
||||
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
|
||||
total = 128 * 256 * 2
|
||||
# For non-replicate, total shard bytes == tensor bytes
|
||||
# (replicate within cube means cube shards overlap, but sip shards don't)
|
||||
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
|
||||
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
|
||||
assert sip0_bytes + sip1_bytes == total
|
||||
|
||||
|
||||
# ── SP5. sip="replicate" backward compat ─────────────────────────────
|
||||
|
||||
|
||||
def test_sip_replicate_backward_compat():
|
||||
"""sip='replicate' produces same result as before (2-level)."""
|
||||
dp_old = DPPolicy(cube="replicate", pe="column_wise")
|
||||
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
|
||||
|
||||
shards_old = resolve_dp_policy(
|
||||
dp_old, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
shards_new = resolve_dp_policy(
|
||||
dp_new, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
assert len(shards_old) == len(shards_new)
|
||||
for a, b in zip(shards_old, shards_new):
|
||||
assert a.pe_index == b.pe_index
|
||||
assert a.offset_bytes == b.offset_bytes
|
||||
assert a.nbytes == b.nbytes
|
||||
|
||||
|
||||
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_cpu_sets_num_programs():
|
||||
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
|
||||
# This test validates the interface contract.
|
||||
# After implementation, PE_CPU should derive num_programs from the
|
||||
# number of PE shards in the kernel launch's target cube.
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
# With 8 PEs per cube, num_programs should be 8
|
||||
tl = TLContext(pe_id=3, num_programs=8)
|
||||
assert tl.program_id(0) == 3
|
||||
assert tl.num_programs(0) == 8
|
||||
@@ -88,7 +88,6 @@ def test_deploy_tensor_assigns_va_base():
|
||||
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
@@ -98,7 +97,6 @@ def test_deploy_tensor_assigns_va_base():
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
assert th.va_base is not None
|
||||
@@ -109,7 +107,6 @@ def test_deploy_tensor_va_covers_all_shards():
|
||||
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
@@ -119,41 +116,32 @@ def test_deploy_tensor_va_covers_all_shards():
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Each shard's VA is derivable: va_base + offset_bytes
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert shard_va > 0
|
||||
|
||||
|
||||
def test_deploy_tensor_registers_mmu_mappings():
|
||||
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
|
||||
def test_deploy_tensor_does_not_install_mmu_mappings():
|
||||
"""deploy_tensor does NOT install MMU mappings — that's context's job."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Every MMU should have entries (broadcast)
|
||||
# No MMU should have any entries (mappings come from fabric MmuMapMsg)
|
||||
for mmu in mmus.values():
|
||||
assert mmu.num_entries > 0
|
||||
|
||||
# Each shard's derived VA should translate to its PA in every MMU
|
||||
for mmu in mmus.values():
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert mmu.translate(shard_va) == s.pa
|
||||
assert mmu.num_entries == 0
|
||||
|
||||
|
||||
# ── T12. Tensor.va property ──────────────────────────────────────────
|
||||
@@ -165,7 +153,6 @@ def test_tensor_va_property():
|
||||
|
||||
allocs = _make_allocators(1)
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus(1)
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
|
||||
|
||||
t = Tensor(shape=(2048,), dtype="f16", name="test")
|
||||
@@ -176,7 +163,6 @@ def test_tensor_va_property():
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
assert t.va > 0
|
||||
assert t.va == t._handle.va_base
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
"""VA offset verification: each PE accesses its own local HBM slice.
|
||||
|
||||
Verifies that column-wise sharding + VA offset calculation produces DMA
|
||||
addresses that translate to the correct PE's local HBM — not a remote PE.
|
||||
|
||||
Tests:
|
||||
VO1. Per-PE DMA addresses are correct VAs (2D)
|
||||
VO2. Each VA translates to the executing PE's own HBM slice (2D)
|
||||
VO3. End-to-end bench completes (2D, full TP)
|
||||
VO4. Per-PE DMA addresses are correct VAs (1D)
|
||||
VO5. Each VA translates to local HBM (1D)
|
||||
VO6. End-to-end 1D bench completes
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import DPPolicy, column_wise
|
||||
from kernbench.runtime_api.tensor import deploy_tensor
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.topology.builder import load_topology
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
M, K = 128, 256
|
||||
DTYPE = "f16"
|
||||
NUM_PE = 8
|
||||
ELEM_BYTES = 2
|
||||
|
||||
|
||||
def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
|
||||
"""Standard Triton 2D copy. M, K are cube-local."""
|
||||
pid = tl.program_id(0)
|
||||
num_pe = tl.num_programs(0)
|
||||
cols_per_pe = K // num_pe
|
||||
elem_bytes = 2
|
||||
offset = pid * M * cols_per_pe * elem_bytes
|
||||
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
|
||||
tl.store(dst_ptr + offset, data)
|
||||
|
||||
|
||||
def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"):
|
||||
"""Standard Triton 1D copy. N is cube-local."""
|
||||
pid = tl.program_id(0)
|
||||
num_pe = tl.num_programs(0)
|
||||
elems_per_pe = N // num_pe
|
||||
elem_bytes = 2
|
||||
offset = pid * elems_per_pe * elem_bytes
|
||||
data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE)
|
||||
tl.store(dst_ptr + offset, data)
|
||||
|
||||
|
||||
def _make_standalone(shape, num_pe=NUM_PE):
|
||||
"""Create standalone allocators + MMUs for unit testing."""
|
||||
cfg = AddressConfig(
|
||||
sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe,
|
||||
hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe,
|
||||
tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
allocators = {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
|
||||
mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)}
|
||||
return cfg, allocators, va_alloc, mmus
|
||||
|
||||
|
||||
# ── VO1. 2D: Per-PE DMA addresses are correct VAs ────────────────────
|
||||
|
||||
|
||||
def test_2d_each_pe_computes_correct_va_offset():
|
||||
"""2D: each PE generates DMA at va_base + pid * block_bytes."""
|
||||
src_va = 0x1_0000_0000
|
||||
dst_va = 0x2_0000_0000
|
||||
cols_per_pe = K // NUM_PE
|
||||
block_bytes = M * cols_per_pe * ELEM_BYTES
|
||||
|
||||
for pe_id in range(NUM_PE):
|
||||
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
|
||||
run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K)
|
||||
|
||||
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
|
||||
expected_offset = pe_id * block_bytes
|
||||
assert reads[0].src_addr == src_va + expected_offset
|
||||
assert writes[0].dst_addr == dst_va + expected_offset
|
||||
|
||||
|
||||
# ── VO2. 2D: Each VA translates to local HBM ─────────────────────────
|
||||
|
||||
|
||||
def test_2d_va_translates_to_local_hbm():
|
||||
"""2D: each PE's DMA VA translates to its own HBM slice."""
|
||||
cfg, allocators, va_alloc, mmus = _make_standalone((M, K))
|
||||
slice_size = cfg.hbm_slice_bytes
|
||||
cols_per_pe = K // NUM_PE
|
||||
block_bytes = M * cols_per_pe * ELEM_BYTES
|
||||
|
||||
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE)
|
||||
handle = deploy_tensor(
|
||||
name="src", shape=(M, K), dtype="fp16",
|
||||
placement=placement, allocators=allocators, va_allocator=va_alloc,
|
||||
)
|
||||
|
||||
# Install per-PE mappings (simulating what context does via MmuMapMsg)
|
||||
for s in handle.shards:
|
||||
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
|
||||
|
||||
for pe_id in range(NUM_PE):
|
||||
va = handle.va_base + pe_id * block_bytes
|
||||
pa = mmus[pe_id].translate(va)
|
||||
decoded = PhysAddr.decode(pa)
|
||||
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
|
||||
assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM"
|
||||
|
||||
|
||||
# ── VO3. 2D: End-to-end bench completes ──────────────────────────────
|
||||
|
||||
|
||||
def test_2d_bench_completes():
|
||||
"""2D: full TP bench with standard Triton kernel pattern."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine, target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="vo3", spec=graph.spec,
|
||||
)
|
||||
from benches.va_offset_verify import run as bench_run
|
||||
bench_run(ctx)
|
||||
ctx.wait_all()
|
||||
|
||||
|
||||
# ── VO4. 1D: Per-PE DMA addresses ────────────────────────────────────
|
||||
|
||||
N_1D = 1024
|
||||
|
||||
|
||||
def test_1d_each_pe_computes_correct_offset():
|
||||
"""1D: each PE generates DMA at correct offset."""
|
||||
src_va = 0x1_0000_0000
|
||||
dst_va = 0x2_0000_0000
|
||||
elems_per_pe = N_1D // NUM_PE
|
||||
block_bytes = elems_per_pe * ELEM_BYTES
|
||||
|
||||
for pe_id in range(NUM_PE):
|
||||
tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0)
|
||||
run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D)
|
||||
|
||||
reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
|
||||
expected_offset = pe_id * block_bytes
|
||||
assert reads[0].src_addr == src_va + expected_offset
|
||||
assert writes[0].dst_addr == dst_va + expected_offset
|
||||
|
||||
|
||||
# ── VO5. 1D: VA translates to local HBM ──────────────────────────────
|
||||
|
||||
|
||||
def test_1d_va_translates_to_local_hbm():
|
||||
"""1D: each PE's DMA VA translates to its own HBM slice."""
|
||||
cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D))
|
||||
slice_size = cfg.hbm_slice_bytes
|
||||
elems_per_pe = N_1D // NUM_PE
|
||||
block_bytes = elems_per_pe * ELEM_BYTES
|
||||
|
||||
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE)
|
||||
handle = deploy_tensor(
|
||||
name="src_1d", shape=(N_1D,), dtype="fp16",
|
||||
placement=placement, allocators=allocators, va_allocator=va_alloc,
|
||||
)
|
||||
|
||||
for s in handle.shards:
|
||||
mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes)
|
||||
|
||||
for pe_id in range(NUM_PE):
|
||||
va = handle.va_base + pe_id * block_bytes
|
||||
pa = mmus[pe_id].translate(va)
|
||||
decoded = PhysAddr.decode(pa)
|
||||
hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size)
|
||||
assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM"
|
||||
|
||||
|
||||
# ── VO6. 1D: End-to-end ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_1d_e2e_completes():
|
||||
"""1D: full engine run with column_wise TP sharding."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine, target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="vo6", spec=graph.spec,
|
||||
)
|
||||
|
||||
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
|
||||
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
|
||||
|
||||
# launch() auto-localizes N_1D → cube-local N
|
||||
ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D)
|
||||
ctx.wait_all()
|
||||
Reference in New Issue
Block a user