Phase 2c-1: wire chunkifies into Flits + reassembly compat layer
Wire decomposes Transactions into Flits per `_flit_bytes` but emits all flits atomically at the same env.now — preserves single-msg timing as infrastructure for Phase 2c-2 (per-flit timing + flit-aware routers). Non-flit-aware components reassemble Flits in `_fan_in`; `_update_step` sets txn.step to current component's path position so legacy step-based routing continues working when upstream is flit-aware. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -53,11 +53,51 @@ class ComponentBase(ABC):
|
||||
env.process(self._fan_in(port))
|
||||
env.process(self._worker(env))
|
||||
|
||||
# ADR-0033 Phase 2c: flit-aware components consume Flits directly;
|
||||
# non-flit-aware components reassemble Flits into the parent
|
||||
# Transaction before delivery to _inbox. Default False preserves
|
||||
# legacy single-msg semantics during incremental rollout.
|
||||
_FLIT_AWARE: bool = False
|
||||
|
||||
def _fan_in(self, port: simpy.Store) -> Generator:
|
||||
"""Relay messages from one in_port into the shared inbox."""
|
||||
"""Relay messages from in_port to _inbox. For non-flit-aware
|
||||
components (default), Flits are accumulated by parent Transaction
|
||||
and only the reassembled Transaction is placed on _inbox once
|
||||
``is_last`` arrives. Step is updated to this component's path
|
||||
position for legacy step-based routing."""
|
||||
from kernbench.sim_engine.transaction import Flit
|
||||
|
||||
if self._FLIT_AWARE:
|
||||
while True:
|
||||
msg = yield port.get()
|
||||
yield self._inbox.put(msg)
|
||||
return
|
||||
|
||||
flit_buffers: dict[int, list[Any]] = {}
|
||||
while True:
|
||||
msg = yield port.get()
|
||||
yield self._inbox.put(msg)
|
||||
if isinstance(msg, Flit):
|
||||
tid = id(msg.txn)
|
||||
flit_buffers.setdefault(tid, []).append(msg)
|
||||
if msg.is_last:
|
||||
flit_buffers.pop(tid, None)
|
||||
self._update_step(msg.txn)
|
||||
yield self._inbox.put(msg.txn)
|
||||
else:
|
||||
yield self._inbox.put(msg)
|
||||
|
||||
def _update_step(self, txn: Any) -> None:
|
||||
"""Set txn.step to this component's index in txn.path (if found).
|
||||
Allows legacy step-based routing to work even when flit-aware
|
||||
upstream components don't call txn.advance()."""
|
||||
my_id = self.node.id
|
||||
path = getattr(txn, "path", None)
|
||||
if not path:
|
||||
return
|
||||
for i, n in enumerate(path):
|
||||
if n == my_id:
|
||||
txn.step = i
|
||||
return
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
|
||||
|
||||
@@ -11,7 +11,7 @@ from kernbench.components.context import ComponentContext
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.sim_engine.transaction import Flit, Transaction
|
||||
from kernbench.topology.types import Edge, TopologyGraph
|
||||
|
||||
|
||||
@@ -41,6 +41,14 @@ class GraphEngine:
|
||||
for e in graph.edges:
|
||||
self._edge_map[(e.src, e.dst)] = e
|
||||
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
||||
# ADR-0033 Phase 2c-1: wire chunkifies into Flits (Phase 2c-2/3
|
||||
# will graduate to per-flit timing + flit-aware components). At
|
||||
# 2c-1 stage all flits of a Transaction are emitted atomically
|
||||
# at the same env.now to preserve current single-msg timing —
|
||||
# Flit transport is in place but behaviorally equivalent.
|
||||
self._flit_bytes: int = int(
|
||||
graph.spec.get("system", {}).get("flit_bytes", 256)
|
||||
)
|
||||
self._results: dict[str, tuple[Completion, Trace]] = {}
|
||||
self._events: dict[str, simpy.Event] = {}
|
||||
self._counter = 0
|
||||
@@ -259,18 +267,26 @@ class GraphEngine:
|
||||
available_at = 0.0
|
||||
while True:
|
||||
msg = yield out_port.get()
|
||||
# ADR-0033 Phase 2c-1: chunkify Transactions into Flits but
|
||||
# emit atomically (same env.now) to preserve current timing.
|
||||
# Phase 2c-2 will graduate to per-flit timing.
|
||||
if isinstance(msg, Transaction) and msg.nbytes > 0:
|
||||
items = list(msg.into_flits(self._flit_bytes))
|
||||
payload_nbytes = msg.nbytes
|
||||
else:
|
||||
items = [msg]
|
||||
payload_nbytes = getattr(msg, "nbytes", 0) or 0
|
||||
# BW occupancy: wait for link to become free, then mark busy
|
||||
if bw_gbs > 0:
|
||||
nbytes = getattr(msg, "nbytes", 0)
|
||||
if nbytes > 0:
|
||||
wait = available_at - self._env.now
|
||||
if wait > 0:
|
||||
yield self._env.timeout(wait)
|
||||
available_at = self._env.now + (nbytes / bw_gbs)
|
||||
if bw_gbs > 0 and payload_nbytes > 0:
|
||||
wait = available_at - self._env.now
|
||||
if wait > 0:
|
||||
yield self._env.timeout(wait)
|
||||
available_at = self._env.now + (payload_nbytes / bw_gbs)
|
||||
# Propagation delay
|
||||
if prop_ns > 0:
|
||||
yield self._env.timeout(prop_ns)
|
||||
yield in_port.put(msg)
|
||||
for item in items:
|
||||
yield in_port.put(item)
|
||||
|
||||
def _process(self, key: str, request: Any, done: simpy.Event):
|
||||
if isinstance(request, PeDmaMsg):
|
||||
|
||||
Reference in New Issue
Block a user