From 4929040cf1ffd6233371fb2d22c5a930959c8e6e Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Thu, 14 May 2026 22:43:40 -0700 Subject: [PATCH] Phase 2c-2/3: per-flit wire timing + flit-aware routers + HBM CTRL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause of Phase 2c-1 timing collapse identified: src.out_port and dst.in_port aliased the same simpy.Store, so when wire chunkified a Transaction into Flits and re-put them, fan_in could pull flits before the wire applied bw delay — half the flits bypassed bottleneck timing. Fix: separate Stores per directed edge. Wire is the only conduit. Each flit on the wire incurs chunk_time = flit_nbytes/bw_gbs once, in arrival order. Multi-hop wormhole pipelining emerges naturally because flit-aware pass-through (TransitComponent) forwards each flit serially without reassembly. 64 KB MemoryWrite via UCIe 128 GB/s bottleneck: 273 ns (broken) → 545 ns (matches drain 512 + commit 8 + path overheads). 1 MB: 8230 ns (matches drain 8192). Single-flit transfer transport-time alone, exactly what real-HW wormhole produces. 3 pre-existing tests now off by small margins or inverted: - test_h2d_local_cube_cut_through: 65.53 vs threshold 65.0 - test_engine_override_is_scoped_to_impl: ZeroRouter inherits ComponentBase, not flit-aware, so override path reassembles at each hop while default doesn't - test_intra_sip_critical_path_at_96k_below_threshold: 96KB allreduce microscopically over its threshold Not weakening these to pass: they reflect model fidelity improvements that need calibrated thresholds. To address in follow-up via test threshold updates and ZeroRouter→TransitComponent inheritance. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../components/builtin/forwarding.py | 52 ++++++++++++++-- src/kernbench/components/builtin/hbm_ctrl.py | 61 ++++++++++++++++++- src/kernbench/sim_engine/engine.py | 50 +++++++++------ 3 files changed, 138 insertions(+), 25 deletions(-) diff --git a/src/kernbench/components/builtin/forwarding.py b/src/kernbench/components/builtin/forwarding.py index 1fa8eee..e1e2df1 100644 --- a/src/kernbench/components/builtin/forwarding.py +++ b/src/kernbench/components/builtin/forwarding.py @@ -1,11 +1,12 @@ from __future__ import annotations from collections.abc import Generator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Flit if TYPE_CHECKING: from kernbench.components.context import ComponentContext @@ -13,15 +14,58 @@ if TYPE_CHECKING: class TransitComponent(ComponentBase): - """Transit component for NOC, UCIe, XBAR nodes. + """Transit component for NOC, UCIe, XBAR nodes (ADR-0033 Phase 2c). - Applies overhead_ns processing delay (from node.attrs) then forwards the - Transaction to the next hop via inherited _forward_txn(). + Flit-aware pass-through: forwards each Flit to the next hop with + per-transaction ``overhead_ns`` applied ONCE (at first-flit arrival, + modeling header decode + routing decision). Subsequent flits of the + same transaction pipeline through with no extra delay, preserving + wormhole-style cut-through across multi-hop paths. + + Forwarding is SERIAL in the worker: each flit is forwarded in arrival + order. Spawning ``env.process`` per flit would let later flits + overtake earlier ones (when the first flit yields ``overhead_ns`` + while subsequent flits skip it), producing out-of-order delivery + and early ``is_last`` signaling at the destination. + + Non-Flit messages (zero-byte control Transactions, etc.) fall back + to the legacy atomic ``_forward_txn`` path via ``env.process``. """ + _FLIT_AWARE = True + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) + self._txn_decoded: set[int] = set() def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, Flit): + tid = id(msg.txn) + if tid not in self._txn_decoded: + self._txn_decoded.add(tid) + yield from self.run(env, msg.txn.nbytes) + if msg.is_last: + self._txn_decoded.discard(tid) + next_hop = self._next_hop_in_path(msg.txn) + if next_hop and next_hop in self.out_ports: + yield self.out_ports[next_hop].put(msg) + elif msg.is_last: + msg.txn.done.succeed() + else: + env.process(self._forward_txn(env, msg)) + + def _next_hop_in_path(self, txn: Any) -> str | None: + my_id = self.node.id + path = getattr(txn, "path", None) + if not path: + return None + for i, n in enumerate(path): + if n == my_id and i + 1 < len(path): + return path[i + 1] + return None diff --git a/src/kernbench/components/builtin/hbm_ctrl.py b/src/kernbench/components/builtin/hbm_ctrl.py index cc9e261..945bbf4 100644 --- a/src/kernbench/components/builtin/hbm_ctrl.py +++ b/src/kernbench/components/builtin/hbm_ctrl.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase -from kernbench.sim_engine.transaction import Transaction +from kernbench.sim_engine.transaction import Flit, Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext @@ -35,6 +35,8 @@ class HbmCtrlComponent(ComponentBase): assumption — ideal scheduler amortizes switching cost; ADR-0033 D2). """ + _FLIT_AWARE = True + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._num_pcs: int = 0 @@ -44,6 +46,8 @@ class HbmCtrlComponent(ComponentBase): self._pc_avail: list[float] = [] self._pc_last_dir: list[str | None] = [] self._next_pc: int = 0 + # Per-txn flit accumulation state (ADR-0033 Phase 2c-3). + self._txn_state: dict[int, dict[str, Any]] = {} def start(self, env: simpy.Environment) -> None: attrs = self.node.attrs @@ -72,8 +76,59 @@ class HbmCtrlComponent(ComponentBase): def _worker(self, env: simpy.Environment) -> Generator: while True: - txn: Any = yield self._inbox.get() - env.process(self._handle_txn(env, txn)) + msg: Any = yield self._inbox.get() + if isinstance(msg, Flit): + # ADR-0033 Phase 2c-3: serial flit handling (preserve + # arrival order, in particular ``is_last`` only after + # all preceding flits have committed). + yield from self._handle_flit(env, msg) + else: + # Transaction (e.g., zero-byte read command) — keep + # legacy chunk-loop drain path for PC read time modeling. + env.process(self._handle_txn(env, msg)) + + def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator: + """Per-flit PC commit. On first flit of a txn, claim PC range and + apply overhead. On ``is_last``, wait for last PC commit to + finish, then send the response.""" + txn = flit.txn + tid = id(txn) + chunk_time = ( + self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0 + ) + new_dir = "W" if self._is_write(txn) else "R" + + if tid not in self._txn_state: + yield from self.run(env, txn.nbytes) + work_bytes = txn.nbytes if txn.nbytes > 0 else int( + getattr(txn.request, "nbytes", 0) or 0 + ) + n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1 + pc_start = self._next_pc + self._next_pc = (self._next_pc + n_flits) % self._num_pcs + self._txn_state[tid] = { + "pc_start": pc_start, + "last_finish": env.now, + } + + state = self._txn_state[tid] + pc = (state["pc_start"] + flit.flit_index) % self._num_pcs + switch_cost = 0.0 + if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir: + switch_cost = self._switch_penalty_ns + start = max(env.now, self._pc_avail[pc]) + switch_cost + finish = start + chunk_time + self._pc_avail[pc] = finish + self._pc_last_dir[pc] = new_dir + if finish > state["last_finish"]: + state["last_finish"] = finish + + if flit.is_last: + wait = state["last_finish"] - env.now + if wait > 0: + yield env.timeout(wait) + del self._txn_state[tid] + yield from self._send_response(env, txn) def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator: is_write = self._is_write(txn) diff --git a/src/kernbench/sim_engine/engine.py b/src/kernbench/sim_engine/engine.py index 070c797..e2ac5ae 100644 --- a/src/kernbench/sim_engine/engine.py +++ b/src/kernbench/sim_engine/engine.py @@ -85,15 +85,22 @@ class GraphEngine: for node_id, node in graph.nodes.items() } - # Wire ports: one Store per directed edge (ADR-0015 D1) + # Wire ports: SEPARATE Stores for src.out_port and dst.in_port per + # directed edge (ADR-0015 D1, ADR-0033 Phase 2c). The wire process + # is the only conduit between them: pulls from src.out_port, + # processes per-flit timing, puts on dst.in_port. Using separate + # stores eliminates a race with `fan_in` that would otherwise let + # flits bypass wire's BW occupancy (fan_in could pull a flit from + # the same store before wire put it back delayed). for e in graph.edges: src_comp = self._components.get(e.src) dst_comp = self._components.get(e.dst) if src_comp is None or dst_comp is None: continue - store: simpy.Store = simpy.Store(self._env) - src_comp.out_ports[e.dst] = store - dst_comp.in_ports[e.src] = store + out_store: simpy.Store = simpy.Store(self._env) + in_store: simpy.Store = simpy.Store(self._env) + src_comp.out_ports[e.dst] = out_store + dst_comp.in_ports[e.src] = in_store # Wire processes: propagation delay + BW occupancy per edge (ADR-0015 D2) # Cut-through (wormhole) model: wires apply propagation delay per hop. @@ -267,25 +274,32 @@ class GraphEngine: available_at = 0.0 while True: msg = yield out_port.get() - # ADR-0033 Phase 2c-1: chunkify Transactions into Flits but - # emit atomically (same env.now) to preserve current timing. - # Phase 2c-2 will graduate to per-flit timing. + # ADR-0033 Phase 2c-2/3: per-flit transport timing. + # Transactions with payload chunkify into Flits; each flit + # occupies the wire for ``flit_nbytes/bw_gbs`` and is + # delivered after ``prop_ns + transfer_time``. Wormhole + # pipelining emerges naturally because downstream flit-aware + # components forward flits without reassembly. if isinstance(msg, Transaction) and msg.nbytes > 0: items = list(msg.into_flits(self._flit_bytes)) - payload_nbytes = msg.nbytes else: items = [msg] - payload_nbytes = getattr(msg, "nbytes", 0) or 0 - # BW occupancy: wait for link to become free, then mark busy - if bw_gbs > 0 and payload_nbytes > 0: - wait = available_at - self._env.now - if wait > 0: - yield self._env.timeout(wait) - available_at = self._env.now + (payload_nbytes / bw_gbs) - # Propagation delay - if prop_ns > 0: - yield self._env.timeout(prop_ns) for item in items: + if isinstance(item, Flit): + item_nbytes = item.flit_nbytes + elif isinstance(item, Transaction): + item_nbytes = item.nbytes + else: + item_nbytes = getattr(item, "nbytes", 0) or 0 + if bw_gbs > 0 and item_nbytes > 0: + wait = available_at - self._env.now + if wait > 0: + yield self._env.timeout(wait) + available_at = self._env.now + item_nbytes / bw_gbs + yield self._env.timeout(prop_ns + item_nbytes / bw_gbs) + else: + if prop_ns > 0: + yield self._env.timeout(prop_ns) yield in_port.put(item) def _process(self, key: str, request: Any, done: simpy.Event):