Add SIP-level tensor parallelism, component registry YAML, VA offset verification

- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 01:13:17 -07:00
parent 08812eda58
commit 63669f82cb
35 changed files with 813 additions and 219 deletions
+186
View File
@@ -0,0 +1,186 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeCpuComponent(ComponentBase):
"""PE_CPU: kernel execution controller (Stage 2).
Two-phase kernel execution (ADR-0014 D1):
Phase 1 (compile): look up kernel from registry, run it with TLContext
to generate a PeCommand list.
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
PeInternalTxn, wait for blocking commands.
Non-kernel Transactions are forwarded normally.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
try:
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
except (IndexError, ValueError):
self._pe_idx = 0
# Extract sip/cube index for multi-SIP/cube shard matching
parts = node.id.split(".")
try:
self._sip_idx = int(parts[0].replace("sip", ""))
except (IndexError, ValueError):
self._sip_idx = 0
try:
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
for s in shards:
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
return s
return shards[min(self._pe_idx, len(shards) - 1)]
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
from kernbench.runtime_api.kernel import KernelLaunchMsg
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
yield from self._execute_kernel(env, txn)
else:
yield from self._forward_txn(env, txn)
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
"""Compile kernel function and replay command trace."""
from kernbench.common.pe_commands import (
CompositeCmd,
PeCpuOverheadCmd,
PeInternalTxn,
WaitCmd,
)
from kernbench.triton_emu.registry import get_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
request = txn.request
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
# Derive num_programs from the number of PE shards in this cube
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → va_base (already local, set by runtime) or PA fallback
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
if arg.va_base:
kernel_args.append(arg.va_base)
else:
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
pe_exec_start = env.now
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
pending: dict[str, simpy.Event] = {} # completion_id → done event
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
for cmd in commands:
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
elif isinstance(cmd, WaitCmd):
if cmd.handle is not None:
evt = pending.pop(cmd.handle.id, None)
if evt:
yield evt
else:
# Wait all pending completions
for evt in pending.values():
yield evt
pending.clear()
elif isinstance(cmd, CompositeCmd):
# Non-blocking: dispatch to scheduler, track completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
composite_results.append(pe_txn.result_data)
yield self.out_ports[scheduler_id].put(pe_txn)
pending[cmd.completion.id] = done_evt
else:
# Blocking: dispatch and wait for completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
yield self.out_ports[scheduler_id].put(pe_txn)
yield done_evt
# Wait for any remaining pending completions
for evt in pending.values():
yield evt
# Record PE-internal execution time
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
# Aggregate dma_ns / compute_ns from CompositeCmd results
total_dma_ns = 0.0
total_compute_ns = 0.0
for rd in composite_results:
total_dma_ns += rd.get("dma_ns", 0.0)
total_compute_ns += rd.get("compute_ns", 0.0)
txn.result_data["dma_ns"] = total_dma_ns
txn.result_data["compute_ns"] = total_compute_ns
# Send ResponseMsg on reverse path (PE_CPU → NOC → M_CPU)
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=self._cube_idx, src_pe=self._pe_idx,
success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()