Files
kernbench2/src/kernbench/sim_engine/transaction.py
T
ywkang aaa1cbfaf6 ADR-0033 D6: address-based PC selection at HBM CTRL
Replaces global round-robin with deterministic address-derived PC
striping:

    pc_shift = log2(burst_bytes)
    pc_mask  = num_pcs - 1
    pc       = (flit.address >> pc_shift) & pc_mask

Each Transaction carries base_address (HBM byte offset of the first
chunk); each Flit derives its own address as base + i*flit_bytes.
HBM CTRL routes flits to PCs via this formula, replacing the
arrival-order RR pointer. Also splits the is_last wait into an
asynchronous _finalize_txn process so the worker isn't blocked on
PC commit, exposing true PC parallelism for disjoint addresses.

phyaddr.py documents the canonical bit layout (bits [10:8] for the
default burst=256, num_pcs=8 case). ADR-0033 D6 records the
derivation and the workload scenarios where address-striping
matters (strided streams, offset-disjoint parallel transfers).

Adds tests/test_hbm_address_based_pc.py: canonical bit mapping,
strided 8-way load distribution, same-address PC-0 serialization,
PC-aligned 2KB pair collision, dynamic pc_shift from burst_bytes,
and power-of-2 attr validation. Integration tests inspect
_pc_avail ledger directly: at default config UCIe's 8 ns per-txn
overhead exactly matches chunk_time, masking PC contention at the
makespan level even though the ledger correctly distinguishes the
cases.

Full suite: 631 passed, 1 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 00:18:46 -07:00

99 lines
4.0 KiB
Python

from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
import simpy
@dataclass
class Transaction:
"""In-flight request traversing the device fabric hop-by-hop (ADR-0015 D4).
A Transaction carries a host request through one leg of the device fabric.
Each component on the path reads from its in_port, processes (overhead_ns or
other latency), and advances the Transaction to the next hop via out_port.
Wire processes (ADR-0015 D2) model propagation delay between hops.
Multi-leg flows (e.g. IO_CPU → M_CPU as leg 1, M_CPU.DMA → HBM as leg 2)
use separate Transactions: the terminal component of leg 1 creates leg 2
and waits for leg 2's done before succeeding leg 1's done.
"""
request: Any # original host request (MemoryReadMsg, KernelLaunchMsg, …)
path: list[str] # node_id sequence for this leg
step: int # index of the component currently holding this Transaction
nbytes: int # payload size (bytes)
done: simpy.Event # succeeded when this leg completes
drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal)
is_response: bool = False # True when carrying ResponseMsg on reverse path
result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.)
base_address: int = 0 # HBM byte offset of the first chunk; per-flit addresses
# derived as base + flit_index * flit_bytes (ADR-0033 D6)
@property
def next_hop(self) -> str | None:
"""Node id of the next component, or None if this is the terminal hop."""
nxt = self.step + 1
return self.path[nxt] if nxt < len(self.path) else None
def advance(self) -> Transaction:
"""Return a copy of this Transaction advanced one step along the path."""
return Transaction(
request=self.request,
path=self.path,
step=self.step + 1,
nbytes=self.nbytes,
done=self.done,
drain_ns=self.drain_ns,
is_response=self.is_response,
result_data=self.result_data,
base_address=self.base_address,
)
def into_flits(self, flit_bytes: int) -> Iterator[Flit]:
"""Decompose this Transaction's payload into Flits (ADR-0033 D1).
Yields one Flit per ``flit_bytes`` of payload. The final flit may
carry fewer bytes when ``nbytes`` is not a multiple of ``flit_bytes``;
that flit has ``is_last=True``. Transactions with ``nbytes <= 0``
yield no flits.
All yielded Flits share a reference to this Transaction.
"""
if self.nbytes <= 0 or flit_bytes <= 0:
return
n_full = self.nbytes // flit_bytes
remainder = self.nbytes % flit_bytes
n_total = n_full + (1 if remainder else 0)
for i in range(n_total):
size = flit_bytes if i < n_full else remainder
yield Flit(
txn=self,
flit_index=i,
flit_nbytes=size,
is_last=(i == n_total - 1),
address=self.base_address + i * flit_bytes,
)
@dataclass
class Flit:
"""Atomic wire transport unit (ADR-0033 D1).
Carries a slice of a parent Transaction's payload. The wire
(``engine._wire``) decomposes Transactions into Flits on first
transport; downstream wires pass Flits through with their own
``bw_gbs`` delay.
Phase 2 constraint: ``flit_bytes`` MUST be a multiple of HBM
``burst_bytes`` (default they are equal). See ADR-0033 D1.
"""
txn: Transaction # parent transaction reference
flit_index: int # 0..n_flits-1
flit_nbytes: int # bytes carried (usually flit_bytes; last may be smaller)
is_last: bool # True for the terminating flit
address: int = 0 # HBM byte offset for this flit's chunk (ADR-0033 D6)