Files
kernbench2/src/kernbench/components/builtin/m_cpu.py
T
mukesh 14d800b0ae Kernel-launch sync (ADR-0009 D5) and IPCQ drain at inbound (ADR-0023)
- KernelLaunchMsg gains target_start_ns: IO_CPU stamps a global barrier
  (max path latency across every target PE), M_CPU passes it through,
  PE_CPU yields until it before recording pe_exec_start. Every PE in a
  launch begins kernel execution at the same env.now regardless of its
  dispatch path length — eliminates per-PE dispatch-offset artifact in
  cross-PE and cross-cube latency measurements.

- PE_DMA._handle_ipcq_inbound now pays Transaction.drain_ns at the top,
  matching the terminal-drain behavior of ComponentBase._forward_txn for
  every non-IPCQ Transaction. SRC-side tl.send stays fire-and-forget
  (sender doesn't yield on sub_done); tl.recv now blocks until bytes
  have actually drained into its inbox.

- ComponentContext: new compute_path_latency_ns helper + node_overhead_ns
  field populated by GraphEngine.

- tests/test_kernel_launch_sync.py: asserts all PEs in one launch
  produce identical pe_exec_ns for a no-op kernel (zero spread).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 15:30:29 -07:00

357 lines
14 KiB
Python

from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class MCpuComponent(ComponentBase):
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
Forward path (ADR-0015 D5):
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
controls fan-out: int → single PE, "all" → all PEs in the cube.
Response path:
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
path back to io_cpu.
Transit:
When m_cpu is NOT the terminal hop (transit or response relay), the
Transaction is forwarded normally to the next hop.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
# Store parent txn for response sending: request_id → parent_txn
self._parent_txns: dict[str, Any] = {}
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
self._dma_write: simpy.Resource | None = None
self._dma_read: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
self._dma_write = simpy.Resource(env, capacity=1)
self._dma_read = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
env.process(self._mmu_msg_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
txn.done.succeed()
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a PE response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, all_done = self._pending[key]
received += 1
if received >= expected:
all_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, all_done)
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
then send aggregate response on reverse command path.
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
so multi-PE fan-out is serialized through the DMA engine.
"""
from kernbench.runtime_api.kernel import MemoryWriteMsg
request = txn.request
target_pe = getattr(request, "target_pe", "all")
dst_nodes = self._resolve_dma_destinations(request, target_pe)
if not dst_nodes:
txn.done.succeed()
return
# Setup aggregation
all_done = env.event()
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
self._parent_txns[request.request_id] = txn
# Select DMA resource based on operation type
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
# Fan out DMA sub-txns (serialized through DMA resource)
max_drain_ns = 0.0
for dst_node in dst_nodes:
try:
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
except Exception:
continue
if len(dma_path) < 2:
continue
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
max_drain_ns = max(max_drain_ns, drain_ns)
sub_txn = Transaction(
request=request, path=dma_path, step=0,
nbytes=txn.nbytes, done=env.event(),
drain_ns=drain_ns,
)
with dma_res.request() as req:
yield req
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
# Wait for all PE responses
yield all_done
txn.result_data["xfer_ns"] = max_drain_ns
del self._parent_txns[request.request_id]
# Send aggregate response on reverse command path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
PE_CPU sends ResponseMsg back via NOC → M_CPU on completion.
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
ADR-0009 D5: stamps target_start_ns so every PE in this fanout
starts executing at the same env.now regardless of dispatch path.
"""
import dataclasses
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Resolve per-PE paths. If IO_CPU already stamped a global
# target_start_ns (ADR-0009 D5 extended), pass it through
# unchanged so every PE across every cube uses the same barrier.
# Otherwise (e.g. direct-to-M_CPU launch in a unit test) compute
# a per-cube barrier from env.now.
per_pe: list[tuple[int, list[str], float]] = []
max_latency = 0.0
for pe_id in pe_ids:
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
latency = self.ctx.compute_path_latency_ns(path, nbytes=0)
per_pe.append((pe_id, path, latency))
if latency > max_latency:
max_latency = latency
if getattr(request, "target_start_ns", None) is not None:
stamped_request = request
else:
stamped_request = dataclasses.replace(
request, target_start_ns=float(env.now) + max_latency,
)
# Fan out to each PE_CPU, using response-based aggregation
sub_txns: list[Transaction] = []
n_dispatched = 0
for pe_id, path, _lat in per_pe:
sub_txn = Transaction(
request=stamped_request, path=path, step=0,
nbytes=0, done=env.event(),
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_txns.append(sub_txn)
n_dispatched += 1
if n_dispatched == 0:
txn.done.succeed()
return
# Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response)
all_done = env.event()
self._pending[request.request_id] = (n_dispatched, 0, all_done)
self._parent_txns[request.request_id] = txn
# Wait for all PE_CPU responses via NOC
yield all_done
del self._parent_txns[request.request_id]
# Aggregate PE-internal metrics (max across PEs and across cubes).
# Multiple M_CPUs share the same result_data dict via IO_CPU fanout;
# merge against the existing value so cubes don't clobber each other.
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
if pe_exec_values:
cur = txn.result_data.get("pe_exec_ns", 0.0) or 0.0
txn.result_data["pe_exec_ns"] = max(cur, max(pe_exec_values))
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
if dma_values:
cur = txn.result_data.get("dma_ns", 0.0) or 0.0
txn.result_data["dma_ns"] = max(cur, max(dma_values))
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
if compute_values:
cur = txn.result_data.get("compute_ns", 0.0) or 0.0
txn.result_data["compute_ns"] = max(cur, max(compute_values))
# Send aggregate response on reverse command path back to IO_CPU
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
"""Return list of HBM destination node_ids for DMA fan-out.
With single hbm_ctrl per cube (ADR-0019), always returns one node.
PA-based resolution still used for cross-cube routing.
"""
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
# PA-based resolution: extract actual target from physical address
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
if pa_val is not None:
from kernbench.policy.address.phyaddr import PhysAddr
try:
pa = PhysAddr.decode(pa_val)
return [self.ctx.resolver.resolve(pa)]
except Exception:
pass
# Default: single hbm_ctrl in local cube
return [f"{cube_prefix}.hbm_ctrl"]
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
PE_MMU is a terminal node — completes the transaction directly.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_MMU
sub_dones: list[simpy.Event] = []
for pe_id in pe_ids:
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
# Wait for all PE_MMUs to complete
for sd in sub_dones:
yield sd
# Send aggregate response on reverse path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
return [target_pe]
if isinstance(target_pe, tuple):
return list(target_pe)
# "all": all PEs in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return list(range(n_slices))