diff --git a/src/kernbench/components/base.py b/src/kernbench/components/base.py index 932b9e9..6ee9beb 100644 --- a/src/kernbench/components/base.py +++ b/src/kernbench/components/base.py @@ -53,11 +53,51 @@ class ComponentBase(ABC): env.process(self._fan_in(port)) env.process(self._worker(env)) + # ADR-0033 Phase 2c: flit-aware components consume Flits directly; + # non-flit-aware components reassemble Flits into the parent + # Transaction before delivery to _inbox. Default False preserves + # legacy single-msg semantics during incremental rollout. + _FLIT_AWARE: bool = False + def _fan_in(self, port: simpy.Store) -> Generator: - """Relay messages from one in_port into the shared inbox.""" + """Relay messages from in_port to _inbox. For non-flit-aware + components (default), Flits are accumulated by parent Transaction + and only the reassembled Transaction is placed on _inbox once + ``is_last`` arrives. Step is updated to this component's path + position for legacy step-based routing.""" + from kernbench.sim_engine.transaction import Flit + + if self._FLIT_AWARE: + while True: + msg = yield port.get() + yield self._inbox.put(msg) + return + + flit_buffers: dict[int, list[Any]] = {} while True: msg = yield port.get() - yield self._inbox.put(msg) + if isinstance(msg, Flit): + tid = id(msg.txn) + flit_buffers.setdefault(tid, []).append(msg) + if msg.is_last: + flit_buffers.pop(tid, None) + self._update_step(msg.txn) + yield self._inbox.put(msg.txn) + else: + yield self._inbox.put(msg) + + def _update_step(self, txn: Any) -> None: + """Set txn.step to this component's index in txn.path (if found). + Allows legacy step-based routing to work even when flit-aware + upstream components don't call txn.advance().""" + my_id = self.node.id + path = getattr(txn, "path", None) + if not path: + return + for i, n in enumerate(path): + if n == my_id: + txn.step = i + return def _worker(self, env: simpy.Environment) -> Generator: """Generic forwarding worker: spawns _forward_txn per message (pipeline).""" diff --git a/src/kernbench/sim_engine/engine.py b/src/kernbench/sim_engine/engine.py index 62904a8..070c797 100644 --- a/src/kernbench/sim_engine/engine.py +++ b/src/kernbench/sim_engine/engine.py @@ -11,7 +11,7 @@ from kernbench.components.context import ComponentContext from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.routing.router import AddressResolver, PathRouter from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg -from kernbench.sim_engine.transaction import Transaction +from kernbench.sim_engine.transaction import Flit, Transaction from kernbench.topology.types import Edge, TopologyGraph @@ -41,6 +41,14 @@ class GraphEngine: for e in graph.edges: self._edge_map[(e.src, e.dst)] = e self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01) + # ADR-0033 Phase 2c-1: wire chunkifies into Flits (Phase 2c-2/3 + # will graduate to per-flit timing + flit-aware components). At + # 2c-1 stage all flits of a Transaction are emitted atomically + # at the same env.now to preserve current single-msg timing — + # Flit transport is in place but behaviorally equivalent. + self._flit_bytes: int = int( + graph.spec.get("system", {}).get("flit_bytes", 256) + ) self._results: dict[str, tuple[Completion, Trace]] = {} self._events: dict[str, simpy.Event] = {} self._counter = 0 @@ -259,18 +267,26 @@ class GraphEngine: available_at = 0.0 while True: msg = yield out_port.get() + # ADR-0033 Phase 2c-1: chunkify Transactions into Flits but + # emit atomically (same env.now) to preserve current timing. + # Phase 2c-2 will graduate to per-flit timing. + if isinstance(msg, Transaction) and msg.nbytes > 0: + items = list(msg.into_flits(self._flit_bytes)) + payload_nbytes = msg.nbytes + else: + items = [msg] + payload_nbytes = getattr(msg, "nbytes", 0) or 0 # BW occupancy: wait for link to become free, then mark busy - if bw_gbs > 0: - nbytes = getattr(msg, "nbytes", 0) - if nbytes > 0: - wait = available_at - self._env.now - if wait > 0: - yield self._env.timeout(wait) - available_at = self._env.now + (nbytes / bw_gbs) + if bw_gbs > 0 and payload_nbytes > 0: + wait = available_at - self._env.now + if wait > 0: + yield self._env.timeout(wait) + available_at = self._env.now + (payload_nbytes / bw_gbs) # Propagation delay if prop_ns > 0: yield self._env.timeout(prop_ns) - yield in_port.put(msg) + for item in items: + yield in_port.put(item) def _process(self, key: str, request: Any, done: simpy.Event): if isinstance(request, PeDmaMsg):