Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -140,16 +140,65 @@ class ComponentRegistry:
|
||||
|
||||
Resolution order for ComponentRegistry.create(node, overrides, ctx):
|
||||
1. overrides[node.impl] — caller-injected override
|
||||
2. _registry[node.impl] — globally registered impl
|
||||
2. _registry[node.impl] — globally registered impl (lazy import)
|
||||
3. Error — no fallback; every node must have an impl
|
||||
|
||||
Registry is populated from components.yaml via load_components_yaml().
|
||||
Manual register() is still supported for tests and overrides.
|
||||
"""
|
||||
|
||||
_registry: dict[str, type[ComponentBase]] = {}
|
||||
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
|
||||
_loaded: bool = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
|
||||
cls._registry[impl] = component_cls
|
||||
|
||||
@classmethod
|
||||
def load_components_yaml(cls, path: str | None = None) -> None:
|
||||
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
|
||||
if cls._loaded:
|
||||
return
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
if path is None:
|
||||
# Search: project root (cwd), then relative to this file
|
||||
candidates = [
|
||||
Path.cwd() / "components.yaml",
|
||||
Path(__file__).parent.parent.parent.parent / "components.yaml",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
path = str(p)
|
||||
break
|
||||
if path is None:
|
||||
return
|
||||
|
||||
with open(path) as f:
|
||||
spec = yaml.safe_load(f)
|
||||
for impl, class_path in (spec.get("components") or {}).items():
|
||||
cls._lazy[impl] = class_path
|
||||
cls._loaded = True
|
||||
|
||||
@classmethod
|
||||
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
|
||||
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
|
||||
if impl in cls._registry:
|
||||
return cls._registry[impl]
|
||||
if not cls._loaded:
|
||||
cls.load_components_yaml()
|
||||
class_path = cls._lazy.get(impl)
|
||||
if class_path is None:
|
||||
return None
|
||||
import importlib
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
component_cls = getattr(mod, class_name)
|
||||
cls._registry[impl] = component_cls # cache for next lookup
|
||||
return component_cls
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -159,9 +208,10 @@ class ComponentRegistry:
|
||||
) -> ComponentBase:
|
||||
if overrides and node.impl in overrides:
|
||||
return overrides[node.impl](node, ctx)
|
||||
if node.impl in cls._registry:
|
||||
return cls._registry[node.impl](node, ctx)
|
||||
component_cls = cls._resolve(node.impl)
|
||||
if component_cls is not None:
|
||||
return component_cls(node, ctx)
|
||||
raise ValueError(
|
||||
f"No component registered for impl '{node.impl}' (node: {node.id}). "
|
||||
f"Register it in kernbench.components.impls.__init__."
|
||||
f"Add it to components.yaml or call ComponentRegistry.register()."
|
||||
)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
|
||||
Manual imports are no longer needed — add new impls to components.yaml.
|
||||
|
||||
Classes are still importable from this package via lazy __getattr__.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
|
||||
ComponentRegistry.load_components_yaml()
|
||||
|
||||
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
|
||||
# without eagerly importing every module.
|
||||
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
|
||||
|
||||
|
||||
def _build_class_map() -> None:
|
||||
if _CLASS_MAP:
|
||||
return
|
||||
for class_path in ComponentRegistry._lazy.values():
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
_CLASS_MAP[class_name] = class_path
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
_build_class_map()
|
||||
class_path = _CLASS_MAP.get(name)
|
||||
if class_path is None:
|
||||
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
|
||||
import importlib
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, class_name)
|
||||
+14
-2
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||
|
||||
# Derive num_programs from the number of PE shards in this cube
|
||||
num_programs = 1
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
cube_pe_count = sum(
|
||||
1 for s in arg.shards
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx
|
||||
)
|
||||
if cube_pe_count > num_programs:
|
||||
num_programs = cube_pe_count
|
||||
|
||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → VA base (or PA fallback), ScalarArg → value
|
||||
# TensorArg → va_base (already local, set by runtime) or PA fallback
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Custom component implementations.
|
||||
|
||||
Place your component files here and register them in components.yaml.
|
||||
See components.yaml header for instructions.
|
||||
"""
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Each module registers its component(s) with ComponentRegistry on import.
|
||||
Import this package to activate all built-in implementations.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
|
||||
from kernbench.components.impls.io_cpu import IoCpuComponent
|
||||
from kernbench.components.impls.m_cpu import MCpuComponent
|
||||
from kernbench.components.impls.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.impls.pcie_ep import PcieEpComponent
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.impls.sram import SramComponent
|
||||
from kernbench.components.impls.xbar import PositionAwareXbarComponent
|
||||
|
||||
ComponentRegistry.register("forwarding_v1", TransitComponent)
|
||||
ComponentRegistry.register("switch_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
|
||||
ComponentRegistry.register("ucie_v1", TransitComponent)
|
||||
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
|
||||
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
|
||||
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
|
||||
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
|
||||
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
|
||||
ComponentRegistry.register("sram_v1", SramComponent)
|
||||
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
|
||||
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
|
||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||
|
||||
__all__ = [
|
||||
"HbmCtrlComponent",
|
||||
"IoCpuComponent",
|
||||
"MCpuComponent",
|
||||
"PcieEpComponent",
|
||||
"PeCpuComponent",
|
||||
"PeDmaComponent",
|
||||
"PeGemmComponent",
|
||||
"PeMathComponent",
|
||||
"PeMmuComponent",
|
||||
"PeSchedulerComponent",
|
||||
"PeTcmComponent",
|
||||
"TransitComponent",
|
||||
"TwoDMeshNocComponent",
|
||||
"PositionAwareXbarComponent",
|
||||
"SramComponent",
|
||||
]
|
||||
Reference in New Issue
Block a user