Phase 2c-2/3: per-flit wire timing + flit-aware routers + HBM CTRL
Root cause of Phase 2c-1 timing collapse identified: src.out_port and dst.in_port aliased the same simpy.Store, so when wire chunkified a Transaction into Flits and re-put them, fan_in could pull flits before the wire applied bw delay — half the flits bypassed bottleneck timing. Fix: separate Stores per directed edge. Wire is the only conduit. Each flit on the wire incurs chunk_time = flit_nbytes/bw_gbs once, in arrival order. Multi-hop wormhole pipelining emerges naturally because flit-aware pass-through (TransitComponent) forwards each flit serially without reassembly. 64 KB MemoryWrite via UCIe 128 GB/s bottleneck: 273 ns (broken) → 545 ns (matches drain 512 + commit 8 + path overheads). 1 MB: 8230 ns (matches drain 8192). Single-flit transfer transport-time alone, exactly what real-HW wormhole produces. 3 pre-existing tests now off by small margins or inverted: - test_h2d_local_cube_cut_through: 65.53 vs threshold 65.0 - test_engine_override_is_scoped_to_impl: ZeroRouter inherits ComponentBase, not flit-aware, so override path reassembles at each hop while default doesn't - test_intra_sip_critical_path_at_96k_below_threshold: 96KB allreduce microscopically over its threshold Not weakening these to pass: they reflect model fidelity improvements that need calibrated thresholds. To address in follow-up via test threshold updates and ZeroRouter→TransitComponent inheritance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Flit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
@@ -13,15 +14,58 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TransitComponent(ComponentBase):
|
||||
"""Transit component for NOC, UCIe, XBAR nodes.
|
||||
"""Transit component for NOC, UCIe, XBAR nodes (ADR-0033 Phase 2c).
|
||||
|
||||
Applies overhead_ns processing delay (from node.attrs) then forwards the
|
||||
Transaction to the next hop via inherited _forward_txn().
|
||||
Flit-aware pass-through: forwards each Flit to the next hop with
|
||||
per-transaction ``overhead_ns`` applied ONCE (at first-flit arrival,
|
||||
modeling header decode + routing decision). Subsequent flits of the
|
||||
same transaction pipeline through with no extra delay, preserving
|
||||
wormhole-style cut-through across multi-hop paths.
|
||||
|
||||
Forwarding is SERIAL in the worker: each flit is forwarded in arrival
|
||||
order. Spawning ``env.process`` per flit would let later flits
|
||||
overtake earlier ones (when the first flit yields ``overhead_ns``
|
||||
while subsequent flits skip it), producing out-of-order delivery
|
||||
and early ``is_last`` signaling at the destination.
|
||||
|
||||
Non-Flit messages (zero-byte control Transactions, etc.) fall back
|
||||
to the legacy atomic ``_forward_txn`` path via ``env.process``.
|
||||
"""
|
||||
|
||||
_FLIT_AWARE = True
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._txn_decoded: set[int] = set()
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, Flit):
|
||||
tid = id(msg.txn)
|
||||
if tid not in self._txn_decoded:
|
||||
self._txn_decoded.add(tid)
|
||||
yield from self.run(env, msg.txn.nbytes)
|
||||
if msg.is_last:
|
||||
self._txn_decoded.discard(tid)
|
||||
next_hop = self._next_hop_in_path(msg.txn)
|
||||
if next_hop and next_hop in self.out_ports:
|
||||
yield self.out_ports[next_hop].put(msg)
|
||||
elif msg.is_last:
|
||||
msg.txn.done.succeed()
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _next_hop_in_path(self, txn: Any) -> str | None:
|
||||
my_id = self.node.id
|
||||
path = getattr(txn, "path", None)
|
||||
if not path:
|
||||
return None
|
||||
for i, n in enumerate(path):
|
||||
if n == my_id and i + 1 < len(path):
|
||||
return path[i + 1]
|
||||
return None
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.sim_engine.transaction import Flit, Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
@@ -35,6 +35,8 @@ class HbmCtrlComponent(ComponentBase):
|
||||
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
|
||||
"""
|
||||
|
||||
_FLIT_AWARE = True
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._num_pcs: int = 0
|
||||
@@ -44,6 +46,8 @@ class HbmCtrlComponent(ComponentBase):
|
||||
self._pc_avail: list[float] = []
|
||||
self._pc_last_dir: list[str | None] = []
|
||||
self._next_pc: int = 0
|
||||
# Per-txn flit accumulation state (ADR-0033 Phase 2c-3).
|
||||
self._txn_state: dict[int, dict[str, Any]] = {}
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
attrs = self.node.attrs
|
||||
@@ -72,8 +76,59 @@ class HbmCtrlComponent(ComponentBase):
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, Flit):
|
||||
# ADR-0033 Phase 2c-3: serial flit handling (preserve
|
||||
# arrival order, in particular ``is_last`` only after
|
||||
# all preceding flits have committed).
|
||||
yield from self._handle_flit(env, msg)
|
||||
else:
|
||||
# Transaction (e.g., zero-byte read command) — keep
|
||||
# legacy chunk-loop drain path for PC read time modeling.
|
||||
env.process(self._handle_txn(env, msg))
|
||||
|
||||
def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator:
|
||||
"""Per-flit PC commit. On first flit of a txn, claim PC range and
|
||||
apply overhead. On ``is_last``, wait for last PC commit to
|
||||
finish, then send the response."""
|
||||
txn = flit.txn
|
||||
tid = id(txn)
|
||||
chunk_time = (
|
||||
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
|
||||
)
|
||||
new_dir = "W" if self._is_write(txn) else "R"
|
||||
|
||||
if tid not in self._txn_state:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
work_bytes = txn.nbytes if txn.nbytes > 0 else int(
|
||||
getattr(txn.request, "nbytes", 0) or 0
|
||||
)
|
||||
n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1
|
||||
pc_start = self._next_pc
|
||||
self._next_pc = (self._next_pc + n_flits) % self._num_pcs
|
||||
self._txn_state[tid] = {
|
||||
"pc_start": pc_start,
|
||||
"last_finish": env.now,
|
||||
}
|
||||
|
||||
state = self._txn_state[tid]
|
||||
pc = (state["pc_start"] + flit.flit_index) % self._num_pcs
|
||||
switch_cost = 0.0
|
||||
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
|
||||
switch_cost = self._switch_penalty_ns
|
||||
start = max(env.now, self._pc_avail[pc]) + switch_cost
|
||||
finish = start + chunk_time
|
||||
self._pc_avail[pc] = finish
|
||||
self._pc_last_dir[pc] = new_dir
|
||||
if finish > state["last_finish"]:
|
||||
state["last_finish"] = finish
|
||||
|
||||
if flit.is_last:
|
||||
wait = state["last_finish"] - env.now
|
||||
if wait > 0:
|
||||
yield env.timeout(wait)
|
||||
del self._txn_state[tid]
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
is_write = self._is_write(txn)
|
||||
|
||||
@@ -85,15 +85,22 @@ class GraphEngine:
|
||||
for node_id, node in graph.nodes.items()
|
||||
}
|
||||
|
||||
# Wire ports: one Store per directed edge (ADR-0015 D1)
|
||||
# Wire ports: SEPARATE Stores for src.out_port and dst.in_port per
|
||||
# directed edge (ADR-0015 D1, ADR-0033 Phase 2c). The wire process
|
||||
# is the only conduit between them: pulls from src.out_port,
|
||||
# processes per-flit timing, puts on dst.in_port. Using separate
|
||||
# stores eliminates a race with `fan_in` that would otherwise let
|
||||
# flits bypass wire's BW occupancy (fan_in could pull a flit from
|
||||
# the same store before wire put it back delayed).
|
||||
for e in graph.edges:
|
||||
src_comp = self._components.get(e.src)
|
||||
dst_comp = self._components.get(e.dst)
|
||||
if src_comp is None or dst_comp is None:
|
||||
continue
|
||||
store: simpy.Store = simpy.Store(self._env)
|
||||
src_comp.out_ports[e.dst] = store
|
||||
dst_comp.in_ports[e.src] = store
|
||||
out_store: simpy.Store = simpy.Store(self._env)
|
||||
in_store: simpy.Store = simpy.Store(self._env)
|
||||
src_comp.out_ports[e.dst] = out_store
|
||||
dst_comp.in_ports[e.src] = in_store
|
||||
|
||||
# Wire processes: propagation delay + BW occupancy per edge (ADR-0015 D2)
|
||||
# Cut-through (wormhole) model: wires apply propagation delay per hop.
|
||||
@@ -267,25 +274,32 @@ class GraphEngine:
|
||||
available_at = 0.0
|
||||
while True:
|
||||
msg = yield out_port.get()
|
||||
# ADR-0033 Phase 2c-1: chunkify Transactions into Flits but
|
||||
# emit atomically (same env.now) to preserve current timing.
|
||||
# Phase 2c-2 will graduate to per-flit timing.
|
||||
# ADR-0033 Phase 2c-2/3: per-flit transport timing.
|
||||
# Transactions with payload chunkify into Flits; each flit
|
||||
# occupies the wire for ``flit_nbytes/bw_gbs`` and is
|
||||
# delivered after ``prop_ns + transfer_time``. Wormhole
|
||||
# pipelining emerges naturally because downstream flit-aware
|
||||
# components forward flits without reassembly.
|
||||
if isinstance(msg, Transaction) and msg.nbytes > 0:
|
||||
items = list(msg.into_flits(self._flit_bytes))
|
||||
payload_nbytes = msg.nbytes
|
||||
else:
|
||||
items = [msg]
|
||||
payload_nbytes = getattr(msg, "nbytes", 0) or 0
|
||||
# BW occupancy: wait for link to become free, then mark busy
|
||||
if bw_gbs > 0 and payload_nbytes > 0:
|
||||
wait = available_at - self._env.now
|
||||
if wait > 0:
|
||||
yield self._env.timeout(wait)
|
||||
available_at = self._env.now + (payload_nbytes / bw_gbs)
|
||||
# Propagation delay
|
||||
if prop_ns > 0:
|
||||
yield self._env.timeout(prop_ns)
|
||||
for item in items:
|
||||
if isinstance(item, Flit):
|
||||
item_nbytes = item.flit_nbytes
|
||||
elif isinstance(item, Transaction):
|
||||
item_nbytes = item.nbytes
|
||||
else:
|
||||
item_nbytes = getattr(item, "nbytes", 0) or 0
|
||||
if bw_gbs > 0 and item_nbytes > 0:
|
||||
wait = available_at - self._env.now
|
||||
if wait > 0:
|
||||
yield self._env.timeout(wait)
|
||||
available_at = self._env.now + item_nbytes / bw_gbs
|
||||
yield self._env.timeout(prop_ns + item_nbytes / bw_gbs)
|
||||
else:
|
||||
if prop_ns > 0:
|
||||
yield self._env.timeout(prop_ns)
|
||||
yield in_port.put(item)
|
||||
|
||||
def _process(self, key: str, request: Any, done: simpy.Event):
|
||||
|
||||
Reference in New Issue
Block a user