Files
kernbench2/src/kernbench/sim_engine/engine.py
T
mukesh 14d800b0ae Kernel-launch sync (ADR-0009 D5) and IPCQ drain at inbound (ADR-0023)
- KernelLaunchMsg gains target_start_ns: IO_CPU stamps a global barrier
  (max path latency across every target PE), M_CPU passes it through,
  PE_CPU yields until it before recording pe_exec_start. Every PE in a
  launch begins kernel execution at the same env.now regardless of its
  dispatch path length — eliminates per-PE dispatch-offset artifact in
  cross-PE and cross-cube latency measurements.

- PE_DMA._handle_ipcq_inbound now pays Transaction.drain_ns at the top,
  matching the terminal-drain behavior of ComponentBase._forward_txn for
  every non-IPCQ Transaction. SRC-side tl.send stays fire-and-forget
  (sender doesn't yield on sub_done); tl.recv now blocks until bytes
  have actually drained into its inbox.

- ComponentContext: new compute_path_latency_ns helper + node_overhead_ns
  field populated by GraphEngine.

- tests/test_kernel_launch_sync.py: asserts all PEs in one launch
  produce identical pe_exec_ns for a no-op kernel (zero spread).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 15:30:29 -07:00

520 lines
22 KiB
Python

from __future__ import annotations
from typing import Any
import simpy
from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.builtin # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Edge, TopologyGraph
class GraphEngine:
"""simpy-based discrete-event simulation engine.
Request routing:
MemoryWrite/Read: pcie_ep → io_noc → cube → router mesh → hbm_ctrl (m_cpu bypass)
KernelLaunch: pcie_ep → io_noc → io_cpu → io_noc → cube → m_cpu → PE
PeDmaMsg: pe_dma → router mesh → hbm_ctrl (direct probe)
Component implementations are DI-injectable via component_overrides (ADR-0007 D3).
"""
def __init__(
self,
graph: TopologyGraph,
*,
component_overrides: dict[str, type[ComponentBase]] | None = None,
enable_data: bool = False,
) -> None:
self._env = simpy.Environment()
self._resolver = AddressResolver(graph)
self._router = PathRouter(graph)
self._nodes = graph.nodes
self._edge_map: dict[tuple[str, str], Edge] = {}
for e in graph.edges:
self._edge_map[(e.src, e.dst)] = e
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
self._results: dict[str, tuple[Completion, Trace]] = {}
self._events: dict[str, simpy.Event] = {}
self._counter = 0
overrides = component_overrides or {}
# ADR-0020: optional data execution support
self._op_logger = None
self._memory_store = None
if enable_data:
from kernbench.sim_engine.memory_store import MemoryStore
from kernbench.sim_engine.op_log import OpLogger
self._memory_store = MemoryStore()
self._op_logger = OpLogger(memory_store=self._memory_store)
# Cursor for incremental Phase 2 replay (ADR-0020 D6).
# SimPy env.now is monotonic so newly logged records always sort
# to the tail; the cursor remains valid across waits.
self._data_cursor = 0
ctx = ComponentContext(
router=self._router,
resolver=self._resolver,
positions={nid: n.pos_mm for nid, n in graph.nodes.items()},
ns_per_mm=self._ns_per_mm,
edge_map=self._edge_map,
spec=graph.spec,
memory_store=self._memory_store,
op_logger=self._op_logger,
node_overhead_ns={
nid: float(n.attrs.get("overhead_ns", 0.0))
for nid, n in graph.nodes.items()
},
)
self._components: dict[str, ComponentBase] = {
node_id: ComponentRegistry.create(node, overrides, ctx)
for node_id, node in graph.nodes.items()
}
# Wire ports: one Store per directed edge (ADR-0015 D1)
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
store: simpy.Store = simpy.Store(self._env)
src_comp.out_ports[e.dst] = store
dst_comp.in_ports[e.src] = store
# Wire processes: propagation delay + BW occupancy per edge (ADR-0015 D2)
# Cut-through (wormhole) model: wires apply propagation delay per hop.
# BW occupancy (available_at) tracks when each directed link becomes free
# for the next transaction, modeling back-to-back serialization contention.
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
prop_ns = e.distance_mm * self._ns_per_mm
bw_gbs = e.bw_gbs or 0.0
self._env.process(
self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src],
prop_ns, bw_gbs)
)
# Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3)
self._host_queues: dict[str, simpy.Store] = {}
for pcie_ep_id in self._resolver.find_all_pcie_eps():
host_q: simpy.Store = simpy.Store(self._env)
self._components[pcie_ep_id].in_ports["host"] = host_q
self._host_queues[pcie_ep_id] = host_q
# Attach host queues to PE_DMA nodes for direct PE DMA injection
self._pe_dma_queues: dict[str, simpy.Store] = {}
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
host_q = simpy.Store(self._env)
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
# Derive PE_MMU node ID from PE_DMA node ID
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
mmu_id = f"{pe_prefix}.pe_mmu"
mmu_comp = self._components.get(mmu_id)
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
self._components[node_id]._mmu = mmu_comp.mmu
# Inject op_logger into all components (ADR-0020 D2)
if self._op_logger:
for comp in self._components.values():
comp._op_logger = self._op_logger
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
@property
def op_log(self):
"""Op log records from Phase 1 (ADR-0020)."""
return self._op_logger.records if self._op_logger else []
@property
def memory_store(self):
"""MemoryStore from Phase 1 (ADR-0020)."""
return self._memory_store
def submit(self, request: Any) -> RequestHandle:
self._counter += 1
handle = RequestHandle(f"h{self._counter}")
event = self._env.event()
self._events[str(handle)] = event
self._env.process(self._process(str(handle), request, event))
return handle
def _flush_data_phase(self) -> None:
"""Replay newly recorded op_log entries through DataExecutor.
ADR-0020 D6 Phase 2: when data tracking is enabled, run DataExecutor
on records added since the last flush so that callers reading
MemoryStore between launches observe correct (compute-replayed)
tensor data.
Cursor-based incremental replay is necessary because Phase 2 is
NOT idempotent across full re-runs: a math op writes a TCM scratch
addr, a later dma_write copies that scratch into HBM[X], and an
even-later math op may then read HBM[X]. Re-running everything
from scratch would let the second pass's first math op read the
already-overwritten HBM[X] instead of the original input.
"""
if self._op_logger is None or self._memory_store is None:
return
records = self._op_logger.records # sorted by t_start (stable)
if self._data_cursor >= len(records):
return
new_records = records[self._data_cursor:]
from kernbench.sim_engine.data_executor import DataExecutor
DataExecutor(new_records, self._memory_store).run()
self._data_cursor = len(records)
def wait(self, handle: RequestHandle) -> None:
key = str(handle)
event = self._events[key]
if not event.triggered:
try:
self._env.run(until=event)
except (simpy.core.EmptySchedule, RuntimeError) as exc:
# SimPy raises EmptySchedule directly OR (in newer simpy)
# wraps it as a RuntimeError("No scheduled events left ...").
# Either case while our event is still pending → IPCQ deadlock.
msg = str(exc)
is_deadlock = (
isinstance(exc, simpy.core.EmptySchedule)
or "No scheduled events left" in msg
)
if not is_deadlock:
raise
from kernbench.ccl.diagnostics import IpcqDeadlock, pointer_dump
dump = pointer_dump(self)
if dump.strip():
raise IpcqDeadlock(
"IPCQ deadlock: simulation schedule empty while "
f"request {handle!r} is still pending.\n"
f"Pointer state:\n{dump}"
) from None
raise
# ADR-0020: replay newly logged ops so the caller observes
# post-Phase-2 tensor state from MemoryStore.
self._flush_data_phase()
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
def mmu_map(self, va: int, pa: int, size: int) -> None:
"""Sideband: install VA→PA mapping in all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_map_pe(
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
) -> None:
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
comp = self._components.get(mmu_id)
if comp is not None and hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_unmap(self, va: int, size: int) -> None:
"""Sideband: remove VA mapping from all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.unmap(va=va, size=size)
# ── internal ────────────────────────────────────────────────────
def _wire(
self,
out_port: simpy.Store,
in_port: simpy.Store,
prop_ns: float,
bw_gbs: float = 0.0,
):
"""SimPy process: relay messages with propagation delay and BW occupancy.
Each directed edge maintains an ``available_at`` timestamp tracking when
the link becomes free for the next transaction. When a transaction of
``nbytes`` uses a link with ``bw_gbs``, the link is occupied for
``nbytes / bw_gbs`` ns. The *next* transaction on the same directed
link must wait until ``available_at`` passes (back-to-back serialization).
The *current* transaction is NOT delayed by its own occupancy — only by
a prior transaction's occupancy that has not yet cleared. This avoids
double-drain: terminal drain_ns handles single-transaction serialization,
while available_at handles inter-transaction BW contention.
"""
available_at = 0.0
while True:
msg = yield out_port.get()
# BW occupancy: wait for link to become free, then mark busy
if bw_gbs > 0:
nbytes = getattr(msg, "nbytes", 0)
if nbytes > 0:
wait = available_at - self._env.now
if wait > 0:
yield self._env.timeout(wait)
available_at = self._env.now + (nbytes / bw_gbs)
# Propagation delay
if prop_ns > 0:
yield self._env.timeout(prop_ns)
yield in_port.put(msg)
def _process(self, key: str, request: Any, done: simpy.Event):
if isinstance(request, PeDmaMsg):
yield from self._process_pe_dma(key, request, done)
return
if isinstance(request, (MemoryWriteMsg, MemoryReadMsg)):
yield from self._process_memory_direct(key, request, done)
return
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
yield from self._process_mmu_msg(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
Completion(ok=True),
{"total_ns": 0.0, "nbytes": 0},
)
done.succeed()
return
start_ns = self._env.now
total_nbytes = 0
root_txn: Transaction | None = None
if len(entries) == 1:
# Single-SIP: direct inject (common path, no extra events)
pcie_ep_id, io_cpu_id, nbytes = entries[0]
total_nbytes = nbytes
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done)
root_txn = txn
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP: inject per SIP, aggregate completions (ADR-0007)
sub_dones: list[simpy.Event] = []
sub_txns: list[Transaction] = []
for pcie_ep_id, io_cpu_id, nbytes in entries:
total_nbytes = max(total_nbytes, nbytes)
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(
request=request, path=path, step=0,
nbytes=nbytes, done=txn_done,
)
yield self._host_queues[pcie_ep_id].put(txn)
sub_dones.append(txn_done)
sub_txns.append(txn)
for sd in sub_dones:
yield sd
# Aggregate pe_exec_ns from multi-SIP (max)
pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns]
pe_vals = [v for v in pe_vals if v is not None]
if pe_vals:
if root_txn is None:
root_txn = sub_txns[0]
root_txn.result_data["pe_exec_ns"] = max(pe_vals)
total_ns = self._env.now - start_ns
result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes}
if root_txn is not None and root_txn.result_data:
result_trace.update(root_txn.result_data)
self._results[key] = (
Completion(ok=True),
result_trace,
)
done.succeed()
def _process_memory_direct(self, key: str, request: Any, done: simpy.Event):
"""Direct memory path: pcie_ep → io_noc → cube → router mesh → hbm_ctrl.
MemoryWrite: data flows forward (nbytes on wires), drain at hbm_ctrl terminal.
MemoryRead: command flows forward (nbytes=0), hbm_ctrl sends data back on
reverse path with nbytes=request.nbytes.
"""
if isinstance(request, MemoryWriteMsg):
sip, pa_val = request.dst_sip, request.dst_pa
else:
sip, pa_val = request.src_sip, request.src_pa
pcie_ep_id = self._resolver.find_pcie_ep(sip)
pa = PhysAddr.decode(pa_val)
hbm_node = self._resolver.resolve(pa)
path = self._router.find_memory_path(pcie_ep_id, hbm_node)
drain_ns = self._path_drain_ns(path, request.nbytes)
start_ns = self._env.now
txn_done = self._env.event()
is_write = isinstance(request, MemoryWriteMsg)
txn = Transaction(
request=request, path=path, step=0,
nbytes=request.nbytes if is_write else 0,
done=txn_done, drain_ns=drain_ns,
)
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
total_ns = self._env.now - start_ns
self._results[key] = (
Completion(ok=True),
{"total_ns": total_ns, "nbytes": request.nbytes},
)
done.succeed()
def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event):
"""Inject a Transaction directly at PE_DMA for PE→HBM latency measurement."""
pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}"
pe_dma_id = f"{pe_prefix}.pe_dma"
pa = PhysAddr.decode(request.dst_pa)
dst_node = self._resolver.resolve(pa)
path = self._router.find_path(pe_prefix, dst_node)
drain_ns = self._path_drain_ns(path, request.nbytes)
start_ns = self._env.now
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
done=txn_done, drain_ns=drain_ns)
yield self._pe_dma_queues[pe_dma_id].put(txn)
yield txn_done
total_ns = self._env.now - start_ns
formula_ns = self._formula_latency(path, request.nbytes)
self._results[key] = (
Completion(ok=True),
{"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes},
)
done.succeed()
def _path_drain_ns(self, path: list[str], nbytes: int) -> float:
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
min_bw = float("inf")
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge and edge.bw_gbs:
min_bw = min(min_bw, edge.bw_gbs)
if min_bw == float("inf"):
return 0.0
return nbytes / min_bw
def _formula_latency(self, path: list[str], nbytes: int) -> float:
"""Lower-bound formula latency (ADR-0015 D7).
formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns
Phase 0: formula == actual (no contention).
Phase 1+: formula <= actual (contention adds queueing).
"""
total = 0.0
# Wire propagation delays
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge:
total += edge.distance_mm * self._ns_per_mm
# Component overhead_ns
for node_id in path:
node = self._nodes.get(node_id)
if node:
total += float(node.attrs.get("overhead_ns", 0.0))
# Drain
total += self._path_drain_ns(path, nbytes)
return total
def _entry_points(self, request: Any) -> list[tuple[str, str, int]]:
"""Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP.
Only handles KernelLaunchMsg. MemoryWrite/Read use _process_memory_direct.
"""
if isinstance(request, KernelLaunchMsg):
seen: set[int] = set()
entries: list[tuple[str, str, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip not in seen:
seen.add(shard.sip)
entries.append((
self._resolver.find_pcie_ep(shard.sip),
self._resolver.find_io_cpu(shard.sip),
shard.nbytes,
))
return entries
raise ValueError(f"unsupported request type: {type(request)}")
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
"""
start_ns = self._env.now
target_sips = getattr(request, "target_sips", "all")
# Determine target SIPs
sip_set: set[int] = set()
if target_sips == "all":
for ep_id in self._resolver.find_all_pcie_eps():
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
sip_set.add(sip_id)
else:
sip_set = set(target_sips)
entries = []
for sip_id in sorted(sip_set):
entries.append((
self._resolver.find_pcie_ep(sip_id),
self._resolver.find_io_cpu(sip_id),
0, # MmuMapMsg has no data payload
))
if not entries:
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
done.succeed()
return
if len(entries) == 1:
pcie_ep_id, io_cpu_id, _ = entries[0]
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP fan-out
sub_dones = []
for pcie_ep_id, io_cpu_id, _ in entries:
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
sub_done = self._env.event()
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
yield self._host_queues[pcie_ep_id].put(sub_txn)
sub_dones.append(sub_done)
for sd in sub_dones:
yield sd
elapsed = self._env.now - start_ns
self._results[key] = (
Completion(ok=True),
{"total_ns": elapsed, "msg_type": request.msg_type},
)
done.succeed()