299 lines
12 KiB
Python
299 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.common.types import Completion, RequestHandle, Trace
|
|
import kernbench.components.impls # noqa: F401 — registers built-in implementations
|
|
from kernbench.components.base import ComponentBase, ComponentRegistry
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.policy.address.phyaddr import PhysAddr
|
|
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
|
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
|
|
from kernbench.sim_engine.transaction import Transaction
|
|
from kernbench.topology.types import Edge, TopologyGraph
|
|
|
|
|
|
class GraphEngine:
|
|
"""simpy-based discrete-event simulation engine.
|
|
|
|
Phase B: engine injects a Transaction into the PCIE_EP host queue for
|
|
each request. Components handle their own routing:
|
|
Path 1: PCIE_EP → IO_CPU (engine-computed path, pre-loaded in Transaction)
|
|
Path 2: IO_CPU → M_CPU (IO_CPU dispatches, fire-and-forget callback)
|
|
Path 3: M_CPU.DMA → HBM (M_CPU dispatches, fire-and-forget callback)
|
|
|
|
Component implementations are DI-injectable via component_overrides (ADR-0007 D3).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
graph: TopologyGraph,
|
|
*,
|
|
component_overrides: dict[str, type[ComponentBase]] | None = None,
|
|
) -> None:
|
|
self._env = simpy.Environment()
|
|
self._resolver = AddressResolver(graph)
|
|
self._router = PathRouter(graph)
|
|
self._nodes = graph.nodes
|
|
self._edge_map: dict[tuple[str, str], Edge] = {}
|
|
for e in graph.edges:
|
|
self._edge_map[(e.src, e.dst)] = e
|
|
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
|
self._results: dict[str, tuple[Completion, Trace]] = {}
|
|
self._events: dict[str, simpy.Event] = {}
|
|
self._counter = 0
|
|
overrides = component_overrides or {}
|
|
ctx = ComponentContext(
|
|
router=self._router,
|
|
resolver=self._resolver,
|
|
positions={nid: n.pos_mm for nid, n in graph.nodes.items()},
|
|
ns_per_mm=self._ns_per_mm,
|
|
edge_map=self._edge_map,
|
|
spec=graph.spec,
|
|
)
|
|
self._components: dict[str, ComponentBase] = {
|
|
node_id: ComponentRegistry.create(node, overrides, ctx)
|
|
for node_id, node in graph.nodes.items()
|
|
}
|
|
|
|
# Wire ports: one Store per directed edge (ADR-0015 D1)
|
|
for e in graph.edges:
|
|
src_comp = self._components.get(e.src)
|
|
dst_comp = self._components.get(e.dst)
|
|
if src_comp is None or dst_comp is None:
|
|
continue
|
|
store: simpy.Store = simpy.Store(self._env)
|
|
src_comp.out_ports[e.dst] = store
|
|
dst_comp.in_ports[e.src] = store
|
|
|
|
# Wire processes: propagation delay per edge (ADR-0015 D2)
|
|
# Cut-through (wormhole) model: wires apply propagation only.
|
|
# Serialization (drain) is computed per-path and applied once at the terminal.
|
|
for e in graph.edges:
|
|
src_comp = self._components.get(e.src)
|
|
dst_comp = self._components.get(e.dst)
|
|
if src_comp is None or dst_comp is None:
|
|
continue
|
|
prop_ns = e.distance_mm * self._ns_per_mm
|
|
self._env.process(
|
|
self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src],
|
|
prop_ns)
|
|
)
|
|
|
|
# Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3)
|
|
self._host_queues: dict[str, simpy.Store] = {}
|
|
for pcie_ep_id in self._resolver.find_all_pcie_eps():
|
|
host_q: simpy.Store = simpy.Store(self._env)
|
|
self._components[pcie_ep_id].in_ports["host"] = host_q
|
|
self._host_queues[pcie_ep_id] = host_q
|
|
|
|
# Attach host queues to PE_DMA nodes for direct PE DMA injection
|
|
self._pe_dma_queues: dict[str, simpy.Store] = {}
|
|
for node_id, node in graph.nodes.items():
|
|
if node.kind == "pe_dma":
|
|
host_q = simpy.Store(self._env)
|
|
self._components[node_id].in_ports["host"] = host_q
|
|
self._pe_dma_queues[node_id] = host_q
|
|
|
|
# Start components after all ports are wired (ADR-0015 D3)
|
|
for comp in self._components.values():
|
|
comp.start(self._env)
|
|
|
|
def submit(self, request: Any) -> RequestHandle:
|
|
self._counter += 1
|
|
handle = RequestHandle(f"h{self._counter}")
|
|
event = self._env.event()
|
|
self._events[str(handle)] = event
|
|
self._env.process(self._process(str(handle), request, event))
|
|
return handle
|
|
|
|
def wait(self, handle: RequestHandle) -> None:
|
|
key = str(handle)
|
|
event = self._events[key]
|
|
if not event.triggered:
|
|
self._env.run(until=event)
|
|
|
|
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
|
return self._results[str(handle)]
|
|
|
|
# ── internal ────────────────────────────────────────────────────
|
|
|
|
def _wire(
|
|
self,
|
|
out_port: simpy.Store,
|
|
in_port: simpy.Store,
|
|
prop_ns: float,
|
|
):
|
|
"""SimPy process: relay messages with propagation delay only.
|
|
|
|
Cut-through (wormhole) model: serialization (drain) is computed per-path
|
|
and applied once at the terminal component, not at every wire hop.
|
|
"""
|
|
while True:
|
|
msg = yield out_port.get()
|
|
if prop_ns > 0:
|
|
yield self._env.timeout(prop_ns)
|
|
yield in_port.put(msg)
|
|
|
|
def _process(self, key: str, request: Any, done: simpy.Event):
|
|
if isinstance(request, PeDmaMsg):
|
|
yield from self._process_pe_dma(key, request, done)
|
|
return
|
|
|
|
entries = self._entry_points(request)
|
|
if not entries:
|
|
self._results[key] = (
|
|
Completion(ok=True),
|
|
{"total_ns": 0.0, "nbytes": 0},
|
|
)
|
|
done.succeed()
|
|
return
|
|
|
|
start_ns = self._env.now
|
|
total_nbytes = 0
|
|
|
|
root_txn: Transaction | None = None
|
|
if len(entries) == 1:
|
|
# Single-SIP: direct inject (common path, no extra events)
|
|
pcie_ep_id, io_cpu_id, nbytes = entries[0]
|
|
total_nbytes = nbytes
|
|
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
|
txn_done = self._env.event()
|
|
txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done)
|
|
root_txn = txn
|
|
yield self._host_queues[pcie_ep_id].put(txn)
|
|
yield txn_done
|
|
else:
|
|
# Multi-SIP: inject per SIP, aggregate completions (ADR-0007)
|
|
sub_dones: list[simpy.Event] = []
|
|
sub_txns: list[Transaction] = []
|
|
for pcie_ep_id, io_cpu_id, nbytes in entries:
|
|
total_nbytes = max(total_nbytes, nbytes)
|
|
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
|
txn_done = self._env.event()
|
|
txn = Transaction(
|
|
request=request, path=path, step=0,
|
|
nbytes=nbytes, done=txn_done,
|
|
)
|
|
yield self._host_queues[pcie_ep_id].put(txn)
|
|
sub_dones.append(txn_done)
|
|
sub_txns.append(txn)
|
|
for sd in sub_dones:
|
|
yield sd
|
|
# Aggregate pe_exec_ns from multi-SIP (max)
|
|
pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns]
|
|
pe_vals = [v for v in pe_vals if v is not None]
|
|
if pe_vals:
|
|
if root_txn is None:
|
|
root_txn = sub_txns[0]
|
|
root_txn.result_data["pe_exec_ns"] = max(pe_vals)
|
|
|
|
total_ns = self._env.now - start_ns
|
|
result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes}
|
|
if root_txn is not None and root_txn.result_data:
|
|
result_trace.update(root_txn.result_data)
|
|
self._results[key] = (
|
|
Completion(ok=True),
|
|
result_trace,
|
|
)
|
|
done.succeed()
|
|
|
|
def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event):
|
|
"""Inject a Transaction directly at PE_DMA for PE→HBM latency measurement."""
|
|
pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}"
|
|
pe_dma_id = f"{pe_prefix}.pe_dma"
|
|
pa = PhysAddr.decode(request.dst_pa)
|
|
dst_node = self._resolver.resolve(pa)
|
|
path = self._router.find_path(pe_prefix, dst_node)
|
|
drain_ns = self._path_drain_ns(path, request.nbytes)
|
|
|
|
start_ns = self._env.now
|
|
txn_done = self._env.event()
|
|
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
|
|
done=txn_done, drain_ns=drain_ns)
|
|
yield self._pe_dma_queues[pe_dma_id].put(txn)
|
|
yield txn_done
|
|
total_ns = self._env.now - start_ns
|
|
formula_ns = self._formula_latency(path, request.nbytes)
|
|
self._results[key] = (
|
|
Completion(ok=True),
|
|
{"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes},
|
|
)
|
|
done.succeed()
|
|
|
|
def _path_drain_ns(self, path: list[str], nbytes: int) -> float:
|
|
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
|
|
min_bw = float("inf")
|
|
for i in range(len(path) - 1):
|
|
edge = self._edge_map.get((path[i], path[i + 1]))
|
|
if edge and edge.bw_gbs:
|
|
min_bw = min(min_bw, edge.bw_gbs)
|
|
if min_bw == float("inf"):
|
|
return 0.0
|
|
return nbytes / min_bw
|
|
|
|
def _formula_latency(self, path: list[str], nbytes: int) -> float:
|
|
"""Lower-bound formula latency (ADR-0015 D7).
|
|
|
|
formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns
|
|
|
|
Phase 0: formula == actual (no contention).
|
|
Phase 1+: formula <= actual (contention adds queueing).
|
|
"""
|
|
total = 0.0
|
|
# Wire propagation delays
|
|
for i in range(len(path) - 1):
|
|
edge = self._edge_map.get((path[i], path[i + 1]))
|
|
if edge:
|
|
total += edge.distance_mm * self._ns_per_mm
|
|
# Component overhead_ns
|
|
for node_id in path:
|
|
node = self._nodes.get(node_id)
|
|
if node:
|
|
total += float(node.attrs.get("overhead_ns", 0.0))
|
|
# Drain
|
|
total += self._path_drain_ns(path, nbytes)
|
|
return total
|
|
|
|
def _entry_points(self, request: Any) -> list[tuple[str, str, int]]:
|
|
"""Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP.
|
|
|
|
For Memory{Write,Read}: single SIP entry.
|
|
For KernelLaunchMsg: one entry per distinct SIP in tensor shards.
|
|
"""
|
|
if isinstance(request, MemoryWriteMsg):
|
|
sip = request.dst_sip
|
|
return [(
|
|
self._resolver.find_pcie_ep(sip),
|
|
self._resolver.find_io_cpu(sip),
|
|
request.nbytes,
|
|
)]
|
|
|
|
if isinstance(request, MemoryReadMsg):
|
|
sip = request.src_sip
|
|
return [(
|
|
self._resolver.find_pcie_ep(sip),
|
|
self._resolver.find_io_cpu(sip),
|
|
request.nbytes,
|
|
)]
|
|
|
|
if isinstance(request, KernelLaunchMsg):
|
|
seen: set[int] = set()
|
|
entries: list[tuple[str, str, int]] = []
|
|
for arg in request.args:
|
|
if arg.arg_kind != "tensor":
|
|
continue
|
|
for shard in arg.shards:
|
|
if shard.sip not in seen:
|
|
seen.add(shard.sip)
|
|
entries.append((
|
|
self._resolver.find_pcie_ep(shard.sip),
|
|
self._resolver.find_io_cpu(shard.sip),
|
|
shard.nbytes,
|
|
))
|
|
return entries
|
|
|
|
raise ValueError(f"unsupported request type: {type(request)}")
|