from __future__ import annotations from typing import Any import simpy from kernbench.common.types import Completion, RequestHandle, Trace import kernbench.components.impls # noqa: F401 — registers built-in implementations from kernbench.components.base import ComponentBase, ComponentRegistry from kernbench.components.context import ComponentContext from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.routing.router import AddressResolver, PathRouter from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg from kernbench.sim_engine.transaction import Transaction from kernbench.topology.types import Edge, TopologyGraph class GraphEngine: """simpy-based discrete-event simulation engine. Phase B: engine injects a Transaction into the PCIE_EP host queue for each request. Components handle their own routing: Path 1: PCIE_EP → IO_CPU (engine-computed path, pre-loaded in Transaction) Path 2: IO_CPU → M_CPU (IO_CPU dispatches, fire-and-forget callback) Path 3: M_CPU.DMA → HBM (M_CPU dispatches, fire-and-forget callback) Component implementations are DI-injectable via component_overrides (ADR-0007 D3). """ def __init__( self, graph: TopologyGraph, *, component_overrides: dict[str, type[ComponentBase]] | None = None, ) -> None: self._env = simpy.Environment() self._resolver = AddressResolver(graph) self._router = PathRouter(graph) self._nodes = graph.nodes self._edge_map: dict[tuple[str, str], Edge] = {} for e in graph.edges: self._edge_map[(e.src, e.dst)] = e self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01) self._results: dict[str, tuple[Completion, Trace]] = {} self._events: dict[str, simpy.Event] = {} self._counter = 0 overrides = component_overrides or {} ctx = ComponentContext( router=self._router, resolver=self._resolver, positions={nid: n.pos_mm for nid, n in graph.nodes.items()}, ns_per_mm=self._ns_per_mm, edge_map=self._edge_map, spec=graph.spec, ) self._components: dict[str, ComponentBase] = { node_id: ComponentRegistry.create(node, overrides, ctx) for node_id, node in graph.nodes.items() } # Wire ports: one Store per directed edge (ADR-0015 D1) for e in graph.edges: src_comp = self._components.get(e.src) dst_comp = self._components.get(e.dst) if src_comp is None or dst_comp is None: continue store: simpy.Store = simpy.Store(self._env) src_comp.out_ports[e.dst] = store dst_comp.in_ports[e.src] = store # Wire processes: propagation delay per edge (ADR-0015 D2) # Cut-through (wormhole) model: wires apply propagation only. # Serialization (drain) is computed per-path and applied once at the terminal. for e in graph.edges: src_comp = self._components.get(e.src) dst_comp = self._components.get(e.dst) if src_comp is None or dst_comp is None: continue prop_ns = e.distance_mm * self._ns_per_mm self._env.process( self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src], prop_ns) ) # Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3) self._host_queues: dict[str, simpy.Store] = {} for pcie_ep_id in self._resolver.find_all_pcie_eps(): host_q: simpy.Store = simpy.Store(self._env) self._components[pcie_ep_id].in_ports["host"] = host_q self._host_queues[pcie_ep_id] = host_q # Attach host queues to PE_DMA nodes for direct PE DMA injection self._pe_dma_queues: dict[str, simpy.Store] = {} for node_id, node in graph.nodes.items(): if node.kind == "pe_dma": host_q = simpy.Store(self._env) self._components[node_id].in_ports["host"] = host_q self._pe_dma_queues[node_id] = host_q # Start components after all ports are wired (ADR-0015 D3) for comp in self._components.values(): comp.start(self._env) def submit(self, request: Any) -> RequestHandle: self._counter += 1 handle = RequestHandle(f"h{self._counter}") event = self._env.event() self._events[str(handle)] = event self._env.process(self._process(str(handle), request, event)) return handle def wait(self, handle: RequestHandle) -> None: key = str(handle) event = self._events[key] if not event.triggered: self._env.run(until=event) def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: return self._results[str(handle)] # ── internal ──────────────────────────────────────────────────── def _wire( self, out_port: simpy.Store, in_port: simpy.Store, prop_ns: float, ): """SimPy process: relay messages with propagation delay only. Cut-through (wormhole) model: serialization (drain) is computed per-path and applied once at the terminal component, not at every wire hop. """ while True: msg = yield out_port.get() if prop_ns > 0: yield self._env.timeout(prop_ns) yield in_port.put(msg) def _process(self, key: str, request: Any, done: simpy.Event): if isinstance(request, PeDmaMsg): yield from self._process_pe_dma(key, request, done) return entries = self._entry_points(request) if not entries: self._results[key] = ( Completion(ok=True), {"total_ns": 0.0, "nbytes": 0}, ) done.succeed() return start_ns = self._env.now total_nbytes = 0 root_txn: Transaction | None = None if len(entries) == 1: # Single-SIP: direct inject (common path, no extra events) pcie_ep_id, io_cpu_id, nbytes = entries[0] total_nbytes = nbytes path = self._router.find_node_path(pcie_ep_id, io_cpu_id) txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done) root_txn = txn yield self._host_queues[pcie_ep_id].put(txn) yield txn_done else: # Multi-SIP: inject per SIP, aggregate completions (ADR-0007) sub_dones: list[simpy.Event] = [] sub_txns: list[Transaction] = [] for pcie_ep_id, io_cpu_id, nbytes in entries: total_nbytes = max(total_nbytes, nbytes) path = self._router.find_node_path(pcie_ep_id, io_cpu_id) txn_done = self._env.event() txn = Transaction( request=request, path=path, step=0, nbytes=nbytes, done=txn_done, ) yield self._host_queues[pcie_ep_id].put(txn) sub_dones.append(txn_done) sub_txns.append(txn) for sd in sub_dones: yield sd # Aggregate pe_exec_ns from multi-SIP (max) pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns] pe_vals = [v for v in pe_vals if v is not None] if pe_vals: if root_txn is None: root_txn = sub_txns[0] root_txn.result_data["pe_exec_ns"] = max(pe_vals) total_ns = self._env.now - start_ns result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes} if root_txn is not None and root_txn.result_data: result_trace.update(root_txn.result_data) self._results[key] = ( Completion(ok=True), result_trace, ) done.succeed() def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event): """Inject a Transaction directly at PE_DMA for PE→HBM latency measurement.""" pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}" pe_dma_id = f"{pe_prefix}.pe_dma" pa = PhysAddr.decode(request.dst_pa) dst_node = self._resolver.resolve(pa) path = self._router.find_path(pe_prefix, dst_node) drain_ns = self._path_drain_ns(path, request.nbytes) start_ns = self._env.now txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes, done=txn_done, drain_ns=drain_ns) yield self._pe_dma_queues[pe_dma_id].put(txn) yield txn_done total_ns = self._env.now - start_ns formula_ns = self._formula_latency(path, request.nbytes) self._results[key] = ( Completion(ok=True), {"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes}, ) done.succeed() def _path_drain_ns(self, path: list[str], nbytes: int) -> float: """Wormhole drain time: nbytes / bottleneck_bw along path.""" min_bw = float("inf") for i in range(len(path) - 1): edge = self._edge_map.get((path[i], path[i + 1])) if edge and edge.bw_gbs: min_bw = min(min_bw, edge.bw_gbs) if min_bw == float("inf"): return 0.0 return nbytes / min_bw def _formula_latency(self, path: list[str], nbytes: int) -> float: """Lower-bound formula latency (ADR-0015 D7). formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns Phase 0: formula == actual (no contention). Phase 1+: formula <= actual (contention adds queueing). """ total = 0.0 # Wire propagation delays for i in range(len(path) - 1): edge = self._edge_map.get((path[i], path[i + 1])) if edge: total += edge.distance_mm * self._ns_per_mm # Component overhead_ns for node_id in path: node = self._nodes.get(node_id) if node: total += float(node.attrs.get("overhead_ns", 0.0)) # Drain total += self._path_drain_ns(path, nbytes) return total def _entry_points(self, request: Any) -> list[tuple[str, str, int]]: """Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP. For Memory{Write,Read}: single SIP entry. For KernelLaunchMsg: one entry per distinct SIP in tensor shards. """ if isinstance(request, MemoryWriteMsg): sip = request.dst_sip return [( self._resolver.find_pcie_ep(sip), self._resolver.find_io_cpu(sip), request.nbytes, )] if isinstance(request, MemoryReadMsg): sip = request.src_sip return [( self._resolver.find_pcie_ep(sip), self._resolver.find_io_cpu(sip), request.nbytes, )] if isinstance(request, KernelLaunchMsg): seen: set[int] = set() entries: list[tuple[str, str, int]] = [] for arg in request.args: if arg.arg_kind != "tensor": continue for shard in arg.shards: if shard.sip not in seen: seen.add(shard.sip) entries.append(( self._resolver.find_pcie_ep(shard.sip), self._resolver.find_io_cpu(shard.sip), shard.nbytes, )) return entries raise ValueError(f"unsupported request type: {type(request)}")