Phase 2c-2/3: per-flit wire timing + flit-aware routers + HBM CTRL

Root cause of Phase 2c-1 timing collapse identified: src.out_port and
dst.in_port aliased the same simpy.Store, so when wire chunkified a
Transaction into Flits and re-put them, fan_in could pull flits before
the wire applied bw delay — half the flits bypassed bottleneck timing.

Fix: separate Stores per directed edge. Wire is the only conduit. Each
flit on the wire incurs chunk_time = flit_nbytes/bw_gbs once, in arrival
order. Multi-hop wormhole pipelining emerges naturally because
flit-aware pass-through (TransitComponent) forwards each flit serially
without reassembly.

64 KB MemoryWrite via UCIe 128 GB/s bottleneck: 273 ns (broken) → 545 ns
(matches drain 512 + commit 8 + path overheads). 1 MB: 8230 ns (matches
drain 8192). Single-flit transfer transport-time alone, exactly what
real-HW wormhole produces.

3 pre-existing tests now off by small margins or inverted:
- test_h2d_local_cube_cut_through: 65.53 vs threshold 65.0
- test_engine_override_is_scoped_to_impl: ZeroRouter inherits
  ComponentBase, not flit-aware, so override path reassembles at each
  hop while default doesn't
- test_intra_sip_critical_path_at_96k_below_threshold: 96KB allreduce
  microscopically over its threshold

Not weakening these to pass: they reflect model fidelity improvements
that need calibrated thresholds. To address in follow-up via test
threshold updates and ZeroRouter→TransitComponent inheritance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-14 22:43:40 -07:00
parent b31b3e8248
commit 4929040cf1
3 changed files with 138 additions and 25 deletions
+48 -4
View File
@@ -1,11 +1,12 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Flit
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
@@ -13,15 +14,58 @@ if TYPE_CHECKING:
class TransitComponent(ComponentBase):
"""Transit component for NOC, UCIe, XBAR nodes.
"""Transit component for NOC, UCIe, XBAR nodes (ADR-0033 Phase 2c).
Applies overhead_ns processing delay (from node.attrs) then forwards the
Transaction to the next hop via inherited _forward_txn().
Flit-aware pass-through: forwards each Flit to the next hop with
per-transaction ``overhead_ns`` applied ONCE (at first-flit arrival,
modeling header decode + routing decision). Subsequent flits of the
same transaction pipeline through with no extra delay, preserving
wormhole-style cut-through across multi-hop paths.
Forwarding is SERIAL in the worker: each flit is forwarded in arrival
order. Spawning ``env.process`` per flit would let later flits
overtake earlier ones (when the first flit yields ``overhead_ns``
while subsequent flits skip it), producing out-of-order delivery
and early ``is_last`` signaling at the destination.
Non-Flit messages (zero-byte control Transactions, etc.) fall back
to the legacy atomic ``_forward_txn`` path via ``env.process``.
"""
_FLIT_AWARE = True
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._txn_decoded: set[int] = set()
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, Flit):
tid = id(msg.txn)
if tid not in self._txn_decoded:
self._txn_decoded.add(tid)
yield from self.run(env, msg.txn.nbytes)
if msg.is_last:
self._txn_decoded.discard(tid)
next_hop = self._next_hop_in_path(msg.txn)
if next_hop and next_hop in self.out_ports:
yield self.out_ports[next_hop].put(msg)
elif msg.is_last:
msg.txn.done.succeed()
else:
env.process(self._forward_txn(env, msg))
def _next_hop_in_path(self, txn: Any) -> str | None:
my_id = self.node.id
path = getattr(txn, "path", None)
if not path:
return None
for i, n in enumerate(path):
if n == my_id and i + 1 < len(path):
return path[i + 1]
return None
+58 -3
View File
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
from kernbench.sim_engine.transaction import Flit, Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
@@ -35,6 +35,8 @@ class HbmCtrlComponent(ComponentBase):
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
"""
_FLIT_AWARE = True
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._num_pcs: int = 0
@@ -44,6 +46,8 @@ class HbmCtrlComponent(ComponentBase):
self._pc_avail: list[float] = []
self._pc_last_dir: list[str | None] = []
self._next_pc: int = 0
# Per-txn flit accumulation state (ADR-0033 Phase 2c-3).
self._txn_state: dict[int, dict[str, Any]] = {}
def start(self, env: simpy.Environment) -> None:
attrs = self.node.attrs
@@ -72,8 +76,59 @@ class HbmCtrlComponent(ComponentBase):
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
env.process(self._handle_txn(env, txn))
msg: Any = yield self._inbox.get()
if isinstance(msg, Flit):
# ADR-0033 Phase 2c-3: serial flit handling (preserve
# arrival order, in particular ``is_last`` only after
# all preceding flits have committed).
yield from self._handle_flit(env, msg)
else:
# Transaction (e.g., zero-byte read command) — keep
# legacy chunk-loop drain path for PC read time modeling.
env.process(self._handle_txn(env, msg))
def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator:
"""Per-flit PC commit. On first flit of a txn, claim PC range and
apply overhead. On ``is_last``, wait for last PC commit to
finish, then send the response."""
txn = flit.txn
tid = id(txn)
chunk_time = (
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
)
new_dir = "W" if self._is_write(txn) else "R"
if tid not in self._txn_state:
yield from self.run(env, txn.nbytes)
work_bytes = txn.nbytes if txn.nbytes > 0 else int(
getattr(txn.request, "nbytes", 0) or 0
)
n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1
pc_start = self._next_pc
self._next_pc = (self._next_pc + n_flits) % self._num_pcs
self._txn_state[tid] = {
"pc_start": pc_start,
"last_finish": env.now,
}
state = self._txn_state[tid]
pc = (state["pc_start"] + flit.flit_index) % self._num_pcs
switch_cost = 0.0
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
switch_cost = self._switch_penalty_ns
start = max(env.now, self._pc_avail[pc]) + switch_cost
finish = start + chunk_time
self._pc_avail[pc] = finish
self._pc_last_dir[pc] = new_dir
if finish > state["last_finish"]:
state["last_finish"] = finish
if flit.is_last:
wait = state["last_finish"] - env.now
if wait > 0:
yield env.timeout(wait)
del self._txn_state[tid]
yield from self._send_response(env, txn)
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
is_write = self._is_write(txn)
+28 -14
View File
@@ -85,15 +85,22 @@ class GraphEngine:
for node_id, node in graph.nodes.items()
}
# Wire ports: one Store per directed edge (ADR-0015 D1)
# Wire ports: SEPARATE Stores for src.out_port and dst.in_port per
# directed edge (ADR-0015 D1, ADR-0033 Phase 2c). The wire process
# is the only conduit between them: pulls from src.out_port,
# processes per-flit timing, puts on dst.in_port. Using separate
# stores eliminates a race with `fan_in` that would otherwise let
# flits bypass wire's BW occupancy (fan_in could pull a flit from
# the same store before wire put it back delayed).
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
store: simpy.Store = simpy.Store(self._env)
src_comp.out_ports[e.dst] = store
dst_comp.in_ports[e.src] = store
out_store: simpy.Store = simpy.Store(self._env)
in_store: simpy.Store = simpy.Store(self._env)
src_comp.out_ports[e.dst] = out_store
dst_comp.in_ports[e.src] = in_store
# Wire processes: propagation delay + BW occupancy per edge (ADR-0015 D2)
# Cut-through (wormhole) model: wires apply propagation delay per hop.
@@ -267,25 +274,32 @@ class GraphEngine:
available_at = 0.0
while True:
msg = yield out_port.get()
# ADR-0033 Phase 2c-1: chunkify Transactions into Flits but
# emit atomically (same env.now) to preserve current timing.
# Phase 2c-2 will graduate to per-flit timing.
# ADR-0033 Phase 2c-2/3: per-flit transport timing.
# Transactions with payload chunkify into Flits; each flit
# occupies the wire for ``flit_nbytes/bw_gbs`` and is
# delivered after ``prop_ns + transfer_time``. Wormhole
# pipelining emerges naturally because downstream flit-aware
# components forward flits without reassembly.
if isinstance(msg, Transaction) and msg.nbytes > 0:
items = list(msg.into_flits(self._flit_bytes))
payload_nbytes = msg.nbytes
else:
items = [msg]
payload_nbytes = getattr(msg, "nbytes", 0) or 0
# BW occupancy: wait for link to become free, then mark busy
if bw_gbs > 0 and payload_nbytes > 0:
for item in items:
if isinstance(item, Flit):
item_nbytes = item.flit_nbytes
elif isinstance(item, Transaction):
item_nbytes = item.nbytes
else:
item_nbytes = getattr(item, "nbytes", 0) or 0
if bw_gbs > 0 and item_nbytes > 0:
wait = available_at - self._env.now
if wait > 0:
yield self._env.timeout(wait)
available_at = self._env.now + (payload_nbytes / bw_gbs)
# Propagation delay
available_at = self._env.now + item_nbytes / bw_gbs
yield self._env.timeout(prop_ns + item_nbytes / bw_gbs)
else:
if prop_ns > 0:
yield self._env.timeout(prop_ns)
for item in items:
yield in_port.put(item)
def _process(self, key: str, request: Any, done: simpy.Event):