from __future__ import annotations from typing import Any import simpy from kernbench.common.types import Completion, RequestHandle, Trace import kernbench.components.builtin # noqa: F401 — registers built-in implementations from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.context import ComponentContext from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.routing.router import AddressResolver, PathRouter from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg from kernbench.sim_engine.transaction import Transaction from kernbench.topology.types import Edge, TopologyGraph class GraphEngine: """simpy-based discrete-event simulation engine. Request routing: MemoryWrite/Read: pcie_ep → io_noc → cube → router mesh → hbm_ctrl (m_cpu bypass) KernelLaunch: pcie_ep → io_noc → io_cpu → io_noc → cube → m_cpu → PE PeDmaMsg: pe_dma → router mesh → hbm_ctrl (direct probe) Component implementations are DI-injectable via component_overrides (ADR-0007 D3). """ def __init__( self, graph: TopologyGraph, *, component_overrides: dict[str, type[ComponentBase]] | None = None, enable_data: bool = False, ) -> None: self._env = simpy.Environment() self._resolver = AddressResolver(graph) self._router = PathRouter(graph) self._nodes = graph.nodes self._edge_map: dict[tuple[str, str], Edge] = {} for e in graph.edges: self._edge_map[(e.src, e.dst)] = e self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01) self._results: dict[str, tuple[Completion, Trace]] = {} self._events: dict[str, simpy.Event] = {} self._counter = 0 overrides = component_overrides or {} # ADR-0020: optional data execution support self._op_logger = None self._memory_store = None if enable_data: from kernbench.sim_engine.memory_store import MemoryStore from kernbench.sim_engine.op_log import OpLogger self._op_logger = OpLogger() self._memory_store = MemoryStore() ctx = ComponentContext( router=self._router, resolver=self._resolver, positions={nid: n.pos_mm for nid, n in graph.nodes.items()}, ns_per_mm=self._ns_per_mm, edge_map=self._edge_map, spec=graph.spec, memory_store=self._memory_store, op_logger=self._op_logger, ) self._components: dict[str, ComponentBase] = { node_id: ComponentRegistry.create(node, overrides, ctx) for node_id, node in graph.nodes.items() } # Wire ports: one Store per directed edge (ADR-0015 D1) for e in graph.edges: src_comp = self._components.get(e.src) dst_comp = self._components.get(e.dst) if src_comp is None or dst_comp is None: continue store: simpy.Store = simpy.Store(self._env) src_comp.out_ports[e.dst] = store dst_comp.in_ports[e.src] = store # Wire processes: propagation delay + BW occupancy per edge (ADR-0015 D2) # Cut-through (wormhole) model: wires apply propagation delay per hop. # BW occupancy (available_at) tracks when each directed link becomes free # for the next transaction, modeling back-to-back serialization contention. for e in graph.edges: src_comp = self._components.get(e.src) dst_comp = self._components.get(e.dst) if src_comp is None or dst_comp is None: continue prop_ns = e.distance_mm * self._ns_per_mm bw_gbs = e.bw_gbs or 0.0 self._env.process( self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src], prop_ns, bw_gbs) ) # Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3) self._host_queues: dict[str, simpy.Store] = {} for pcie_ep_id in self._resolver.find_all_pcie_eps(): host_q: simpy.Store = simpy.Store(self._env) self._components[pcie_ep_id].in_ports["host"] = host_q self._host_queues[pcie_ep_id] = host_q # Attach host queues to PE_DMA nodes for direct PE DMA injection self._pe_dma_queues: dict[str, simpy.Store] = {} for node_id, node in graph.nodes.items(): if node.kind == "pe_dma": host_q = simpy.Store(self._env) self._components[node_id].in_ports["host"] = host_q self._pe_dma_queues[node_id] = host_q # Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object for node_id, node in graph.nodes.items(): if node.kind == "pe_dma": # Derive PE_MMU node ID from PE_DMA node ID pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0" mmu_id = f"{pe_prefix}.pe_mmu" mmu_comp = self._components.get(mmu_id) if mmu_comp is not None and hasattr(mmu_comp, "mmu"): self._components[node_id]._mmu = mmu_comp.mmu # Inject op_logger into all components (ADR-0020 D2) if self._op_logger: for comp in self._components.values(): comp._op_logger = self._op_logger # Start components after all ports are wired (ADR-0015 D3) for comp in self._components.values(): comp.start(self._env) @property def op_log(self): """Op log records from Phase 1 (ADR-0020).""" return self._op_logger.records if self._op_logger else [] @property def memory_store(self): """MemoryStore from Phase 1 (ADR-0020).""" return self._memory_store def submit(self, request: Any) -> RequestHandle: self._counter += 1 handle = RequestHandle(f"h{self._counter}") event = self._env.event() self._events[str(handle)] = event self._env.process(self._process(str(handle), request, event)) return handle def wait(self, handle: RequestHandle) -> None: key = str(handle) event = self._events[key] if not event.triggered: self._env.run(until=event) def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: return self._results[str(handle)] def mmu_map(self, va: int, pa: int, size: int) -> None: """Sideband: install VA→PA mapping in all PE_MMU components.""" for node_id, comp in self._components.items(): if hasattr(comp, "mmu"): comp.mmu.map(va=va, pa=pa, size=size) def mmu_map_pe( self, sip: int, cube: int, pe: int, va: int, pa: int, size: int, ) -> None: """Sideband: install VA→PA mapping in a specific PE's MMU only.""" mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu" comp = self._components.get(mmu_id) if comp is not None and hasattr(comp, "mmu"): comp.mmu.map(va=va, pa=pa, size=size) def mmu_unmap(self, va: int, size: int) -> None: """Sideband: remove VA mapping from all PE_MMU components.""" for node_id, comp in self._components.items(): if hasattr(comp, "mmu"): comp.mmu.unmap(va=va, size=size) # ── internal ──────────────────────────────────────────────────── def _wire( self, out_port: simpy.Store, in_port: simpy.Store, prop_ns: float, bw_gbs: float = 0.0, ): """SimPy process: relay messages with propagation delay and BW occupancy. Each directed edge maintains an ``available_at`` timestamp tracking when the link becomes free for the next transaction. When a transaction of ``nbytes`` uses a link with ``bw_gbs``, the link is occupied for ``nbytes / bw_gbs`` ns. The *next* transaction on the same directed link must wait until ``available_at`` passes (back-to-back serialization). The *current* transaction is NOT delayed by its own occupancy — only by a prior transaction's occupancy that has not yet cleared. This avoids double-drain: terminal drain_ns handles single-transaction serialization, while available_at handles inter-transaction BW contention. """ available_at = 0.0 while True: msg = yield out_port.get() # BW occupancy: wait for link to become free, then mark busy if bw_gbs > 0: nbytes = getattr(msg, "nbytes", 0) if nbytes > 0: wait = available_at - self._env.now if wait > 0: yield self._env.timeout(wait) available_at = self._env.now + (nbytes / bw_gbs) # Propagation delay if prop_ns > 0: yield self._env.timeout(prop_ns) yield in_port.put(msg) def _process(self, key: str, request: Any, done: simpy.Event): if isinstance(request, PeDmaMsg): yield from self._process_pe_dma(key, request, done) return if isinstance(request, (MemoryWriteMsg, MemoryReadMsg)): yield from self._process_memory_direct(key, request, done) return from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg if isinstance(request, (MmuMapMsg, MmuUnmapMsg)): yield from self._process_mmu_msg(key, request, done) return entries = self._entry_points(request) if not entries: self._results[key] = ( Completion(ok=True), {"total_ns": 0.0, "nbytes": 0}, ) done.succeed() return start_ns = self._env.now total_nbytes = 0 root_txn: Transaction | None = None if len(entries) == 1: # Single-SIP: direct inject (common path, no extra events) pcie_ep_id, io_cpu_id, nbytes = entries[0] total_nbytes = nbytes path = self._router.find_node_path(pcie_ep_id, io_cpu_id) txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done) root_txn = txn yield self._host_queues[pcie_ep_id].put(txn) yield txn_done else: # Multi-SIP: inject per SIP, aggregate completions (ADR-0007) sub_dones: list[simpy.Event] = [] sub_txns: list[Transaction] = [] for pcie_ep_id, io_cpu_id, nbytes in entries: total_nbytes = max(total_nbytes, nbytes) path = self._router.find_node_path(pcie_ep_id, io_cpu_id) txn_done = self._env.event() txn = Transaction( request=request, path=path, step=0, nbytes=nbytes, done=txn_done, ) yield self._host_queues[pcie_ep_id].put(txn) sub_dones.append(txn_done) sub_txns.append(txn) for sd in sub_dones: yield sd # Aggregate pe_exec_ns from multi-SIP (max) pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns] pe_vals = [v for v in pe_vals if v is not None] if pe_vals: if root_txn is None: root_txn = sub_txns[0] root_txn.result_data["pe_exec_ns"] = max(pe_vals) total_ns = self._env.now - start_ns result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes} if root_txn is not None and root_txn.result_data: result_trace.update(root_txn.result_data) self._results[key] = ( Completion(ok=True), result_trace, ) done.succeed() def _process_memory_direct(self, key: str, request: Any, done: simpy.Event): """Direct memory path: pcie_ep → io_noc → cube → router mesh → hbm_ctrl. MemoryWrite: data flows forward (nbytes on wires), drain at hbm_ctrl terminal. MemoryRead: command flows forward (nbytes=0), hbm_ctrl sends data back on reverse path with nbytes=request.nbytes. """ if isinstance(request, MemoryWriteMsg): sip, pa_val = request.dst_sip, request.dst_pa else: sip, pa_val = request.src_sip, request.src_pa pcie_ep_id = self._resolver.find_pcie_ep(sip) pa = PhysAddr.decode(pa_val) hbm_node = self._resolver.resolve(pa) path = self._router.find_memory_path(pcie_ep_id, hbm_node) drain_ns = self._path_drain_ns(path, request.nbytes) start_ns = self._env.now txn_done = self._env.event() is_write = isinstance(request, MemoryWriteMsg) txn = Transaction( request=request, path=path, step=0, nbytes=request.nbytes if is_write else 0, done=txn_done, drain_ns=drain_ns, ) yield self._host_queues[pcie_ep_id].put(txn) yield txn_done total_ns = self._env.now - start_ns self._results[key] = ( Completion(ok=True), {"total_ns": total_ns, "nbytes": request.nbytes}, ) done.succeed() def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event): """Inject a Transaction directly at PE_DMA for PE→HBM latency measurement.""" pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}" pe_dma_id = f"{pe_prefix}.pe_dma" pa = PhysAddr.decode(request.dst_pa) dst_node = self._resolver.resolve(pa) path = self._router.find_path(pe_prefix, dst_node) drain_ns = self._path_drain_ns(path, request.nbytes) start_ns = self._env.now txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes, done=txn_done, drain_ns=drain_ns) yield self._pe_dma_queues[pe_dma_id].put(txn) yield txn_done total_ns = self._env.now - start_ns formula_ns = self._formula_latency(path, request.nbytes) self._results[key] = ( Completion(ok=True), {"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes}, ) done.succeed() def _path_drain_ns(self, path: list[str], nbytes: int) -> float: """Wormhole drain time: nbytes / bottleneck_bw along path.""" min_bw = float("inf") for i in range(len(path) - 1): edge = self._edge_map.get((path[i], path[i + 1])) if edge and edge.bw_gbs: min_bw = min(min_bw, edge.bw_gbs) if min_bw == float("inf"): return 0.0 return nbytes / min_bw def _formula_latency(self, path: list[str], nbytes: int) -> float: """Lower-bound formula latency (ADR-0015 D7). formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns Phase 0: formula == actual (no contention). Phase 1+: formula <= actual (contention adds queueing). """ total = 0.0 # Wire propagation delays for i in range(len(path) - 1): edge = self._edge_map.get((path[i], path[i + 1])) if edge: total += edge.distance_mm * self._ns_per_mm # Component overhead_ns for node_id in path: node = self._nodes.get(node_id) if node: total += float(node.attrs.get("overhead_ns", 0.0)) # Drain total += self._path_drain_ns(path, nbytes) return total def _entry_points(self, request: Any) -> list[tuple[str, str, int]]: """Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP. Only handles KernelLaunchMsg. MemoryWrite/Read use _process_memory_direct. """ if isinstance(request, KernelLaunchMsg): seen: set[int] = set() entries: list[tuple[str, str, int]] = [] for arg in request.args: if arg.arg_kind != "tensor": continue for shard in arg.shards: if shard.sip not in seen: seen.add(shard.sip) entries.append(( self._resolver.find_pcie_ep(shard.sip), self._resolver.find_io_cpu(shard.sip), shard.nbytes, )) return entries raise ValueError(f"unsupported request type: {type(request)}") def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event): """Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg. Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU """ start_ns = self._env.now target_sips = getattr(request, "target_sips", "all") # Determine target SIPs sip_set: set[int] = set() if target_sips == "all": for ep_id in self._resolver.find_all_pcie_eps(): sip_id = int(ep_id.split(".")[0].replace("sip", "")) sip_set.add(sip_id) else: sip_set = set(target_sips) entries = [] for sip_id in sorted(sip_set): entries.append(( self._resolver.find_pcie_ep(sip_id), self._resolver.find_io_cpu(sip_id), 0, # MmuMapMsg has no data payload )) if not entries: self._results[key] = (Completion(ok=True), {"total_ns": 0.0}) done.succeed() return if len(entries) == 1: pcie_ep_id, io_cpu_id, _ = entries[0] path = self._router.find_node_path(pcie_ep_id, io_cpu_id) txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done) yield self._host_queues[pcie_ep_id].put(txn) yield txn_done else: # Multi-SIP fan-out sub_dones = [] for pcie_ep_id, io_cpu_id, _ in entries: path = self._router.find_node_path(pcie_ep_id, io_cpu_id) sub_done = self._env.event() sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done) yield self._host_queues[pcie_ep_id].put(sub_txn) sub_dones.append(sub_done) for sd in sub_dones: yield sd elapsed = self._env.now - start_ns self._results[key] = ( Completion(ok=True), {"total_ns": elapsed, "msg_type": request.msg_type}, ) done.succeed()