Phase 2c-1: wire chunkifies into Flits + reassembly compat layer
Wire decomposes Transactions into Flits per `_flit_bytes` but emits all flits atomically at the same env.now — preserves single-msg timing as infrastructure for Phase 2c-2 (per-flit timing + flit-aware routers). Non-flit-aware components reassemble Flits in `_fan_in`; `_update_step` sets txn.step to current component's path position so legacy step-based routing continues working when upstream is flit-aware. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -53,11 +53,51 @@ class ComponentBase(ABC):
|
|||||||
env.process(self._fan_in(port))
|
env.process(self._fan_in(port))
|
||||||
env.process(self._worker(env))
|
env.process(self._worker(env))
|
||||||
|
|
||||||
|
# ADR-0033 Phase 2c: flit-aware components consume Flits directly;
|
||||||
|
# non-flit-aware components reassemble Flits into the parent
|
||||||
|
# Transaction before delivery to _inbox. Default False preserves
|
||||||
|
# legacy single-msg semantics during incremental rollout.
|
||||||
|
_FLIT_AWARE: bool = False
|
||||||
|
|
||||||
def _fan_in(self, port: simpy.Store) -> Generator:
|
def _fan_in(self, port: simpy.Store) -> Generator:
|
||||||
"""Relay messages from one in_port into the shared inbox."""
|
"""Relay messages from in_port to _inbox. For non-flit-aware
|
||||||
|
components (default), Flits are accumulated by parent Transaction
|
||||||
|
and only the reassembled Transaction is placed on _inbox once
|
||||||
|
``is_last`` arrives. Step is updated to this component's path
|
||||||
|
position for legacy step-based routing."""
|
||||||
|
from kernbench.sim_engine.transaction import Flit
|
||||||
|
|
||||||
|
if self._FLIT_AWARE:
|
||||||
|
while True:
|
||||||
|
msg = yield port.get()
|
||||||
|
yield self._inbox.put(msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
flit_buffers: dict[int, list[Any]] = {}
|
||||||
while True:
|
while True:
|
||||||
msg = yield port.get()
|
msg = yield port.get()
|
||||||
yield self._inbox.put(msg)
|
if isinstance(msg, Flit):
|
||||||
|
tid = id(msg.txn)
|
||||||
|
flit_buffers.setdefault(tid, []).append(msg)
|
||||||
|
if msg.is_last:
|
||||||
|
flit_buffers.pop(tid, None)
|
||||||
|
self._update_step(msg.txn)
|
||||||
|
yield self._inbox.put(msg.txn)
|
||||||
|
else:
|
||||||
|
yield self._inbox.put(msg)
|
||||||
|
|
||||||
|
def _update_step(self, txn: Any) -> None:
|
||||||
|
"""Set txn.step to this component's index in txn.path (if found).
|
||||||
|
Allows legacy step-based routing to work even when flit-aware
|
||||||
|
upstream components don't call txn.advance()."""
|
||||||
|
my_id = self.node.id
|
||||||
|
path = getattr(txn, "path", None)
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
for i, n in enumerate(path):
|
||||||
|
if n == my_id:
|
||||||
|
txn.step = i
|
||||||
|
return
|
||||||
|
|
||||||
def _worker(self, env: simpy.Environment) -> Generator:
|
def _worker(self, env: simpy.Environment) -> Generator:
|
||||||
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
|
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from kernbench.components.context import ComponentContext
|
|||||||
from kernbench.policy.address.phyaddr import PhysAddr
|
from kernbench.policy.address.phyaddr import PhysAddr
|
||||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
|
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
|
||||||
from kernbench.sim_engine.transaction import Transaction
|
from kernbench.sim_engine.transaction import Flit, Transaction
|
||||||
from kernbench.topology.types import Edge, TopologyGraph
|
from kernbench.topology.types import Edge, TopologyGraph
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +41,14 @@ class GraphEngine:
|
|||||||
for e in graph.edges:
|
for e in graph.edges:
|
||||||
self._edge_map[(e.src, e.dst)] = e
|
self._edge_map[(e.src, e.dst)] = e
|
||||||
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
||||||
|
# ADR-0033 Phase 2c-1: wire chunkifies into Flits (Phase 2c-2/3
|
||||||
|
# will graduate to per-flit timing + flit-aware components). At
|
||||||
|
# 2c-1 stage all flits of a Transaction are emitted atomically
|
||||||
|
# at the same env.now to preserve current single-msg timing —
|
||||||
|
# Flit transport is in place but behaviorally equivalent.
|
||||||
|
self._flit_bytes: int = int(
|
||||||
|
graph.spec.get("system", {}).get("flit_bytes", 256)
|
||||||
|
)
|
||||||
self._results: dict[str, tuple[Completion, Trace]] = {}
|
self._results: dict[str, tuple[Completion, Trace]] = {}
|
||||||
self._events: dict[str, simpy.Event] = {}
|
self._events: dict[str, simpy.Event] = {}
|
||||||
self._counter = 0
|
self._counter = 0
|
||||||
@@ -259,18 +267,26 @@ class GraphEngine:
|
|||||||
available_at = 0.0
|
available_at = 0.0
|
||||||
while True:
|
while True:
|
||||||
msg = yield out_port.get()
|
msg = yield out_port.get()
|
||||||
|
# ADR-0033 Phase 2c-1: chunkify Transactions into Flits but
|
||||||
|
# emit atomically (same env.now) to preserve current timing.
|
||||||
|
# Phase 2c-2 will graduate to per-flit timing.
|
||||||
|
if isinstance(msg, Transaction) and msg.nbytes > 0:
|
||||||
|
items = list(msg.into_flits(self._flit_bytes))
|
||||||
|
payload_nbytes = msg.nbytes
|
||||||
|
else:
|
||||||
|
items = [msg]
|
||||||
|
payload_nbytes = getattr(msg, "nbytes", 0) or 0
|
||||||
# BW occupancy: wait for link to become free, then mark busy
|
# BW occupancy: wait for link to become free, then mark busy
|
||||||
if bw_gbs > 0:
|
if bw_gbs > 0 and payload_nbytes > 0:
|
||||||
nbytes = getattr(msg, "nbytes", 0)
|
wait = available_at - self._env.now
|
||||||
if nbytes > 0:
|
if wait > 0:
|
||||||
wait = available_at - self._env.now
|
yield self._env.timeout(wait)
|
||||||
if wait > 0:
|
available_at = self._env.now + (payload_nbytes / bw_gbs)
|
||||||
yield self._env.timeout(wait)
|
|
||||||
available_at = self._env.now + (nbytes / bw_gbs)
|
|
||||||
# Propagation delay
|
# Propagation delay
|
||||||
if prop_ns > 0:
|
if prop_ns > 0:
|
||||||
yield self._env.timeout(prop_ns)
|
yield self._env.timeout(prop_ns)
|
||||||
yield in_port.put(msg)
|
for item in items:
|
||||||
|
yield in_port.put(item)
|
||||||
|
|
||||||
def _process(self, key: str, request: Any, done: simpy.Event):
|
def _process(self, key: str, request: Any, done: simpy.Event):
|
||||||
if isinstance(request, PeDmaMsg):
|
if isinstance(request, PeDmaMsg):
|
||||||
|
|||||||
Reference in New Issue
Block a user