from __future__ import annotations from collections.abc import Iterator from dataclasses import dataclass, field from typing import Any import simpy @dataclass class Transaction: """In-flight request traversing the device fabric hop-by-hop (ADR-0015 D4). A Transaction carries a host request through one leg of the device fabric. Each component on the path reads from its in_port, processes (overhead_ns or other latency), and advances the Transaction to the next hop via out_port. Wire processes (ADR-0015 D2) model propagation delay between hops. Multi-leg flows (e.g. IO_CPU → M_CPU as leg 1, M_CPU.DMA → HBM as leg 2) use separate Transactions: the terminal component of leg 1 creates leg 2 and waits for leg 2's done before succeeding leg 1's done. """ request: Any # original host request (MemoryReadMsg, KernelLaunchMsg, …) path: list[str] # node_id sequence for this leg step: int # index of the component currently holding this Transaction nbytes: int # payload size (bytes) done: simpy.Event # succeeded when this leg completes drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal) is_response: bool = False # True when carrying ResponseMsg on reverse path result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.) base_address: int = 0 # HBM byte offset of the first chunk; per-flit addresses # derived as base + flit_index * flit_bytes (ADR-0033 D6) @property def next_hop(self) -> str | None: """Node id of the next component, or None if this is the terminal hop.""" nxt = self.step + 1 return self.path[nxt] if nxt < len(self.path) else None def advance(self) -> Transaction: """Return a copy of this Transaction advanced one step along the path.""" return Transaction( request=self.request, path=self.path, step=self.step + 1, nbytes=self.nbytes, done=self.done, drain_ns=self.drain_ns, is_response=self.is_response, result_data=self.result_data, base_address=self.base_address, ) def into_flits(self, flit_bytes: int) -> Iterator[Flit]: """Decompose this Transaction's payload into Flits (ADR-0033 D1). Yields one Flit per ``flit_bytes`` of payload. The final flit may carry fewer bytes when ``nbytes`` is not a multiple of ``flit_bytes``; that flit has ``is_last=True``. Transactions with ``nbytes <= 0`` yield no flits. All yielded Flits share a reference to this Transaction. """ if self.nbytes <= 0 or flit_bytes <= 0: return n_full = self.nbytes // flit_bytes remainder = self.nbytes % flit_bytes n_total = n_full + (1 if remainder else 0) for i in range(n_total): size = flit_bytes if i < n_full else remainder yield Flit( txn=self, flit_index=i, flit_nbytes=size, is_last=(i == n_total - 1), address=self.base_address + i * flit_bytes, ) @dataclass class Flit: """Atomic wire transport unit (ADR-0033 D1). Carries a slice of a parent Transaction's payload. The wire (``engine._wire``) decomposes Transactions into Flits on first transport; downstream wires pass Flits through with their own ``bw_gbs`` delay. Phase 2 constraint: ``flit_bytes`` MUST be a multiple of HBM ``burst_bytes`` (default they are equal). See ADR-0033 D1. """ txn: Transaction # parent transaction reference flit_index: int # 0..n_flits-1 flit_nbytes: int # bytes carried (usually flit_bytes; last may be smaller) is_last: bool # True for the terminating flit address: int = 0 # HBM byte offset for this flit's chunk (ADR-0033 D6)