Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+54 -4
View File
@@ -140,16 +140,65 @@ class ComponentRegistry:
Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl
2. _registry[node.impl] — globally registered impl (lazy import)
3. Error — no fallback; every node must have an impl
Registry is populated from components.yaml via load_components_yaml().
Manual register() is still supported for tests and overrides.
"""
_registry: dict[str, type[ComponentBase]] = {}
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
_loaded: bool = False
@classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls
@classmethod
def load_components_yaml(cls, path: str | None = None) -> None:
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
if cls._loaded:
return
import yaml
from pathlib import Path
if path is None:
# Search: project root (cwd), then relative to this file
candidates = [
Path.cwd() / "components.yaml",
Path(__file__).parent.parent.parent.parent / "components.yaml",
]
for p in candidates:
if p.exists():
path = str(p)
break
if path is None:
return
with open(path) as f:
spec = yaml.safe_load(f)
for impl, class_path in (spec.get("components") or {}).items():
cls._lazy[impl] = class_path
cls._loaded = True
@classmethod
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
if impl in cls._registry:
return cls._registry[impl]
if not cls._loaded:
cls.load_components_yaml()
class_path = cls._lazy.get(impl)
if class_path is None:
return None
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
component_cls = getattr(mod, class_name)
cls._registry[impl] = component_cls # cache for next lookup
return component_cls
@classmethod
def create(
cls,
@@ -159,9 +208,10 @@ class ComponentRegistry:
) -> ComponentBase:
if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx)
if node.impl in cls._registry:
return cls._registry[node.impl](node, ctx)
component_cls = cls._resolve(node.impl)
if component_cls is not None:
return component_cls(node, ctx)
raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__."
f"Add it to components.yaml or call ComponentRegistry.register()."
)
@@ -0,0 +1,34 @@
"""Concrete component implementations.
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
Manual imports are no longer needed — add new impls to components.yaml.
Classes are still importable from this package via lazy __getattr__.
"""
from kernbench.components.base import ComponentRegistry
ComponentRegistry.load_components_yaml()
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
# without eagerly importing every module.
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
def _build_class_map() -> None:
if _CLASS_MAP:
return
for class_path in ComponentRegistry._lazy.values():
module_path, class_name = class_path.rsplit(":", 1)
_CLASS_MAP[class_name] = class_path
def __getattr__(name: str):
_build_class_map()
class_path = _CLASS_MAP.get(name)
if class_path is None:
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Derive num_programs from the number of PE shards in this cube
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → VA base (or PA fallback), ScalarArg → value
# TensorArg → va_base (already local, set by runtime) or PA fallback
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
@@ -0,0 +1,5 @@
"""Custom component implementations.
Place your component files here and register them in components.yaml.
See components.yaml header for instructions.
"""
@@ -1,59 +0,0 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"PositionAwareXbarComponent",
"SramComponent",
]