Phase 2c-2/3: per-flit wire timing + flit-aware routers + HBM CTRL
Root cause of Phase 2c-1 timing collapse identified: src.out_port and dst.in_port aliased the same simpy.Store, so when wire chunkified a Transaction into Flits and re-put them, fan_in could pull flits before the wire applied bw delay — half the flits bypassed bottleneck timing. Fix: separate Stores per directed edge. Wire is the only conduit. Each flit on the wire incurs chunk_time = flit_nbytes/bw_gbs once, in arrival order. Multi-hop wormhole pipelining emerges naturally because flit-aware pass-through (TransitComponent) forwards each flit serially without reassembly. 64 KB MemoryWrite via UCIe 128 GB/s bottleneck: 273 ns (broken) → 545 ns (matches drain 512 + commit 8 + path overheads). 1 MB: 8230 ns (matches drain 8192). Single-flit transfer transport-time alone, exactly what real-HW wormhole produces. 3 pre-existing tests now off by small margins or inverted: - test_h2d_local_cube_cut_through: 65.53 vs threshold 65.0 - test_engine_override_is_scoped_to_impl: ZeroRouter inherits ComponentBase, not flit-aware, so override path reassembles at each hop while default doesn't - test_intra_sip_critical_path_at_96k_below_threshold: 96KB allreduce microscopically over its threshold Not weakening these to pass: they reflect model fidelity improvements that need calibrated thresholds. To address in follow-up via test threshold updates and ZeroRouter→TransitComponent inheritance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.sim_engine.transaction import Flit, Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
@@ -35,6 +35,8 @@ class HbmCtrlComponent(ComponentBase):
|
||||
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
|
||||
"""
|
||||
|
||||
_FLIT_AWARE = True
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._num_pcs: int = 0
|
||||
@@ -44,6 +46,8 @@ class HbmCtrlComponent(ComponentBase):
|
||||
self._pc_avail: list[float] = []
|
||||
self._pc_last_dir: list[str | None] = []
|
||||
self._next_pc: int = 0
|
||||
# Per-txn flit accumulation state (ADR-0033 Phase 2c-3).
|
||||
self._txn_state: dict[int, dict[str, Any]] = {}
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
attrs = self.node.attrs
|
||||
@@ -72,8 +76,59 @@ class HbmCtrlComponent(ComponentBase):
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, Flit):
|
||||
# ADR-0033 Phase 2c-3: serial flit handling (preserve
|
||||
# arrival order, in particular ``is_last`` only after
|
||||
# all preceding flits have committed).
|
||||
yield from self._handle_flit(env, msg)
|
||||
else:
|
||||
# Transaction (e.g., zero-byte read command) — keep
|
||||
# legacy chunk-loop drain path for PC read time modeling.
|
||||
env.process(self._handle_txn(env, msg))
|
||||
|
||||
def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator:
|
||||
"""Per-flit PC commit. On first flit of a txn, claim PC range and
|
||||
apply overhead. On ``is_last``, wait for last PC commit to
|
||||
finish, then send the response."""
|
||||
txn = flit.txn
|
||||
tid = id(txn)
|
||||
chunk_time = (
|
||||
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
|
||||
)
|
||||
new_dir = "W" if self._is_write(txn) else "R"
|
||||
|
||||
if tid not in self._txn_state:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
work_bytes = txn.nbytes if txn.nbytes > 0 else int(
|
||||
getattr(txn.request, "nbytes", 0) or 0
|
||||
)
|
||||
n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1
|
||||
pc_start = self._next_pc
|
||||
self._next_pc = (self._next_pc + n_flits) % self._num_pcs
|
||||
self._txn_state[tid] = {
|
||||
"pc_start": pc_start,
|
||||
"last_finish": env.now,
|
||||
}
|
||||
|
||||
state = self._txn_state[tid]
|
||||
pc = (state["pc_start"] + flit.flit_index) % self._num_pcs
|
||||
switch_cost = 0.0
|
||||
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
|
||||
switch_cost = self._switch_penalty_ns
|
||||
start = max(env.now, self._pc_avail[pc]) + switch_cost
|
||||
finish = start + chunk_time
|
||||
self._pc_avail[pc] = finish
|
||||
self._pc_last_dir[pc] = new_dir
|
||||
if finish > state["last_finish"]:
|
||||
state["last_finish"] = finish
|
||||
|
||||
if flit.is_last:
|
||||
wait = state["last_finish"] - env.now
|
||||
if wait > 0:
|
||||
yield env.timeout(wait)
|
||||
del self._txn_state[tid]
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
is_write = self._is_write(txn)
|
||||
|
||||
Reference in New Issue
Block a user