From 63669f82cb8645d9b59c90ff041686972c63ed68 Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Thu, 26 Mar 2026 01:13:17 -0700 Subject: [PATCH] Add SIP-level tensor parallelism, component registry YAML, VA offset verification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) --- benches/va_offset_verify.py | 42 ++++ components.yaml | 53 ++++ ...R-0011-memory-addressing-simplification.md | 2 +- docs/di-presentation.md | 6 +- src/kernbench/components/base.py | 58 ++++- src/kernbench/components/builtin/__init__.py | 34 +++ .../{impls => builtin}/forwarding.py | 0 .../components/{impls => builtin}/hbm_ctrl.py | 0 .../components/{impls => builtin}/io_cpu.py | 0 .../components/{impls => builtin}/m_cpu.py | 0 .../components/{impls => builtin}/noc.py | 0 .../components/{impls => builtin}/pcie_ep.py | 0 .../components/{impls => builtin}/pe_cpu.py | 16 +- .../components/{impls => builtin}/pe_dma.py | 0 .../components/{impls => builtin}/pe_gemm.py | 0 .../components/{impls => builtin}/pe_math.py | 0 .../components/{impls => builtin}/pe_mmu.py | 0 .../{impls => builtin}/pe_scheduler.py | 0 .../components/{impls => builtin}/pe_tcm.py | 0 .../components/{impls => builtin}/sram.py | 0 .../components/{impls => builtin}/xbar.py | 0 src/kernbench/components/custom/__init__.py | 5 + src/kernbench/components/impls/__init__.py | 59 ----- src/kernbench/policy/placement/dp.py | 88 ++++--- src/kernbench/runtime_api/context.py | 230 ++++++++++++------ src/kernbench/runtime_api/tensor.py | 12 +- src/kernbench/sim_engine/engine.py | 2 +- tests/custom/__init__.py | 0 tests/test_component_registry.py | 2 +- tests/test_mmu_component.py | 6 +- tests/test_pe_components.py | 18 +- tests/test_phase_a_components.py | 2 +- tests/test_sip_parallel.py | 157 ++++++++++++ tests/test_va_integration.py | 24 +- tests/test_va_offset.py | 216 ++++++++++++++++ 35 files changed, 813 insertions(+), 219 deletions(-) create mode 100644 benches/va_offset_verify.py create mode 100644 components.yaml create mode 100644 src/kernbench/components/builtin/__init__.py rename src/kernbench/components/{impls => builtin}/forwarding.py (100%) rename src/kernbench/components/{impls => builtin}/hbm_ctrl.py (100%) rename src/kernbench/components/{impls => builtin}/io_cpu.py (100%) rename src/kernbench/components/{impls => builtin}/m_cpu.py (100%) rename src/kernbench/components/{impls => builtin}/noc.py (100%) rename src/kernbench/components/{impls => builtin}/pcie_ep.py (100%) rename src/kernbench/components/{impls => builtin}/pe_cpu.py (91%) rename src/kernbench/components/{impls => builtin}/pe_dma.py (100%) rename src/kernbench/components/{impls => builtin}/pe_gemm.py (100%) rename src/kernbench/components/{impls => builtin}/pe_math.py (100%) rename src/kernbench/components/{impls => builtin}/pe_mmu.py (100%) rename src/kernbench/components/{impls => builtin}/pe_scheduler.py (100%) rename src/kernbench/components/{impls => builtin}/pe_tcm.py (100%) rename src/kernbench/components/{impls => builtin}/sram.py (100%) rename src/kernbench/components/{impls => builtin}/xbar.py (100%) create mode 100644 src/kernbench/components/custom/__init__.py delete mode 100644 src/kernbench/components/impls/__init__.py create mode 100644 tests/custom/__init__.py create mode 100644 tests/test_sip_parallel.py create mode 100644 tests/test_va_offset.py diff --git a/benches/va_offset_verify.py b/benches/va_offset_verify.py new file mode 100644 index 0000000..578ae30 --- /dev/null +++ b/benches/va_offset_verify.py @@ -0,0 +1,42 @@ +"""VA offset verification benchmark. + +Verifies that Triton-style base_ptr + pid * stride addressing works correctly +with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own +block from a sharded tensor and stores it back. + +The kernel uses standard Triton patterns: + - tl.program_id(0) for PE index within cube + - tl.num_programs(0) for PE count within cube + - Shape args are automatically localized by launch() +""" +from kernbench.policy.placement.dp import DPPolicy + +M, K = 128, 256 +DTYPE = "f16" + + +def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"): + """Standard Triton copy kernel. M and K are cube-local (set by launch).""" + pid = tl.program_id(0) + num_pe = tl.num_programs(0) + cols_per_pe = K // num_pe + elem_bytes = 2 # f16 + offset = pid * M * cols_per_pe * elem_bytes + data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE) + tl.store(dst_ptr + offset, data) + + +def run(torch): + """Run the VA offset verification benchmark with full TP sharding.""" + dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") + src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src") + dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst") + + # launch() automatically converts M, K to cube-local values + torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K) + + # Sanity check: kernel completed with non-zero latency + kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"] + assert len(kernel_traces) > 0, "No kernel traces recorded" + for kt in kernel_traces: + assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}" diff --git a/components.yaml b/components.yaml new file mode 100644 index 0000000..ee459d8 --- /dev/null +++ b/components.yaml @@ -0,0 +1,53 @@ +# Component implementation registry. +# Maps impl names (used in topology.yaml) to Python class paths. +# Format: impl_name: module.path:ClassName +# +# ── Adding custom components ────────────────────────────────────────── +# +# 1. Create your implementation in: +# src/kernbench/components/custom/.py +# +# Your class must inherit from ComponentBase (or PeEngineBase for PE engines). +# +# 2. Register it below under "Custom" with a unique impl name: +# my_pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent +# +# 3. Reference it in topology.yaml: +# pe_cpu: { kind: pe_cpu, impl: my_pe_cpu_v2, attrs: { ... } } +# +# 4. Add unit tests in: +# tests/custom/test_.py +# +# External packages also work — use the full module path: +# fast_gemm_v1: my_team.accel.fast_gemm:FastGemmComponent +# ────────────────────────────────────────────────────────────────────── + +components: + # Infrastructure + forwarding_v1: kernbench.components.builtin.forwarding:TransitComponent + switch_v1: kernbench.components.builtin.forwarding:TransitComponent + noc_v1: kernbench.components.builtin.forwarding:TransitComponent + ucie_v1: kernbench.components.builtin.forwarding:TransitComponent + noc_2d_mesh_v1: kernbench.components.builtin.noc:TwoDMeshNocComponent + xbar_v1: kernbench.components.builtin.xbar:PositionAwareXbarComponent + + # IO / Host interface + pcie_ep_v1: kernbench.components.builtin.pcie_ep:PcieEpComponent + io_cpu_v1: kernbench.components.builtin.io_cpu:IoCpuComponent + + # Cube-level + m_cpu_v1: kernbench.components.builtin.m_cpu:MCpuComponent + hbm_ctrl_v1: kernbench.components.builtin.hbm_ctrl:HbmCtrlComponent + sram_v1: kernbench.components.builtin.sram:SramComponent + + # PE-level + pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent + pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent + pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent + pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent + pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent + pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent + pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent + + # Custom — add your implementations here + # pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent diff --git a/docs/adr/ADR-0011-memory-addressing-simplification.md b/docs/adr/ADR-0011-memory-addressing-simplification.md index 96a4c97..0ef9330 100644 --- a/docs/adr/ADR-0011-memory-addressing-simplification.md +++ b/docs/adr/ADR-0011-memory-addressing-simplification.md @@ -58,7 +58,7 @@ sufficient to execute kernels and issue DMA requests. - Mapping strategy based on `DPPolicy.cube`: - **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only. Each cube's PEs see only their local PA. No cross-cube mapping installed. - - **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all + - **Sharded** (`cube="column_wise"`, etc.): broadcast all shard mappings to all target cubes. Enables cross-PE and cross-cube DMA. #### D3.4 Tensor Lifecycle diff --git a/docs/di-presentation.md b/docs/di-presentation.md index 5f64572..3f1145a 100644 --- a/docs/di-presentation.md +++ b/docs/di-presentation.md @@ -163,11 +163,11 @@ DefaultComponent ← 안전한 fallback ## 슬라이드 7 — Registry 등록 방식 ```python -# kernbench/components/impls/__init__.py +# kernbench/components/builtin/__init__.py from kernbench.components.base import ComponentRegistry -from kernbench.components.impls.noc import TwoDMeshNocComponent -from kernbench.components.impls.io_cpu import IoCpuComponent +from kernbench.components.builtin.noc import TwoDMeshNocComponent +from kernbench.components.builtin.io_cpu import IoCpuComponent # ... ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent) diff --git a/src/kernbench/components/base.py b/src/kernbench/components/base.py index 5d633d8..58ec12c 100644 --- a/src/kernbench/components/base.py +++ b/src/kernbench/components/base.py @@ -140,16 +140,65 @@ class ComponentRegistry: Resolution order for ComponentRegistry.create(node, overrides, ctx): 1. overrides[node.impl] — caller-injected override - 2. _registry[node.impl] — globally registered impl + 2. _registry[node.impl] — globally registered impl (lazy import) 3. Error — no fallback; every node must have an impl + + Registry is populated from components.yaml via load_components_yaml(). + Manual register() is still supported for tests and overrides. """ _registry: dict[str, type[ComponentBase]] = {} + _lazy: dict[str, str] = {} # impl → "module.path:ClassName" + _loaded: bool = False @classmethod def register(cls, impl: str, component_cls: type[ComponentBase]) -> None: cls._registry[impl] = component_cls + @classmethod + def load_components_yaml(cls, path: str | None = None) -> None: + """Load impl→class mappings from components.yaml. Lazy imports on first use.""" + if cls._loaded: + return + import yaml + from pathlib import Path + + if path is None: + # Search: project root (cwd), then relative to this file + candidates = [ + Path.cwd() / "components.yaml", + Path(__file__).parent.parent.parent.parent / "components.yaml", + ] + for p in candidates: + if p.exists(): + path = str(p) + break + if path is None: + return + + with open(path) as f: + spec = yaml.safe_load(f) + for impl, class_path in (spec.get("components") or {}).items(): + cls._lazy[impl] = class_path + cls._loaded = True + + @classmethod + def _resolve(cls, impl: str) -> type[ComponentBase] | None: + """Resolve impl name: check _registry first, then lazy import from _lazy.""" + if impl in cls._registry: + return cls._registry[impl] + if not cls._loaded: + cls.load_components_yaml() + class_path = cls._lazy.get(impl) + if class_path is None: + return None + import importlib + module_path, class_name = class_path.rsplit(":", 1) + mod = importlib.import_module(module_path) + component_cls = getattr(mod, class_name) + cls._registry[impl] = component_cls # cache for next lookup + return component_cls + @classmethod def create( cls, @@ -159,9 +208,10 @@ class ComponentRegistry: ) -> ComponentBase: if overrides and node.impl in overrides: return overrides[node.impl](node, ctx) - if node.impl in cls._registry: - return cls._registry[node.impl](node, ctx) + component_cls = cls._resolve(node.impl) + if component_cls is not None: + return component_cls(node, ctx) raise ValueError( f"No component registered for impl '{node.impl}' (node: {node.id}). " - f"Register it in kernbench.components.impls.__init__." + f"Add it to components.yaml or call ComponentRegistry.register()." ) diff --git a/src/kernbench/components/builtin/__init__.py b/src/kernbench/components/builtin/__init__.py new file mode 100644 index 0000000..9e2e26b --- /dev/null +++ b/src/kernbench/components/builtin/__init__.py @@ -0,0 +1,34 @@ +"""Concrete component implementations. + +Loaded from components.yaml via ComponentRegistry.load_components_yaml(). +Manual imports are no longer needed — add new impls to components.yaml. + +Classes are still importable from this package via lazy __getattr__. +""" + +from kernbench.components.base import ComponentRegistry + +ComponentRegistry.load_components_yaml() + +# Lazy re-export: allow `from kernbench.components.builtin import FooComponent` +# without eagerly importing every module. +_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName" + + +def _build_class_map() -> None: + if _CLASS_MAP: + return + for class_path in ComponentRegistry._lazy.values(): + module_path, class_name = class_path.rsplit(":", 1) + _CLASS_MAP[class_name] = class_path + + +def __getattr__(name: str): + _build_class_map() + class_path = _CLASS_MAP.get(name) + if class_path is None: + raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'") + import importlib + module_path, class_name = class_path.rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name) diff --git a/src/kernbench/components/impls/forwarding.py b/src/kernbench/components/builtin/forwarding.py similarity index 100% rename from src/kernbench/components/impls/forwarding.py rename to src/kernbench/components/builtin/forwarding.py diff --git a/src/kernbench/components/impls/hbm_ctrl.py b/src/kernbench/components/builtin/hbm_ctrl.py similarity index 100% rename from src/kernbench/components/impls/hbm_ctrl.py rename to src/kernbench/components/builtin/hbm_ctrl.py diff --git a/src/kernbench/components/impls/io_cpu.py b/src/kernbench/components/builtin/io_cpu.py similarity index 100% rename from src/kernbench/components/impls/io_cpu.py rename to src/kernbench/components/builtin/io_cpu.py diff --git a/src/kernbench/components/impls/m_cpu.py b/src/kernbench/components/builtin/m_cpu.py similarity index 100% rename from src/kernbench/components/impls/m_cpu.py rename to src/kernbench/components/builtin/m_cpu.py diff --git a/src/kernbench/components/impls/noc.py b/src/kernbench/components/builtin/noc.py similarity index 100% rename from src/kernbench/components/impls/noc.py rename to src/kernbench/components/builtin/noc.py diff --git a/src/kernbench/components/impls/pcie_ep.py b/src/kernbench/components/builtin/pcie_ep.py similarity index 100% rename from src/kernbench/components/impls/pcie_ep.py rename to src/kernbench/components/builtin/pcie_ep.py diff --git a/src/kernbench/components/impls/pe_cpu.py b/src/kernbench/components/builtin/pe_cpu.py similarity index 91% rename from src/kernbench/components/impls/pe_cpu.py rename to src/kernbench/components/builtin/pe_cpu.py index 34fcf8e..f2e3c7b 100644 --- a/src/kernbench/components/impls/pe_cpu.py +++ b/src/kernbench/components/builtin/pe_cpu.py @@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase): yield from self.run(env, 0) kernel_fn = get_kernel(request.kernel_ref.name) - tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0) + + # Derive num_programs from the number of PE shards in this cube + num_programs = 1 + for arg in request.args: + if arg.arg_kind == "tensor": + cube_pe_count = sum( + 1 for s in arg.shards + if s.sip == self._sip_idx and s.cube == self._cube_idx + ) + if cube_pe_count > num_programs: + num_programs = cube_pe_count + + tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) # Unpack KernelLaunchMsg.args into positional args for kernel function - # TensorArg → VA base (or PA fallback), ScalarArg → value + # TensorArg → va_base (already local, set by runtime) or PA fallback kernel_args: list = [] for arg in request.args: if arg.arg_kind == "tensor": diff --git a/src/kernbench/components/impls/pe_dma.py b/src/kernbench/components/builtin/pe_dma.py similarity index 100% rename from src/kernbench/components/impls/pe_dma.py rename to src/kernbench/components/builtin/pe_dma.py diff --git a/src/kernbench/components/impls/pe_gemm.py b/src/kernbench/components/builtin/pe_gemm.py similarity index 100% rename from src/kernbench/components/impls/pe_gemm.py rename to src/kernbench/components/builtin/pe_gemm.py diff --git a/src/kernbench/components/impls/pe_math.py b/src/kernbench/components/builtin/pe_math.py similarity index 100% rename from src/kernbench/components/impls/pe_math.py rename to src/kernbench/components/builtin/pe_math.py diff --git a/src/kernbench/components/impls/pe_mmu.py b/src/kernbench/components/builtin/pe_mmu.py similarity index 100% rename from src/kernbench/components/impls/pe_mmu.py rename to src/kernbench/components/builtin/pe_mmu.py diff --git a/src/kernbench/components/impls/pe_scheduler.py b/src/kernbench/components/builtin/pe_scheduler.py similarity index 100% rename from src/kernbench/components/impls/pe_scheduler.py rename to src/kernbench/components/builtin/pe_scheduler.py diff --git a/src/kernbench/components/impls/pe_tcm.py b/src/kernbench/components/builtin/pe_tcm.py similarity index 100% rename from src/kernbench/components/impls/pe_tcm.py rename to src/kernbench/components/builtin/pe_tcm.py diff --git a/src/kernbench/components/impls/sram.py b/src/kernbench/components/builtin/sram.py similarity index 100% rename from src/kernbench/components/impls/sram.py rename to src/kernbench/components/builtin/sram.py diff --git a/src/kernbench/components/impls/xbar.py b/src/kernbench/components/builtin/xbar.py similarity index 100% rename from src/kernbench/components/impls/xbar.py rename to src/kernbench/components/builtin/xbar.py diff --git a/src/kernbench/components/custom/__init__.py b/src/kernbench/components/custom/__init__.py new file mode 100644 index 0000000..a7998ae --- /dev/null +++ b/src/kernbench/components/custom/__init__.py @@ -0,0 +1,5 @@ +"""Custom component implementations. + +Place your component files here and register them in components.yaml. +See components.yaml header for instructions. +""" diff --git a/src/kernbench/components/impls/__init__.py b/src/kernbench/components/impls/__init__.py deleted file mode 100644 index cd170ef..0000000 --- a/src/kernbench/components/impls/__init__.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Concrete component implementations. - -Each module registers its component(s) with ComponentRegistry on import. -Import this package to activate all built-in implementations. -""" - -from kernbench.components.base import ComponentRegistry -from kernbench.components.impls.forwarding import TransitComponent -from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent -from kernbench.components.impls.io_cpu import IoCpuComponent -from kernbench.components.impls.m_cpu import MCpuComponent -from kernbench.components.impls.noc import TwoDMeshNocComponent -from kernbench.components.impls.pcie_ep import PcieEpComponent -from kernbench.components.impls.pe_cpu import PeCpuComponent -from kernbench.components.impls.pe_dma import PeDmaComponent -from kernbench.components.impls.pe_gemm import PeGemmComponent -from kernbench.components.impls.pe_math import PeMathComponent -from kernbench.components.impls.pe_scheduler import PeSchedulerComponent -from kernbench.components.impls.pe_mmu import PeMmuComponent -from kernbench.components.impls.pe_tcm import PeTcmComponent -from kernbench.components.impls.sram import SramComponent -from kernbench.components.impls.xbar import PositionAwareXbarComponent - -ComponentRegistry.register("forwarding_v1", TransitComponent) -ComponentRegistry.register("switch_v1", TransitComponent) -ComponentRegistry.register("noc_v1", TransitComponent) -ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent) -ComponentRegistry.register("ucie_v1", TransitComponent) -ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent) -ComponentRegistry.register("pcie_ep_v1", PcieEpComponent) -ComponentRegistry.register("io_cpu_v1", IoCpuComponent) -ComponentRegistry.register("m_cpu_v1", MCpuComponent) -ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent) -ComponentRegistry.register("sram_v1", SramComponent) -ComponentRegistry.register("pe_cpu_v1", PeCpuComponent) -ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent) -ComponentRegistry.register("pe_dma_v1", PeDmaComponent) -ComponentRegistry.register("pe_gemm_v1", PeGemmComponent) -ComponentRegistry.register("pe_math_v1", PeMathComponent) -ComponentRegistry.register("pe_mmu_v1", PeMmuComponent) -ComponentRegistry.register("pe_tcm_v1", PeTcmComponent) - -__all__ = [ - "HbmCtrlComponent", - "IoCpuComponent", - "MCpuComponent", - "PcieEpComponent", - "PeCpuComponent", - "PeDmaComponent", - "PeGemmComponent", - "PeMathComponent", - "PeMmuComponent", - "PeSchedulerComponent", - "PeTcmComponent", - "TransitComponent", - "TwoDMeshNocComponent", - "PositionAwareXbarComponent", - "SramComponent", -] diff --git a/src/kernbench/policy/placement/dp.py b/src/kernbench/policy/placement/dp.py index 8860d7f..705b791 100644 --- a/src/kernbench/policy/placement/dp.py +++ b/src/kernbench/policy/placement/dp.py @@ -7,12 +7,36 @@ from typing import Literal @dataclass(frozen=True) class DPPolicy: - """Two-level data-parallel policy: cube-level + pe-level.""" + """Three-level data-parallel policy: sip-level + cube-level + pe-level. - cube: Literal["replicate", "shard_m", "shard_k"] = "replicate" + Policies: + - "replicate": full copy at each unit + - "column_wise": split K (column) axis across units + - "row_wise": split M (row) axis across units + """ + + sip: Literal["replicate", "column_wise", "row_wise"] = "replicate" + cube: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" +def _split_shape( + policy: str, shape: tuple[int, int], count: int, itemsize: int, +) -> list[tuple[tuple[int, int], int]]: + """Split shape by policy into (sub_shape, byte_offset) pairs.""" + M, K = shape + if policy == "replicate": + return [((M, K), 0)] * count + elif policy == "column_wise": + chunk_k = K // count + return [((M, chunk_k), i * M * chunk_k * itemsize) for i in range(count)] + elif policy == "row_wise": + chunk_m = M // count + return [((chunk_m, K), i * chunk_m * K * itemsize) for i in range(count)] + else: + raise ValueError(f"Unknown policy: {policy}") + + def resolve_dp_policy( policy: DPPolicy, *, @@ -20,11 +44,14 @@ def resolve_dp_policy( itemsize: int, num_pe: int, num_cubes: int = 1, + num_sips: int = 1, ) -> list[ShardSpec]: - """Resolve a DPPolicy into a list[ShardSpec] with two-level resolution. + """Resolve a DPPolicy into a list[ShardSpec] with three-level resolution. - Cube-level policy distributes across cubes, pe-level distributes within - each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id. + SIP-level → cube-level → pe-level. + num_cubes is cubes per SIP (not total). + ShardSpec.pe_index uses flat indexing: + sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id """ _PE_RESOLVERS = { "replicate": replicate, @@ -35,40 +62,31 @@ def resolve_dp_policy( if resolver is None: raise ValueError(f"Unknown pe-level policy: {policy.pe}") - if num_cubes <= 1: - return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe) - - # Two-level resolution: cube-level → pe-level - M, K = shape + cubes_per_sip = num_cubes all_shards: list[ShardSpec] = [] - for cube_id in range(num_cubes): - # Determine per-cube shape based on cube-level policy - if policy.cube == "replicate": - cube_shape = (M, K) - cube_offset = 0 - elif policy.cube == "shard_m": - chunk_m = M // num_cubes - cube_shape = (chunk_m, K) - cube_offset = cube_id * chunk_m * K * itemsize - elif policy.cube == "shard_k": - chunk_k = K // num_cubes - cube_shape = (M, chunk_k) - cube_offset = cube_id * M * chunk_k * itemsize - else: - raise ValueError(f"Unknown cube-level policy: {policy.cube}") + # Level 1: SIP + sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize) - # Resolve pe-level within this cube's shape - pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) + for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits): + # Level 2: Cube within SIP + cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize) - # Remap pe_index to flat index and adjust offset - for ps in pe_shards: - flat_idx = cube_id * num_pe + ps.pe_index - all_shards.append(ShardSpec( - pe_index=flat_idx, - offset_bytes=cube_offset + ps.offset_bytes, - nbytes=ps.nbytes, - )) + for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits): + # Level 3: PE within cube + pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) + + for ps in pe_shards: + flat_idx = ( + sip_id * cubes_per_sip * num_pe + + cube_id * num_pe + + ps.pe_index + ) + all_shards.append(ShardSpec( + pe_index=flat_idx, + offset_bytes=sip_offset + cube_offset + ps.offset_bytes, + nbytes=ps.nbytes, + )) return all_shards diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index 021babe..7e94877 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -20,7 +20,6 @@ class RuntimeContext: _completed: set[RequestHandle] = field(default_factory=set, init=False) _allocators: dict[int, Any] = field(default_factory=dict, init=False) _va_allocator: Any = field(default=None, init=False) - _mmus: dict[int, Any] = field(default_factory=dict, init=False) _tensor_counter: int = field(default=0, init=False) _traces: list[dict] = field(default_factory=list, init=False) _tensors: list[Any] = field(default_factory=list, init=False) @@ -208,24 +207,17 @@ class RuntimeContext: rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg, ) - # Initialize VA allocator and per-PE MMUs - from kernbench.policy.address.pe_mmu import PeMMU + # Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg) from kernbench.policy.address.va_allocator import VirtualAllocator pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {}) page_size = int(pe_mmu_attrs.get("page_size", 4096)) - tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0)) self._va_allocator = VirtualAllocator( va_base=0x1_0000_0000, va_size=64 * (1 << 30), # 64 GB VA space page_size=page_size, ) - total_pes = sip_count * cubes_per_sip * pes_per_cube - for flat_idx in range(total_pes): - self._mmus[flat_idx] = PeMMU( - page_size=page_size, overhead_ns=tlb_overhead_ns, - ) return self._allocators @@ -276,11 +268,11 @@ class RuntimeContext: dp_policy = dp allocators = self._ensure_allocators() itemsize = dtype_itemsize(dtype) - shape_2d = (shape[0], shape[1]) # type: tuple[int, int] - total_cubes = self._num_sips * self._num_cubes + shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0]) placement = resolve_dp_policy( dp, shape=shape_2d, itemsize=itemsize, - num_pe=self._pes_per_cube, num_cubes=total_cubes, + num_pe=self._pes_per_cube, num_cubes=self._num_cubes, + num_sips=self._num_sips, ) # Infer target_pe from placement: multi-PE → "all", single PE → pe_index @@ -297,7 +289,6 @@ class RuntimeContext: placement=placement, allocators=allocators, va_allocator=self._va_allocator, - mmus=self._mmus, ) t._handle = handle import weakref @@ -305,6 +296,8 @@ class RuntimeContext: self._tensors.append(weakref.ref(t)) # Install VA→PA mappings via fabric MmuMapMsg + # Strategy: always SIP-scoped (each SIP gets only its own shards). + # Within each SIP: cube="replicate" → per-cube, else broadcast within SIP. if handle.va_base: from collections import defaultdict from kernbench.runtime_api.kernel import MmuMapMsg @@ -313,47 +306,52 @@ class RuntimeContext: dp_policy is not None and dp_policy.cube == "replicate" ) - if is_cube_replicate: - # Replicate: each (sip, cube) gets only its own local PA mappings - cube_groups: dict[tuple[int, int], list] = defaultdict(list) - for shard in handle.shards: - cube_groups[(shard.sip, shard.cube)].append(shard) + # Group shards by SIP + sip_groups: dict[int, list] = defaultdict(list) + for shard in handle.shards: + sip_groups[shard.sip].append(shard) - for (sip, cube), group_shards in cube_groups.items(): + for sip, sip_shards in sip_groups.items(): + if is_cube_replicate: + # Cube replicate: per-(sip, cube) local mapping + cube_groups: dict[int, list] = defaultdict(list) + for s in sip_shards: + cube_groups[s.cube].append(s) + + for cube, group_shards in cube_groups.items(): + entries = tuple( + {"va": handle.va_base + s.offset_bytes, + "pa": s.pa, "size": s.nbytes} + for s in group_shards + ) + msg = MmuMapMsg( + correlation_id=self.correlation_id, + request_id=f"mmu_{tensor_name}_s{sip}c{cube}", + entries=entries, + target_sips=(sip,), + target_cubes=(cube,), + target_pe="all", + ) + h = self.submit(msg) + self.wait(h) + else: + # Cube sharded: broadcast all cubes within this SIP entries = tuple( {"va": handle.va_base + s.offset_bytes, "pa": s.pa, "size": s.nbytes} - for s in group_shards + for s in sip_shards ) + cube_set = sorted({s.cube for s in sip_shards}) msg = MmuMapMsg( correlation_id=self.correlation_id, - request_id=f"mmu_{tensor_name}_s{sip}c{cube}", + request_id=f"mmu_{tensor_name}_s{sip}", entries=entries, target_sips=(sip,), - target_cubes=(cube,), + target_cubes=tuple(cube_set), target_pe="all", ) h = self.submit(msg) self.wait(h) - else: - # Sharded: broadcast all mappings to all target (sip, cube)s - entries = tuple( - {"va": handle.va_base + s.offset_bytes, - "pa": s.pa, "size": s.nbytes} - for s in handle.shards - ) - sip_set = sorted({s.sip for s in handle.shards}) - cube_set = sorted({s.cube for s in handle.shards}) - msg = MmuMapMsg( - correlation_id=self.correlation_id, - request_id=f"mmu_{tensor_name}", - entries=entries, - target_sips=tuple(sip_set), - target_cubes=tuple(cube_set), - target_pe="all", - ) - h = self.submit(msg) - self.wait(h) # Submit MemoryWriteMsg per shard (deploy data to device) if pattern is not None: @@ -384,11 +382,18 @@ class RuntimeContext: Positional args: Tensor objects become TensorArg, int/float become ScalarArg. Keyword args: become ScalarArg (name is discarded, order preserved). + + Creates per-SIP KernelLaunchMsg with local va_base per tensor + (like host driver sending per-rank launch commands). """ + from collections import defaultdict + from kernbench.runtime_api.kernel import ( KernelLaunchMsg, KernelRef, ScalarArg, + TensorArg, + TensorArgShard, ) from kernbench.runtime_api.tensor import Tensor from kernbench.triton_emu.registry import register_kernel @@ -399,14 +404,14 @@ class RuntimeContext: except ValueError: pass - # Build kernel args from positional + keyword args - kernel_args: list = [] + # Collect tensors and scalars + tensor_args: list[Tensor] = [] + scalar_args: list = [] target_pe: int | str = 0 for a in args: if isinstance(a, Tensor): - kernel_args.append(a.to_tensor_arg()) - # Infer target_pe from tensor DP metadata + tensor_args.append(a) if a._dp_metadata is not None: dp_target = a._dp_metadata.target_pe if dp_target == "all": @@ -415,34 +420,121 @@ class RuntimeContext: target_pe = dp_target elif isinstance(a, (int, float)): dtype_str = "f32" if isinstance(a, float) else "i32" - kernel_args.append(ScalarArg(dtype=dtype_str, value=a)) + scalar_args.append(ScalarArg(dtype=dtype_str, value=a)) for v in kwargs.values(): if isinstance(v, (int, float)): dtype_str = "f32" if isinstance(v, float) else "i32" - kernel_args.append(ScalarArg(dtype=dtype_str, value=v)) + scalar_args.append(ScalarArg(dtype=dtype_str, value=v)) - # Determine target cubes from all tensor shards - cube_set: set[int] = set() - for a in args: - if isinstance(a, Tensor) and a._handle is not None: - for s in a._handle.shards: - cube_set.add(s.cube) - target_cubes = tuple(sorted(cube_set)) if cube_set else (0,) + # Determine all target SIPs from tensor shards + sip_set: set[int] = set() + for t in tensor_args: + if t._handle is not None: + for s in t._handle.shards: + sip_set.add(s.sip) + if not sip_set: + sip_set = {0} - # Collect scalar values for GEMM FLOP calculation - scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")] + # Build global→local dimension mapping from tensor DPPolicies. + # Scalar args matching a tensor's global dimension get replaced + # with the cube-local value (what the kernel actually operates on). + def _compute_local_shape(t: Tensor) -> tuple[int, ...]: + """Compute cube-local shape from DPPolicy.""" + shape = t.shape + if len(shape) < 2: + shape = (1, shape[0]) + M, K = shape[0], shape[1] + dp = t._dp_metadata.dp_policy if t._dp_metadata else None + if dp is None: + return t.shape + if dp.sip != "replicate": + if dp.sip == "column_wise": + K = K // self._num_sips + elif dp.sip == "row_wise": + M = M // self._num_sips + if dp.cube != "replicate": + if dp.cube == "column_wise": + K = K // self._num_cubes + elif dp.cube == "row_wise": + M = M // self._num_cubes + if len(t.shape) < 2: + return (K,) + return (M, K) - h = self.submit(KernelLaunchMsg( - correlation_id=self.correlation_id, - request_id=kernel_name, - kernel_ref=KernelRef(name=kernel_name, kind="builtin"), - args=tuple(kernel_args), - target_cubes=target_cubes, - target_pe=target_pe, - )) - self.wait(h, _meta={ - "phase": "kernel", "name": kernel_name, - "target_pe": target_pe, "scalars": scalar_vals, - }) - return h + dim_map: dict[int, int] = {} # global_dim → local_dim + for t in tensor_args: + local = _compute_local_shape(t) + for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])): + if g != l: + dim_map[g] = l + + # Per-SIP kernel launch: each SIP gets TensorArgs with local va_base + last_handle = None + for sip_id in sorted(sip_set): + sip_kernel_args: list = [] + sip_cube_set: set[int] = set() + + for t in tensor_args: + if t._handle is None: + continue + sip_shards = [s for s in t._handle.shards if s.sip == sip_id] + if not sip_shards: + sip_shards = list(t._handle.shards) + + local_va_base = 0 + if t._handle.va_base: + min_offset = min(s.offset_bytes for s in sip_shards) + local_va_base = t._handle.va_base + min_offset + + sip_kernel_args.append(TensorArg( + shards=tuple( + TensorArgShard( + sip=s.sip, cube=s.cube, pe=s.pe, + pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes, + ) + for s in sip_shards + ), + va_base=local_va_base, + )) + + for s in sip_shards: + sip_cube_set.add(s.cube) + + # Interleave tensor args and scalar args, replacing global dims with local + final_args: list = [] + t_idx, s_idx = 0, 0 + for a in args: + if isinstance(a, Tensor): + final_args.append(sip_kernel_args[t_idx]) + t_idx += 1 + elif isinstance(a, (int, float)): + sa = scalar_args[s_idx] + if isinstance(a, int) and a in dim_map: + sa = ScalarArg(dtype=sa.dtype, value=dim_map[a]) + final_args.append(sa) + s_idx += 1 + while s_idx < len(scalar_args): + sa = scalar_args[s_idx] + if isinstance(sa.value, int) and int(sa.value) in dim_map: + sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)]) + final_args.append(sa) + s_idx += 1 + + target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,) + + h = self.submit(KernelLaunchMsg( + correlation_id=self.correlation_id, + request_id=f"{kernel_name}_sip{sip_id}", + kernel_ref=KernelRef(name=kernel_name, kind="builtin"), + args=tuple(final_args), + target_cubes=target_cubes, + target_pe=target_pe, + )) + self.wait(h, _meta={ + "phase": "kernel", "name": kernel_name, + "sip": sip_id, "target_pe": target_pe, + }) + last_handle = h + + return last_handle diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 4dde44f..51369fe 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -59,10 +59,7 @@ def deploy_tensor( allocators: dict[int, PEMemAllocator], mem_kind: Literal["hbm", "tcm"] = "hbm", va_allocator=None, - mmus: dict | None = None, ) -> TensorHandle: - from kernbench.policy.address.pe_mmu import PeMMU - isize = dtype_itemsize(dtype) total_nbytes = math.prod(shape) * isize @@ -78,22 +75,15 @@ def deploy_tensor( pa = alloc.alloc_hbm(spec.nbytes) else: pa = alloc.alloc_tcm(spec.nbytes) - encoded_pa = pa.encode() shards.append(TensorShard( sip=alloc._sip_id, cube=alloc._cube_id, pe=alloc._pe_id, - pa=encoded_pa, + pa=pa.encode(), nbytes=spec.nbytes, offset_bytes=spec.offset_bytes, )) - # Register VA→PA mapping in all MMUs (broadcast) - if va_base and mmus is not None: - shard_va = va_base + spec.offset_bytes - for mmu in mmus.values(): - mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes) - return TensorHandle( name=name, shape=shape, diff --git a/src/kernbench/sim_engine/engine.py b/src/kernbench/sim_engine/engine.py index 6c25813..298e080 100644 --- a/src/kernbench/sim_engine/engine.py +++ b/src/kernbench/sim_engine/engine.py @@ -5,7 +5,7 @@ from typing import Any import simpy from kernbench.common.types import Completion, RequestHandle, Trace -import kernbench.components.impls # noqa: F401 — registers built-in implementations +import kernbench.components.builtin # noqa: F401 — registers built-in implementations from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.context import ComponentContext from kernbench.policy.address.phyaddr import PhysAddr diff --git a/tests/custom/__init__.py b/tests/custom/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_component_registry.py b/tests/test_component_registry.py index 6af01d8..c5d8ea9 100644 --- a/tests/test_component_registry.py +++ b/tests/test_component_registry.py @@ -13,7 +13,7 @@ import simpy from pathlib import Path from kernbench.components.base import ComponentBase, ComponentRegistry -from kernbench.components.impls.forwarding import TransitComponent +from kernbench.components.builtin.forwarding import TransitComponent from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import MemoryReadMsg from kernbench.sim_engine.engine import GraphEngine diff --git a/tests/test_mmu_component.py b/tests/test_mmu_component.py index 08ba17f..b4ec8ed 100644 --- a/tests/test_mmu_component.py +++ b/tests/test_mmu_component.py @@ -73,7 +73,7 @@ def test_mmu_unmap_msg_fields(): def test_pe_mmu_registry(): """pe_mmu_v1 impl resolves in ComponentRegistry.""" from kernbench.components.base import ComponentRegistry - from kernbench.components.impls.pe_mmu import PeMmuComponent + from kernbench.components.builtin.pe_mmu import PeMmuComponent from kernbench.topology.types import Node node = Node( @@ -93,7 +93,7 @@ def test_pe_mmu_registry(): def test_pe_mmu_processes_map_msg(): """PE_MMU component receives MmuMapMsg → translate works.""" import simpy - from kernbench.components.impls.pe_mmu import PeMmuComponent + from kernbench.components.builtin.pe_mmu import PeMmuComponent from kernbench.sim_engine.transaction import Transaction from kernbench.topology.types import Node @@ -152,7 +152,7 @@ def test_pe_dma_translates_va(): # This test validates the interface contract. Full integration test # requires the engine wiring which is validated in test_engine. # Here we check that PE_DMA has an mmu attribute it can call. - from kernbench.components.impls.pe_dma import PeDmaComponent + from kernbench.components.builtin.pe_dma import PeDmaComponent from kernbench.topology.types import Node node = Node( diff --git a/tests/test_pe_components.py b/tests/test_pe_components.py index 3149edc..6a77077 100644 --- a/tests/test_pe_components.py +++ b/tests/test_pe_components.py @@ -20,12 +20,12 @@ from kernbench.common.pe_commands import ( TensorHandle, ) from kernbench.components.base import ComponentRegistry -from kernbench.components.impls.pe_cpu import PeCpuComponent -from kernbench.components.impls.pe_dma import PeDmaComponent -from kernbench.components.impls.pe_gemm import PeGemmComponent -from kernbench.components.impls.pe_math import PeMathComponent -from kernbench.components.impls.pe_scheduler import PeSchedulerComponent -from kernbench.components.impls.pe_tcm import PeTcmComponent +from kernbench.components.builtin.pe_cpu import PeCpuComponent +from kernbench.components.builtin.pe_dma import PeDmaComponent +from kernbench.components.builtin.pe_gemm import PeGemmComponent +from kernbench.components.builtin.pe_math import PeMathComponent +from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent +from kernbench.components.builtin.pe_tcm import PeTcmComponent from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import ( KernelLaunchMsg, @@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes(): deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")] kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"] assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy) - assert len(kernel_traces) == 1 + assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology) assert kernel_traces[0]["name"] == "qkv_gemm" assert kernel_traces[0]["total_ns"] > 0 - # Scalars should contain M, K, N - assert len(kernel_traces[0]["scalars"]) >= 3 clear_registry() @@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes(): deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")] kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"] assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8 - assert len(kernel_traces) == 1 + assert len(kernel_traces) >= 1 # one per SIP assert kernel_traces[0]["target_pe"] == "all" clear_registry() diff --git a/tests/test_phase_a_components.py b/tests/test_phase_a_components.py index 68c82db..1e1dc2d 100644 --- a/tests/test_phase_a_components.py +++ b/tests/test_phase_a_components.py @@ -19,7 +19,7 @@ import simpy from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.context import ComponentContext -from kernbench.components.impls import ( +from kernbench.components.builtin import ( HbmCtrlComponent, IoCpuComponent, MCpuComponent, diff --git a/tests/test_sip_parallel.py b/tests/test_sip_parallel.py new file mode 100644 index 0000000..33d15fa --- /dev/null +++ b/tests/test_sip_parallel.py @@ -0,0 +1,157 @@ +"""Tests for SIP-level tensor parallelism. + +Validates: + SP1. DPPolicy accepts sip field (default "replicate", backward compat) + SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips + SP3. sip="row_wise": tensor M-axis split across SIPs + SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets + SP5. sip="replicate": all SIPs get full copy (existing behavior) + SP6. PE_CPU sets num_programs from shard count per cube + SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology +""" +import pytest +from pathlib import Path + +from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy + + +# ── SP1. DPPolicy sip field ────────────────────────────────────────── + + +def test_dp_policy_sip_default_replicate(): + """DPPolicy without sip= defaults to 'replicate'.""" + dp = DPPolicy(cube="replicate", pe="column_wise") + assert dp.sip == "replicate" + + +def test_dp_policy_sip_column_wise(): + """DPPolicy accepts sip='column_wise'.""" + dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") + assert dp.sip == "column_wise" + + +# ── SP2. sip="column_wise" ────────────────────────────────────────────── + + +def test_sip_column_wise_splits_across_sips(): + """sip='column_wise' with 2 SIPs: each SIP gets K//2 columns.""" + dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") + shards = resolve_dp_policy( + dp, shape=(128, 256), itemsize=2, + num_pe=8, num_cubes=1, num_sips=2, + ) + # 2 SIPs × 1 cube × 8 PEs = 16 shards + assert len(shards) == 16 + + # SIP0 shards: first half of K (0 to K//2) + # SIP1 shards: second half of K (K//2 to K) + total_bytes = 128 * 256 * 2 # 64KB + sip0_shards = [s for s in shards if s.pe_index < 8] + sip1_shards = [s for s in shards if s.pe_index >= 8] + + # SIP0 offsets start at 0 + assert sip0_shards[0].offset_bytes == 0 + # SIP1 offsets start at half + assert sip1_shards[0].offset_bytes == total_bytes // 2 + + # Total coverage + assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2 + assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2 + + +# ── SP3. sip="row_wise" ────────────────────────────────────────────── + + +def test_sip_row_wise_splits_across_sips(): + """sip='row_wise' with 2 SIPs: each SIP gets M//2 rows.""" + dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise") + shards = resolve_dp_policy( + dp, shape=(128, 256), itemsize=2, + num_pe=8, num_cubes=1, num_sips=2, + ) + assert len(shards) == 16 + + sip0_shards = [s for s in shards if s.pe_index < 8] + sip1_shards = [s for s in shards if s.pe_index >= 8] + + # SIP0: rows 0..63, SIP1: rows 64..127 + total_bytes = 128 * 256 * 2 + assert sip0_shards[0].offset_bytes == 0 + assert sip1_shards[0].offset_bytes == total_bytes // 2 + + +# ── SP4. 3-level resolve ───────────────────────────────────────────── + + +def test_3level_resolve_flat_index(): + """3-level: sip × cube × pe produces correct flat indices.""" + dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") + shards = resolve_dp_policy( + dp, shape=(128, 256), itemsize=2, + num_pe=8, num_cubes=2, num_sips=2, + ) + # 2 SIPs × 2 cubes × 8 PEs = 32 shards + assert len(shards) == 32 + + # Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id + indices = [s.pe_index for s in shards] + # SIP0: 0..15, SIP1: 16..31 + assert min(indices) == 0 + assert max(indices) == 31 + assert len(set(indices)) == 32 # all unique + + +def test_3level_offsets_cover_full_tensor(): + """3-level sharding covers the entire tensor with no gaps.""" + dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise") + shards = resolve_dp_policy( + dp, shape=(128, 256), itemsize=2, + num_pe=4, num_cubes=1, num_sips=2, + ) + # 2 SIPs × 1 cube × 4 PEs = 8 shards + # sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE + total = 128 * 256 * 2 + # For non-replicate, total shard bytes == tensor bytes + # (replicate within cube means cube shards overlap, but sip shards don't) + sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4) + sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4) + assert sip0_bytes + sip1_bytes == total + + +# ── SP5. sip="replicate" backward compat ───────────────────────────── + + +def test_sip_replicate_backward_compat(): + """sip='replicate' produces same result as before (2-level).""" + dp_old = DPPolicy(cube="replicate", pe="column_wise") + dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise") + + shards_old = resolve_dp_policy( + dp_old, shape=(128, 256), itemsize=2, + num_pe=8, num_cubes=2, num_sips=2, + ) + shards_new = resolve_dp_policy( + dp_new, shape=(128, 256), itemsize=2, + num_pe=8, num_cubes=2, num_sips=2, + ) + assert len(shards_old) == len(shards_new) + for a, b in zip(shards_old, shards_new): + assert a.pe_index == b.pe_index + assert a.offset_bytes == b.offset_bytes + assert a.nbytes == b.nbytes + + +# ── SP6. PE_CPU num_programs ────────────────────────────────────────── + + +def test_pe_cpu_sets_num_programs(): + """PE_CPU should create TLContext with num_programs = PEs per cube.""" + # This test validates the interface contract. + # After implementation, PE_CPU should derive num_programs from the + # number of PE shards in the kernel launch's target cube. + from kernbench.triton_emu.tl_context import TLContext + + # With 8 PEs per cube, num_programs should be 8 + tl = TLContext(pe_id=3, num_programs=8) + assert tl.program_id(0) == 3 + assert tl.num_programs(0) == 8 diff --git a/tests/test_va_integration.py b/tests/test_va_integration.py index a173cbb..3ecbe6b 100644 --- a/tests/test_va_integration.py +++ b/tests/test_va_integration.py @@ -88,7 +88,6 @@ def test_deploy_tensor_assigns_va_base(): """deploy_tensor with VA allocator assigns va_base to TensorHandle.""" allocs = _make_allocators() va_alloc = _make_va_allocator() - mmus = _make_mmus() placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) th = deploy_tensor( @@ -98,7 +97,6 @@ def test_deploy_tensor_assigns_va_base(): placement=placement, allocators=allocs, va_allocator=va_alloc, - mmus=mmus, ) assert th.va_base is not None @@ -109,7 +107,6 @@ def test_deploy_tensor_va_covers_all_shards(): """VA allocation covers the entire tensor; each shard is at va_base + offset.""" allocs = _make_allocators() va_alloc = _make_va_allocator() - mmus = _make_mmus() placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) th = deploy_tensor( @@ -119,41 +116,32 @@ def test_deploy_tensor_va_covers_all_shards(): placement=placement, allocators=allocs, va_allocator=va_alloc, - mmus=mmus, ) - # Each shard's VA is derivable: va_base + offset_bytes for s in th.shards: shard_va = th.va_base + s.offset_bytes assert shard_va > 0 -def test_deploy_tensor_registers_mmu_mappings(): - """deploy_tensor registers VA→PA mappings in all PE MMUs.""" +def test_deploy_tensor_does_not_install_mmu_mappings(): + """deploy_tensor does NOT install MMU mappings — that's context's job.""" allocs = _make_allocators() va_alloc = _make_va_allocator() mmus = _make_mmus() placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) - th = deploy_tensor( + deploy_tensor( name="W", shape=(1024, 512), dtype="fp16", placement=placement, allocators=allocs, va_allocator=va_alloc, - mmus=mmus, ) - # Every MMU should have entries (broadcast) + # No MMU should have any entries (mappings come from fabric MmuMapMsg) for mmu in mmus.values(): - assert mmu.num_entries > 0 - - # Each shard's derived VA should translate to its PA in every MMU - for mmu in mmus.values(): - for s in th.shards: - shard_va = th.va_base + s.offset_bytes - assert mmu.translate(shard_va) == s.pa + assert mmu.num_entries == 0 # ── T12. Tensor.va property ────────────────────────────────────────── @@ -165,7 +153,6 @@ def test_tensor_va_property(): allocs = _make_allocators(1) va_alloc = _make_va_allocator() - mmus = _make_mmus(1) placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)] t = Tensor(shape=(2048,), dtype="f16", name="test") @@ -176,7 +163,6 @@ def test_tensor_va_property(): placement=placement, allocators=allocs, va_allocator=va_alloc, - mmus=mmus, ) assert t.va > 0 assert t.va == t._handle.va_base diff --git a/tests/test_va_offset.py b/tests/test_va_offset.py new file mode 100644 index 0000000..8537874 --- /dev/null +++ b/tests/test_va_offset.py @@ -0,0 +1,216 @@ +"""VA offset verification: each PE accesses its own local HBM slice. + +Verifies that column-wise sharding + VA offset calculation produces DMA +addresses that translate to the correct PE's local HBM — not a remote PE. + +Tests: + VO1. Per-PE DMA addresses are correct VAs (2D) + VO2. Each VA translates to the executing PE's own HBM slice (2D) + VO3. End-to-end bench completes (2D, full TP) + VO4. Per-PE DMA addresses are correct VAs (1D) + VO5. Each VA translates to local HBM (1D) + VO6. End-to-end 1D bench completes +""" +from pathlib import Path + +import pytest + +from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd +from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator +from kernbench.policy.address.pe_mmu import PeMMU +from kernbench.policy.address.phyaddr import PhysAddr +from kernbench.policy.address.va_allocator import VirtualAllocator +from kernbench.policy.placement.dp import DPPolicy, column_wise +from kernbench.runtime_api.tensor import deploy_tensor +from kernbench.sim_engine.engine import GraphEngine +from kernbench.runtime_api.context import RuntimeContext +from kernbench.runtime_api.types import DeviceSelector +from kernbench.topology.builder import load_topology +from kernbench.triton_emu.tl_context import TLContext, run_kernel + +TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" + +_MB = 1 << 20 +_GB = 1 << 30 + +M, K = 128, 256 +DTYPE = "f16" +NUM_PE = 8 +ELEM_BYTES = 2 + + +def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"): + """Standard Triton 2D copy. M, K are cube-local.""" + pid = tl.program_id(0) + num_pe = tl.num_programs(0) + cols_per_pe = K // num_pe + elem_bytes = 2 + offset = pid * M * cols_per_pe * elem_bytes + data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE) + tl.store(dst_ptr + offset, data) + + +def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"): + """Standard Triton 1D copy. N is cube-local.""" + pid = tl.program_id(0) + num_pe = tl.num_programs(0) + elems_per_pe = N // num_pe + elem_bytes = 2 + offset = pid * elems_per_pe * elem_bytes + data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE) + tl.store(dst_ptr + offset, data) + + +def _make_standalone(shape, num_pe=NUM_PE): + """Create standalone allocators + MMUs for unit testing.""" + cfg = AddressConfig( + sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe, + hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe, + tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, + sram_bytes_per_cube=32 * _MB, + ) + allocators = { + i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg) + for i in range(num_pe) + } + va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096) + mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)} + return cfg, allocators, va_alloc, mmus + + +# ── VO1. 2D: Per-PE DMA addresses are correct VAs ──────────────────── + + +def test_2d_each_pe_computes_correct_va_offset(): + """2D: each PE generates DMA at va_base + pid * block_bytes.""" + src_va = 0x1_0000_0000 + dst_va = 0x2_0000_0000 + cols_per_pe = K // NUM_PE + block_bytes = M * cols_per_pe * ELEM_BYTES + + for pe_id in range(NUM_PE): + tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0) + run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K) + + reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)] + writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] + + expected_offset = pe_id * block_bytes + assert reads[0].src_addr == src_va + expected_offset + assert writes[0].dst_addr == dst_va + expected_offset + + +# ── VO2. 2D: Each VA translates to local HBM ───────────────────────── + + +def test_2d_va_translates_to_local_hbm(): + """2D: each PE's DMA VA translates to its own HBM slice.""" + cfg, allocators, va_alloc, mmus = _make_standalone((M, K)) + slice_size = cfg.hbm_slice_bytes + cols_per_pe = K // NUM_PE + block_bytes = M * cols_per_pe * ELEM_BYTES + + placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE) + handle = deploy_tensor( + name="src", shape=(M, K), dtype="fp16", + placement=placement, allocators=allocators, va_allocator=va_alloc, + ) + + # Install per-PE mappings (simulating what context does via MmuMapMsg) + for s in handle.shards: + mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes) + + for pe_id in range(NUM_PE): + va = handle.va_base + pe_id * block_bytes + pa = mmus[pe_id].translate(va) + decoded = PhysAddr.decode(pa) + hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size) + assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM" + + +# ── VO3. 2D: End-to-end bench completes ────────────────────────────── + + +def test_2d_bench_completes(): + """2D: full TP bench with standard Triton kernel pattern.""" + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + ctx = RuntimeContext( + engine=engine, target_device=DeviceSelector("sip:0"), + correlation_id="vo3", spec=graph.spec, + ) + from benches.va_offset_verify import run as bench_run + bench_run(ctx) + ctx.wait_all() + + +# ── VO4. 1D: Per-PE DMA addresses ──────────────────────────────────── + +N_1D = 1024 + + +def test_1d_each_pe_computes_correct_offset(): + """1D: each PE generates DMA at correct offset.""" + src_va = 0x1_0000_0000 + dst_va = 0x2_0000_0000 + elems_per_pe = N_1D // NUM_PE + block_bytes = elems_per_pe * ELEM_BYTES + + for pe_id in range(NUM_PE): + tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0) + run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D) + + reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)] + writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] + + expected_offset = pe_id * block_bytes + assert reads[0].src_addr == src_va + expected_offset + assert writes[0].dst_addr == dst_va + expected_offset + + +# ── VO5. 1D: VA translates to local HBM ────────────────────────────── + + +def test_1d_va_translates_to_local_hbm(): + """1D: each PE's DMA VA translates to its own HBM slice.""" + cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D)) + slice_size = cfg.hbm_slice_bytes + elems_per_pe = N_1D // NUM_PE + block_bytes = elems_per_pe * ELEM_BYTES + + placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE) + handle = deploy_tensor( + name="src_1d", shape=(N_1D,), dtype="fp16", + placement=placement, allocators=allocators, va_allocator=va_alloc, + ) + + for s in handle.shards: + mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes) + + for pe_id in range(NUM_PE): + va = handle.va_base + pe_id * block_bytes + pa = mmus[pe_id].translate(va) + decoded = PhysAddr.decode(pa) + hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size) + assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM" + + +# ── VO6. 1D: End-to-end ────────────────────────────────────────────── + + +def test_1d_e2e_completes(): + """1D: full engine run with column_wise TP sharding.""" + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + ctx = RuntimeContext( + engine=engine, target_device=DeviceSelector("sip:0"), + correlation_id="vo6", spec=graph.spec, + ) + + dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") + src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d") + dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d") + + # launch() auto-localizes N_1D → cube-local N + ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D) + ctx.wait_all()