ADR-0033 D6: address-based PC selection at HBM CTRL

Replaces global round-robin with deterministic address-derived PC
striping:

    pc_shift = log2(burst_bytes)
    pc_mask  = num_pcs - 1
    pc       = (flit.address >> pc_shift) & pc_mask

Each Transaction carries base_address (HBM byte offset of the first
chunk); each Flit derives its own address as base + i*flit_bytes.
HBM CTRL routes flits to PCs via this formula, replacing the
arrival-order RR pointer. Also splits the is_last wait into an
asynchronous _finalize_txn process so the worker isn't blocked on
PC commit, exposing true PC parallelism for disjoint addresses.

phyaddr.py documents the canonical bit layout (bits [10:8] for the
default burst=256, num_pcs=8 case). ADR-0033 D6 records the
derivation and the workload scenarios where address-striping
matters (strided streams, offset-disjoint parallel transfers).

Adds tests/test_hbm_address_based_pc.py: canonical bit mapping,
strided 8-way load distribution, same-address PC-0 serialization,
PC-aligned 2KB pair collision, dynamic pc_shift from burst_bytes,
and power-of-2 attr validation. Integration tests inspect
_pc_avail ledger directly: at default config UCIe's 8 ns per-txn
overhead exactly matches chunk_time, masking PC contention at the
makespan level even though the ledger correctly distinguishes the
cases.

Full suite: 631 passed, 1 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-15 00:18:46 -07:00
parent a44f832be5
commit aaa1cbfaf6
6 changed files with 292 additions and 27 deletions
+35 -20
View File
@@ -45,7 +45,10 @@ class HbmCtrlComponent(ComponentBase):
self._switch_penalty_ns: float = 0.0
self._pc_avail: list[float] = []
self._pc_last_dir: list[str | None] = []
self._next_pc: int = 0
# Address-based PC selection (ADR-0033 D6):
# pc = (address >> _pc_shift) & _pc_mask
self._pc_shift: int = 0
self._pc_mask: int = 0
# Per-txn flit accumulation state (ADR-0033 Phase 2c-3).
self._txn_state: dict[int, dict[str, Any]] = {}
@@ -55,11 +58,19 @@ class HbmCtrlComponent(ComponentBase):
self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0))
self._burst_bytes = int(attrs.get("burst_bytes", 256))
self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0))
if self._num_pcs <= 0 or (self._num_pcs & (self._num_pcs - 1)) != 0:
raise ValueError(f"num_pcs must be a positive power of 2, got {self._num_pcs}")
if self._burst_bytes <= 0 or (self._burst_bytes & (self._burst_bytes - 1)) != 0:
raise ValueError(f"burst_bytes must be a positive power of 2, got {self._burst_bytes}")
self._pc_shift = self._burst_bytes.bit_length() - 1
self._pc_mask = self._num_pcs - 1
self._pc_avail = [0.0] * self._num_pcs
self._pc_last_dir = [None] * self._num_pcs
self._next_pc = 0
super().start(env)
def _pc_for_address(self, address: int) -> int:
return (int(address) >> self._pc_shift) & self._pc_mask
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
@@ -88,9 +99,10 @@ class HbmCtrlComponent(ComponentBase):
env.process(self._handle_txn(env, msg))
def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator:
"""Per-flit PC commit. On first flit of a txn, claim PC range and
apply overhead. On ``is_last``, wait for last PC commit to
finish, then send the response."""
"""Per-flit PC commit. On first flit of a txn, apply overhead. PC is
derived from the flit's address (ADR-0033 D6 address-based striping).
On ``is_last``, wait for last PC commit to finish, then send the
response."""
txn = flit.txn
tid = id(txn)
chunk_time = (
@@ -100,19 +112,12 @@ class HbmCtrlComponent(ComponentBase):
if tid not in self._txn_state:
yield from self.run(env, txn.nbytes)
work_bytes = txn.nbytes if txn.nbytes > 0 else int(
getattr(txn.request, "nbytes", 0) or 0
)
n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1
pc_start = self._next_pc
self._next_pc = (self._next_pc + n_flits) % self._num_pcs
self._txn_state[tid] = {
"pc_start": pc_start,
"last_finish": env.now,
}
state = self._txn_state[tid]
pc = (state["pc_start"] + flit.flit_index) % self._num_pcs
pc = self._pc_for_address(flit.address)
switch_cost = 0.0
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
switch_cost = self._switch_penalty_ns
@@ -124,11 +129,22 @@ class HbmCtrlComponent(ComponentBase):
state["last_finish"] = finish
if flit.is_last:
wait = state["last_finish"] - env.now
if wait > 0:
yield env.timeout(wait)
del self._txn_state[tid]
yield from self._send_response(env, txn)
# Finalize asynchronously so the worker can pick up the next
# flit while this txn's last PC commit drains. Without this
# split, the worker's ``yield env.timeout(wait)`` would
# serialize concurrent single-flit txns at chunk_time even
# when they hit distinct PCs, hiding address-based PC
# parallelism (ADR-0033 D6).
env.process(self._finalize_txn(env, txn, state["last_finish"]))
def _finalize_txn(
self, env: simpy.Environment, txn: Any, last_finish: float,
) -> Generator:
wait = last_finish - env.now
if wait > 0:
yield env.timeout(wait)
yield from self._send_response(env, txn)
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
is_write = self._is_write(txn)
@@ -146,11 +162,12 @@ class HbmCtrlComponent(ComponentBase):
yield from self.run(env, txn.nbytes)
base_addr = int(getattr(txn, "base_address", 0))
last_finish = env.now
for i in range(n_chunks):
if chunk_interval > 0:
yield env.timeout(chunk_interval)
pc = (self._next_pc + i) % self._num_pcs
pc = self._pc_for_address(base_addr + i * self._burst_bytes)
switch_cost = 0.0
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
switch_cost = self._switch_penalty_ns
@@ -160,8 +177,6 @@ class HbmCtrlComponent(ComponentBase):
self._pc_last_dir[pc] = new_dir
if finish > last_finish:
last_finish = finish
if n_chunks > 0:
self._next_pc = (self._next_pc + n_chunks) % self._num_pcs
wait = last_finish - env.now
if wait > 0:
+11
View File
@@ -19,6 +19,17 @@ _LOCAL_MASK = (1 << _LOCAL_BITS) - 1
_AHBM_SEL_BIT = 37
_AHBM_LOCAL_USED = 38 # bits actually meaningful for AHBM
# HBM-offset bit layout for PC (pseudo-channel) striping
# (ADR-0033 D6, ADR-0019). Given burst_bytes = 2^B and num_pcs = 2^P
# configured at hbm_ctrl, the PC index is derived from hbm_offset as
# pc_shift = B; pc_mask = (1 << P) - 1
# pc = (hbm_offset >> pc_shift) & pc_mask
# Canonical default (burst_bytes=256, num_pcs=8 => B=8, P=3) maps:
# hbm_offset[36:11] row/bank/column within PC slice
# hbm_offset[10: 8] pc_index (3 bits, selects 1 of 8 PCs)
# hbm_offset[ 7: 0] within-burst offset (256 B, same PC)
# Shift/mask are computed at runtime from topology config, not hardcoded.
# Resource window: [36:34] resource_kind, [33:0] kind_local
_RES_KIND_SHIFT = 34
_RES_KIND_MASK = 0x7
+3 -1
View File
@@ -400,6 +400,7 @@ class GraphEngine:
request=request, path=path, step=0,
nbytes=request.nbytes if is_write else 0,
done=txn_done, drain_ns=drain_ns,
base_address=pa.hbm_offset,
)
yield self._host_queues[pcie_ep_id].put(txn)
@@ -424,7 +425,8 @@ class GraphEngine:
start_ns = self._env.now
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
done=txn_done, drain_ns=drain_ns)
done=txn_done, drain_ns=drain_ns,
base_address=pa.hbm_offset)
yield self._pe_dma_queues[pe_dma_id].put(txn)
yield txn_done
total_ns = self._env.now - start_ns
+5
View File
@@ -29,6 +29,8 @@ class Transaction:
drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal)
is_response: bool = False # True when carrying ResponseMsg on reverse path
result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.)
base_address: int = 0 # HBM byte offset of the first chunk; per-flit addresses
# derived as base + flit_index * flit_bytes (ADR-0033 D6)
@property
def next_hop(self) -> str | None:
@@ -47,6 +49,7 @@ class Transaction:
drain_ns=self.drain_ns,
is_response=self.is_response,
result_data=self.result_data,
base_address=self.base_address,
)
def into_flits(self, flit_bytes: int) -> Iterator[Flit]:
@@ -71,6 +74,7 @@ class Transaction:
flit_index=i,
flit_nbytes=size,
is_last=(i == n_total - 1),
address=self.base_address + i * flit_bytes,
)
@@ -91,3 +95,4 @@ class Flit:
flit_index: int # 0..n_flits-1
flit_nbytes: int # bytes carried (usually flit_bytes; last may be smaller)
is_last: bool # True for the terminating flit
address: int = 0 # HBM byte offset for this flit's chunk (ADR-0033 D6)