Files
kernbench2/src/kernbench/sim_engine/engine.py
T
2026-03-18 11:47:48 -07:00

299 lines
12 KiB
Python

from __future__ import annotations
from typing import Any
import simpy
from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.impls # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Edge, TopologyGraph
class GraphEngine:
"""simpy-based discrete-event simulation engine.
Phase B: engine injects a Transaction into the PCIE_EP host queue for
each request. Components handle their own routing:
Path 1: PCIE_EP → IO_CPU (engine-computed path, pre-loaded in Transaction)
Path 2: IO_CPU → M_CPU (IO_CPU dispatches, fire-and-forget callback)
Path 3: M_CPU.DMA → HBM (M_CPU dispatches, fire-and-forget callback)
Component implementations are DI-injectable via component_overrides (ADR-0007 D3).
"""
def __init__(
self,
graph: TopologyGraph,
*,
component_overrides: dict[str, type[ComponentBase]] | None = None,
) -> None:
self._env = simpy.Environment()
self._resolver = AddressResolver(graph)
self._router = PathRouter(graph)
self._nodes = graph.nodes
self._edge_map: dict[tuple[str, str], Edge] = {}
for e in graph.edges:
self._edge_map[(e.src, e.dst)] = e
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
self._results: dict[str, tuple[Completion, Trace]] = {}
self._events: dict[str, simpy.Event] = {}
self._counter = 0
overrides = component_overrides or {}
ctx = ComponentContext(
router=self._router,
resolver=self._resolver,
positions={nid: n.pos_mm for nid, n in graph.nodes.items()},
ns_per_mm=self._ns_per_mm,
edge_map=self._edge_map,
spec=graph.spec,
)
self._components: dict[str, ComponentBase] = {
node_id: ComponentRegistry.create(node, overrides, ctx)
for node_id, node in graph.nodes.items()
}
# Wire ports: one Store per directed edge (ADR-0015 D1)
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
store: simpy.Store = simpy.Store(self._env)
src_comp.out_ports[e.dst] = store
dst_comp.in_ports[e.src] = store
# Wire processes: propagation delay per edge (ADR-0015 D2)
# Cut-through (wormhole) model: wires apply propagation only.
# Serialization (drain) is computed per-path and applied once at the terminal.
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
prop_ns = e.distance_mm * self._ns_per_mm
self._env.process(
self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src],
prop_ns)
)
# Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3)
self._host_queues: dict[str, simpy.Store] = {}
for pcie_ep_id in self._resolver.find_all_pcie_eps():
host_q: simpy.Store = simpy.Store(self._env)
self._components[pcie_ep_id].in_ports["host"] = host_q
self._host_queues[pcie_ep_id] = host_q
# Attach host queues to PE_DMA nodes for direct PE DMA injection
self._pe_dma_queues: dict[str, simpy.Store] = {}
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
host_q = simpy.Store(self._env)
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
def submit(self, request: Any) -> RequestHandle:
self._counter += 1
handle = RequestHandle(f"h{self._counter}")
event = self._env.event()
self._events[str(handle)] = event
self._env.process(self._process(str(handle), request, event))
return handle
def wait(self, handle: RequestHandle) -> None:
key = str(handle)
event = self._events[key]
if not event.triggered:
self._env.run(until=event)
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
# ── internal ────────────────────────────────────────────────────
def _wire(
self,
out_port: simpy.Store,
in_port: simpy.Store,
prop_ns: float,
):
"""SimPy process: relay messages with propagation delay only.
Cut-through (wormhole) model: serialization (drain) is computed per-path
and applied once at the terminal component, not at every wire hop.
"""
while True:
msg = yield out_port.get()
if prop_ns > 0:
yield self._env.timeout(prop_ns)
yield in_port.put(msg)
def _process(self, key: str, request: Any, done: simpy.Event):
if isinstance(request, PeDmaMsg):
yield from self._process_pe_dma(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
Completion(ok=True),
{"total_ns": 0.0, "nbytes": 0},
)
done.succeed()
return
start_ns = self._env.now
total_nbytes = 0
root_txn: Transaction | None = None
if len(entries) == 1:
# Single-SIP: direct inject (common path, no extra events)
pcie_ep_id, io_cpu_id, nbytes = entries[0]
total_nbytes = nbytes
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done)
root_txn = txn
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP: inject per SIP, aggregate completions (ADR-0007)
sub_dones: list[simpy.Event] = []
sub_txns: list[Transaction] = []
for pcie_ep_id, io_cpu_id, nbytes in entries:
total_nbytes = max(total_nbytes, nbytes)
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(
request=request, path=path, step=0,
nbytes=nbytes, done=txn_done,
)
yield self._host_queues[pcie_ep_id].put(txn)
sub_dones.append(txn_done)
sub_txns.append(txn)
for sd in sub_dones:
yield sd
# Aggregate pe_exec_ns from multi-SIP (max)
pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns]
pe_vals = [v for v in pe_vals if v is not None]
if pe_vals:
if root_txn is None:
root_txn = sub_txns[0]
root_txn.result_data["pe_exec_ns"] = max(pe_vals)
total_ns = self._env.now - start_ns
result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes}
if root_txn is not None and root_txn.result_data:
result_trace.update(root_txn.result_data)
self._results[key] = (
Completion(ok=True),
result_trace,
)
done.succeed()
def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event):
"""Inject a Transaction directly at PE_DMA for PE→HBM latency measurement."""
pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}"
pe_dma_id = f"{pe_prefix}.pe_dma"
pa = PhysAddr.decode(request.dst_pa)
dst_node = self._resolver.resolve(pa)
path = self._router.find_path(pe_prefix, dst_node)
drain_ns = self._path_drain_ns(path, request.nbytes)
start_ns = self._env.now
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
done=txn_done, drain_ns=drain_ns)
yield self._pe_dma_queues[pe_dma_id].put(txn)
yield txn_done
total_ns = self._env.now - start_ns
formula_ns = self._formula_latency(path, request.nbytes)
self._results[key] = (
Completion(ok=True),
{"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes},
)
done.succeed()
def _path_drain_ns(self, path: list[str], nbytes: int) -> float:
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
min_bw = float("inf")
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge and edge.bw_gbs:
min_bw = min(min_bw, edge.bw_gbs)
if min_bw == float("inf"):
return 0.0
return nbytes / min_bw
def _formula_latency(self, path: list[str], nbytes: int) -> float:
"""Lower-bound formula latency (ADR-0015 D7).
formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns
Phase 0: formula == actual (no contention).
Phase 1+: formula <= actual (contention adds queueing).
"""
total = 0.0
# Wire propagation delays
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge:
total += edge.distance_mm * self._ns_per_mm
# Component overhead_ns
for node_id in path:
node = self._nodes.get(node_id)
if node:
total += float(node.attrs.get("overhead_ns", 0.0))
# Drain
total += self._path_drain_ns(path, nbytes)
return total
def _entry_points(self, request: Any) -> list[tuple[str, str, int]]:
"""Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP.
For Memory{Write,Read}: single SIP entry.
For KernelLaunchMsg: one entry per distinct SIP in tensor shards.
"""
if isinstance(request, MemoryWriteMsg):
sip = request.dst_sip
return [(
self._resolver.find_pcie_ep(sip),
self._resolver.find_io_cpu(sip),
request.nbytes,
)]
if isinstance(request, MemoryReadMsg):
sip = request.src_sip
return [(
self._resolver.find_pcie_ep(sip),
self._resolver.find_io_cpu(sip),
request.nbytes,
)]
if isinstance(request, KernelLaunchMsg):
seen: set[int] = set()
entries: list[tuple[str, str, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip not in seen:
seen.add(shard.sip)
entries.append((
self._resolver.find_pcie_ep(shard.sip),
self._resolver.find_io_cpu(shard.sip),
shard.nbytes,
))
return entries
raise ValueError(f"unsupported request type: {type(request)}")