Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+54 -4
View File
@@ -140,16 +140,65 @@ class ComponentRegistry:
Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl
2. _registry[node.impl] — globally registered impl (lazy import)
3. Error — no fallback; every node must have an impl
Registry is populated from components.yaml via load_components_yaml().
Manual register() is still supported for tests and overrides.
"""
_registry: dict[str, type[ComponentBase]] = {}
_lazy: dict[str, str] = {} # impl → "module.path:ClassName"
_loaded: bool = False
@classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls
@classmethod
def load_components_yaml(cls, path: str | None = None) -> None:
"""Load impl→class mappings from components.yaml. Lazy imports on first use."""
if cls._loaded:
return
import yaml
from pathlib import Path
if path is None:
# Search: project root (cwd), then relative to this file
candidates = [
Path.cwd() / "components.yaml",
Path(__file__).parent.parent.parent.parent / "components.yaml",
]
for p in candidates:
if p.exists():
path = str(p)
break
if path is None:
return
with open(path) as f:
spec = yaml.safe_load(f)
for impl, class_path in (spec.get("components") or {}).items():
cls._lazy[impl] = class_path
cls._loaded = True
@classmethod
def _resolve(cls, impl: str) -> type[ComponentBase] | None:
"""Resolve impl name: check _registry first, then lazy import from _lazy."""
if impl in cls._registry:
return cls._registry[impl]
if not cls._loaded:
cls.load_components_yaml()
class_path = cls._lazy.get(impl)
if class_path is None:
return None
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
component_cls = getattr(mod, class_name)
cls._registry[impl] = component_cls # cache for next lookup
return component_cls
@classmethod
def create(
cls,
@@ -159,9 +208,10 @@ class ComponentRegistry:
) -> ComponentBase:
if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx)
if node.impl in cls._registry:
return cls._registry[node.impl](node, ctx)
component_cls = cls._resolve(node.impl)
if component_cls is not None:
return component_cls(node, ctx)
raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__."
f"Add it to components.yaml or call ComponentRegistry.register()."
)
@@ -0,0 +1,34 @@
"""Concrete component implementations.
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
Manual imports are no longer needed — add new impls to components.yaml.
Classes are still importable from this package via lazy __getattr__.
"""
from kernbench.components.base import ComponentRegistry
ComponentRegistry.load_components_yaml()
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
# without eagerly importing every module.
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
def _build_class_map() -> None:
if _CLASS_MAP:
return
for class_path in ComponentRegistry._lazy.values():
module_path, class_name = class_path.rsplit(":", 1)
_CLASS_MAP[class_name] = class_path
def __getattr__(name: str):
_build_class_map()
class_path = _CLASS_MAP.get(name)
if class_path is None:
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)
@@ -81,10 +81,22 @@ class PeCpuComponent(ComponentBase):
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Derive num_programs from the number of PE shards in this cube
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → VA base (or PA fallback), ScalarArg → value
# TensorArg → va_base (already local, set by runtime) or PA fallback
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
@@ -0,0 +1,5 @@
"""Custom component implementations.
Place your component files here and register them in components.yaml.
See components.yaml header for instructions.
"""
@@ -1,59 +0,0 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", PositionAwareXbarComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"PositionAwareXbarComponent",
"SramComponent",
]
+53 -35
View File
@@ -7,12 +7,36 @@ from typing import Literal
@dataclass(frozen=True)
class DPPolicy:
"""Two-level data-parallel policy: cube-level + pe-level."""
"""Three-level data-parallel policy: sip-level + cube-level + pe-level.
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate"
Policies:
- "replicate": full copy at each unit
- "column_wise": split K (column) axis across units
- "row_wise": split M (row) axis across units
"""
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
def _split_shape(
policy: str, shape: tuple[int, int], count: int, itemsize: int,
) -> list[tuple[tuple[int, int], int]]:
"""Split shape by policy into (sub_shape, byte_offset) pairs."""
M, K = shape
if policy == "replicate":
return [((M, K), 0)] * count
elif policy == "column_wise":
chunk_k = K // count
return [((M, chunk_k), i * M * chunk_k * itemsize) for i in range(count)]
elif policy == "row_wise":
chunk_m = M // count
return [((chunk_m, K), i * chunk_m * K * itemsize) for i in range(count)]
else:
raise ValueError(f"Unknown policy: {policy}")
def resolve_dp_policy(
policy: DPPolicy,
*,
@@ -20,11 +44,14 @@ def resolve_dp_policy(
itemsize: int,
num_pe: int,
num_cubes: int = 1,
num_sips: int = 1,
) -> list[ShardSpec]:
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution.
"""Resolve a DPPolicy into a list[ShardSpec] with three-level resolution.
Cube-level policy distributes across cubes, pe-level distributes within
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id.
SIP-level → cube-level → pe-level.
num_cubes is cubes per SIP (not total).
ShardSpec.pe_index uses flat indexing:
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
"""
_PE_RESOLVERS = {
"replicate": replicate,
@@ -35,40 +62,31 @@ def resolve_dp_policy(
if resolver is None:
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
if num_cubes <= 1:
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
# Two-level resolution: cube-level → pe-level
M, K = shape
cubes_per_sip = num_cubes
all_shards: list[ShardSpec] = []
for cube_id in range(num_cubes):
# Determine per-cube shape based on cube-level policy
if policy.cube == "replicate":
cube_shape = (M, K)
cube_offset = 0
elif policy.cube == "shard_m":
chunk_m = M // num_cubes
cube_shape = (chunk_m, K)
cube_offset = cube_id * chunk_m * K * itemsize
elif policy.cube == "shard_k":
chunk_k = K // num_cubes
cube_shape = (M, chunk_k)
cube_offset = cube_id * M * chunk_k * itemsize
else:
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
# Level 1: SIP
sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize)
# Resolve pe-level within this cube's shape
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
# Level 2: Cube within SIP
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
# Remap pe_index to flat index and adjust offset
for ps in pe_shards:
flat_idx = cube_id * num_pe + ps.pe_index
all_shards.append(ShardSpec(
pe_index=flat_idx,
offset_bytes=cube_offset + ps.offset_bytes,
nbytes=ps.nbytes,
))
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
# Level 3: PE within cube
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
for ps in pe_shards:
flat_idx = (
sip_id * cubes_per_sip * num_pe
+ cube_id * num_pe
+ ps.pe_index
)
all_shards.append(ShardSpec(
pe_index=flat_idx,
offset_bytes=sip_offset + cube_offset + ps.offset_bytes,
nbytes=ps.nbytes,
))
return all_shards
+161 -69
View File
@@ -20,7 +20,6 @@ class RuntimeContext:
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
@@ -208,24 +207,17 @@ class RuntimeContext:
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
@@ -276,11 +268,11 @@ class RuntimeContext:
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
num_sips=self._num_sips,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
@@ -297,7 +289,6 @@ class RuntimeContext:
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
@@ -305,6 +296,8 @@ class RuntimeContext:
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
# Strategy: always SIP-scoped (each SIP gets only its own shards).
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
@@ -313,47 +306,52 @@ class RuntimeContext:
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
# Group shards by SIP
sip_groups: dict[int, list] = defaultdict(list)
for shard in handle.shards:
sip_groups[shard.sip].append(shard)
for (sip, cube), group_shards in cube_groups.items():
for sip, sip_shards in sip_groups.items():
if is_cube_replicate:
# Cube replicate: per-(sip, cube) local mapping
cube_groups: dict[int, list] = defaultdict(list)
for s in sip_shards:
cube_groups[s.cube].append(s)
for cube, group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Cube sharded: broadcast all cubes within this SIP
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
for s in sip_shards
)
cube_set = sorted({s.cube for s in sip_shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
request_id=f"mmu_{tensor_name}_s{sip}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
@@ -384,11 +382,18 @@ class RuntimeContext:
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
"""
from collections import defaultdict
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
@@ -399,14 +404,14 @@ class RuntimeContext:
except ValueError:
pass
# Build kernel args from positional + keyword args
kernel_args: list = []
# Collect tensors and scalars
tensor_args: list[Tensor] = []
scalar_args: list = []
target_pe: int | str = 0
for a in args:
if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg())
# Infer target_pe from tensor DP metadata
tensor_args.append(a)
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
@@ -415,34 +420,121 @@ class RuntimeContext:
target_pe = dp_target
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards
cube_set: set[int] = set()
for a in args:
if isinstance(a, Tensor) and a._handle is not None:
for s in a._handle.shards:
cube_set.add(s.cube)
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
# Determine all target SIPs from tensor shards
sip_set: set[int] = set()
for t in tensor_args:
if t._handle is not None:
for s in t._handle.shards:
sip_set.add(s.sip)
if not sip_set:
sip_set = {0}
# Collect scalar values for GEMM FLOP calculation
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
# Build global→local dimension mapping from tensor DPPolicies.
# Scalar args matching a tensor's global dimension get replaced
# with the cube-local value (what the kernel actually operates on).
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
"""Compute cube-local shape from DPPolicy."""
shape = t.shape
if len(shape) < 2:
shape = (1, shape[0])
M, K = shape[0], shape[1]
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None:
return t.shape
if dp.sip != "replicate":
if dp.sip == "column_wise":
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate":
if dp.cube == "column_wise":
K = K // self._num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=kernel_name,
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals,
})
return h
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
for t in tensor_args:
if t._handle is None:
continue
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
if not sip_shards:
sip_shards = list(t._handle.shards)
local_va_base = 0
if t._handle.va_base:
min_offset = min(s.offset_bytes for s in sip_shards)
local_va_base = t._handle.va_base + min_offset
sip_kernel_args.append(TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in sip_shards
),
va_base=local_va_base,
))
for s in sip_shards:
sip_cube_set.add(s.cube)
# Interleave tensor args and scalar args, replacing global dims with local
final_args: list = []
t_idx, s_idx = 0, 0
for a in args:
if isinstance(a, Tensor):
final_args.append(sip_kernel_args[t_idx])
t_idx += 1
elif isinstance(a, (int, float)):
sa = scalar_args[s_idx]
if isinstance(a, int) and a in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
final_args.append(sa)
s_idx += 1
while s_idx < len(scalar_args):
sa = scalar_args[s_idx]
if isinstance(sa.value, int) and int(sa.value) in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
final_args.append(sa)
s_idx += 1
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=f"{kernel_name}_sip{sip_id}",
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(final_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
last_handle = h
return last_handle
+1 -11
View File
@@ -59,10 +59,7 @@ def deploy_tensor(
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize
@@ -78,22 +75,15 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=encoded_pa,
pa=pa.encode(),
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle(
name=name,
shape=shape,
+1 -1
View File
@@ -5,7 +5,7 @@ from typing import Any
import simpy
from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.impls # noqa: F401 — registers built-in implementations
import kernbench.components.builtin # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr