diff --git a/docs/adr/ADR-0033-latency-model-assumptions.md b/docs/adr/ADR-0033-latency-model-assumptions.md index 4ca622d..5d4024f 100644 --- a/docs/adr/ADR-0033-latency-model-assumptions.md +++ b/docs/adr/ADR-0033-latency-model-assumptions.md @@ -111,12 +111,28 @@ below are different concerns, ordered by expected workload impact. **Higher impact (workload accuracy gap)**: - [ ] **Address-based PC selection at HBM CTRL** (replace the - address-blind global round-robin). When two transactions of size - `num_pcs × burst_bytes` (e.g., 2KB at 8 PCs × 256B) arrive - concurrently, both claim PCs 0..7 via global RR, producing full - per-PC contention even when real-HW address striping would put - them on disjoint PC sets. Directly affects multi-PE concurrent - HBM workload latencies. + address-blind global round-robin). Compute the PC index from + the HBM byte offset using parameters already in topology config: + + pc_shift = log2(burst_bytes) # default 8 (burst=256B) + pc_mask = num_pcs - 1 # default 7 (8 PCs) + pc = (hbm_offset >> pc_shift) & pc_mask + + For the default `burst_bytes=256, num_pcs=8` this places the PC + select field at HBM byte-offset bits **[10:8]**: bits [7:0] are + the within-burst offset (same PC), bits [10:8] are the 3-bit PC + index, and bits [36:11] are row/bank/column within the PC slice. + Shift/mask are derived from topology config rather than hardcoded + so alternative `(burst_bytes, num_pcs)` pairs stay consistent. + See `src/kernbench/policy/address/phyaddr.py` for the canonical + comment. + + Real-HW workloads where this matters most: (a) strided multi- + transaction streams that under global-RR collide on the same PCs + but under address-striping land on disjoint sets; (b) offset- + disjoint parallel transfers where address-striping preserves + parallelism while global-RR re-serializes them. Directly affects + multi-PE concurrent HBM workload latencies. - [ ] **Bank-level conflict modeling** within a PC (opt-in via `track_banks: true`). Currently we assume no same-bank reuse; random scatter/gather workloads are optimistic here. diff --git a/src/kernbench/components/builtin/hbm_ctrl.py b/src/kernbench/components/builtin/hbm_ctrl.py index 945bbf4..09bcaa6 100644 --- a/src/kernbench/components/builtin/hbm_ctrl.py +++ b/src/kernbench/components/builtin/hbm_ctrl.py @@ -45,7 +45,10 @@ class HbmCtrlComponent(ComponentBase): self._switch_penalty_ns: float = 0.0 self._pc_avail: list[float] = [] self._pc_last_dir: list[str | None] = [] - self._next_pc: int = 0 + # Address-based PC selection (ADR-0033 D6): + # pc = (address >> _pc_shift) & _pc_mask + self._pc_shift: int = 0 + self._pc_mask: int = 0 # Per-txn flit accumulation state (ADR-0033 Phase 2c-3). self._txn_state: dict[int, dict[str, Any]] = {} @@ -55,11 +58,19 @@ class HbmCtrlComponent(ComponentBase): self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0)) self._burst_bytes = int(attrs.get("burst_bytes", 256)) self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0)) + if self._num_pcs <= 0 or (self._num_pcs & (self._num_pcs - 1)) != 0: + raise ValueError(f"num_pcs must be a positive power of 2, got {self._num_pcs}") + if self._burst_bytes <= 0 or (self._burst_bytes & (self._burst_bytes - 1)) != 0: + raise ValueError(f"burst_bytes must be a positive power of 2, got {self._burst_bytes}") + self._pc_shift = self._burst_bytes.bit_length() - 1 + self._pc_mask = self._num_pcs - 1 self._pc_avail = [0.0] * self._num_pcs self._pc_last_dir = [None] * self._num_pcs - self._next_pc = 0 super().start(env) + def _pc_for_address(self, address: int) -> int: + return (int(address) >> self._pc_shift) & self._pc_mask + def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) @@ -88,9 +99,10 @@ class HbmCtrlComponent(ComponentBase): env.process(self._handle_txn(env, msg)) def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator: - """Per-flit PC commit. On first flit of a txn, claim PC range and - apply overhead. On ``is_last``, wait for last PC commit to - finish, then send the response.""" + """Per-flit PC commit. On first flit of a txn, apply overhead. PC is + derived from the flit's address (ADR-0033 D6 address-based striping). + On ``is_last``, wait for last PC commit to finish, then send the + response.""" txn = flit.txn tid = id(txn) chunk_time = ( @@ -100,19 +112,12 @@ class HbmCtrlComponent(ComponentBase): if tid not in self._txn_state: yield from self.run(env, txn.nbytes) - work_bytes = txn.nbytes if txn.nbytes > 0 else int( - getattr(txn.request, "nbytes", 0) or 0 - ) - n_flits = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 1 - pc_start = self._next_pc - self._next_pc = (self._next_pc + n_flits) % self._num_pcs self._txn_state[tid] = { - "pc_start": pc_start, "last_finish": env.now, } state = self._txn_state[tid] - pc = (state["pc_start"] + flit.flit_index) % self._num_pcs + pc = self._pc_for_address(flit.address) switch_cost = 0.0 if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir: switch_cost = self._switch_penalty_ns @@ -124,11 +129,22 @@ class HbmCtrlComponent(ComponentBase): state["last_finish"] = finish if flit.is_last: - wait = state["last_finish"] - env.now - if wait > 0: - yield env.timeout(wait) del self._txn_state[tid] - yield from self._send_response(env, txn) + # Finalize asynchronously so the worker can pick up the next + # flit while this txn's last PC commit drains. Without this + # split, the worker's ``yield env.timeout(wait)`` would + # serialize concurrent single-flit txns at chunk_time even + # when they hit distinct PCs, hiding address-based PC + # parallelism (ADR-0033 D6). + env.process(self._finalize_txn(env, txn, state["last_finish"])) + + def _finalize_txn( + self, env: simpy.Environment, txn: Any, last_finish: float, + ) -> Generator: + wait = last_finish - env.now + if wait > 0: + yield env.timeout(wait) + yield from self._send_response(env, txn) def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator: is_write = self._is_write(txn) @@ -146,11 +162,12 @@ class HbmCtrlComponent(ComponentBase): yield from self.run(env, txn.nbytes) + base_addr = int(getattr(txn, "base_address", 0)) last_finish = env.now for i in range(n_chunks): if chunk_interval > 0: yield env.timeout(chunk_interval) - pc = (self._next_pc + i) % self._num_pcs + pc = self._pc_for_address(base_addr + i * self._burst_bytes) switch_cost = 0.0 if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir: switch_cost = self._switch_penalty_ns @@ -160,8 +177,6 @@ class HbmCtrlComponent(ComponentBase): self._pc_last_dir[pc] = new_dir if finish > last_finish: last_finish = finish - if n_chunks > 0: - self._next_pc = (self._next_pc + n_chunks) % self._num_pcs wait = last_finish - env.now if wait > 0: diff --git a/src/kernbench/policy/address/phyaddr.py b/src/kernbench/policy/address/phyaddr.py index 5abf625..4394ba6 100644 --- a/src/kernbench/policy/address/phyaddr.py +++ b/src/kernbench/policy/address/phyaddr.py @@ -19,6 +19,17 @@ _LOCAL_MASK = (1 << _LOCAL_BITS) - 1 _AHBM_SEL_BIT = 37 _AHBM_LOCAL_USED = 38 # bits actually meaningful for AHBM +# HBM-offset bit layout for PC (pseudo-channel) striping +# (ADR-0033 D6, ADR-0019). Given burst_bytes = 2^B and num_pcs = 2^P +# configured at hbm_ctrl, the PC index is derived from hbm_offset as +# pc_shift = B; pc_mask = (1 << P) - 1 +# pc = (hbm_offset >> pc_shift) & pc_mask +# Canonical default (burst_bytes=256, num_pcs=8 => B=8, P=3) maps: +# hbm_offset[36:11] row/bank/column within PC slice +# hbm_offset[10: 8] pc_index (3 bits, selects 1 of 8 PCs) +# hbm_offset[ 7: 0] within-burst offset (256 B, same PC) +# Shift/mask are computed at runtime from topology config, not hardcoded. + # Resource window: [36:34] resource_kind, [33:0] kind_local _RES_KIND_SHIFT = 34 _RES_KIND_MASK = 0x7 diff --git a/src/kernbench/sim_engine/engine.py b/src/kernbench/sim_engine/engine.py index e2ac5ae..8fc3bd1 100644 --- a/src/kernbench/sim_engine/engine.py +++ b/src/kernbench/sim_engine/engine.py @@ -400,6 +400,7 @@ class GraphEngine: request=request, path=path, step=0, nbytes=request.nbytes if is_write else 0, done=txn_done, drain_ns=drain_ns, + base_address=pa.hbm_offset, ) yield self._host_queues[pcie_ep_id].put(txn) @@ -424,7 +425,8 @@ class GraphEngine: start_ns = self._env.now txn_done = self._env.event() txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes, - done=txn_done, drain_ns=drain_ns) + done=txn_done, drain_ns=drain_ns, + base_address=pa.hbm_offset) yield self._pe_dma_queues[pe_dma_id].put(txn) yield txn_done total_ns = self._env.now - start_ns diff --git a/src/kernbench/sim_engine/transaction.py b/src/kernbench/sim_engine/transaction.py index 1a91548..64efcd5 100644 --- a/src/kernbench/sim_engine/transaction.py +++ b/src/kernbench/sim_engine/transaction.py @@ -29,6 +29,8 @@ class Transaction: drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal) is_response: bool = False # True when carrying ResponseMsg on reverse path result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.) + base_address: int = 0 # HBM byte offset of the first chunk; per-flit addresses + # derived as base + flit_index * flit_bytes (ADR-0033 D6) @property def next_hop(self) -> str | None: @@ -47,6 +49,7 @@ class Transaction: drain_ns=self.drain_ns, is_response=self.is_response, result_data=self.result_data, + base_address=self.base_address, ) def into_flits(self, flit_bytes: int) -> Iterator[Flit]: @@ -71,6 +74,7 @@ class Transaction: flit_index=i, flit_nbytes=size, is_last=(i == n_total - 1), + address=self.base_address + i * flit_bytes, ) @@ -91,3 +95,4 @@ class Flit: flit_index: int # 0..n_flits-1 flit_nbytes: int # bytes carried (usually flit_bytes; last may be smaller) is_last: bool # True for the terminating flit + address: int = 0 # HBM byte offset for this flit's chunk (ADR-0033 D6) diff --git a/tests/test_hbm_address_based_pc.py b/tests/test_hbm_address_based_pc.py new file mode 100644 index 0000000..89bc737 --- /dev/null +++ b/tests/test_hbm_address_based_pc.py @@ -0,0 +1,216 @@ +"""Tests for address-based PC selection at HBM CTRL (ADR-0033 D6). + +Replaces the prior global round-robin PC selection. PC index is now derived +from each chunk's HBM byte-address: + pc_shift = log2(burst_bytes) # default 8 for 256B + pc_mask = num_pcs - 1 # default 7 for 8 PCs + pc = (address >> pc_shift) & pc_mask + +Most assertions inspect ``HbmCtrlComponent._pc_avail`` directly rather than +end-to-end makespan: at small payloads UCIe's per-txn overhead (8 ns) is +identical to a chunk_time at the default pc_bw_gbs (32 GB/s × 256 B), so PC +contention is fully masked by upstream serialization in the makespan view. +The PC ledger is the authoritative signal of which PCs were charged. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import simpy + +from kernbench.components.builtin.hbm_ctrl import HbmCtrlComponent +from kernbench.policy.address.phyaddr import PhysAddr +from kernbench.runtime_api.kernel import MemoryWriteMsg +from kernbench.sim_engine.engine import GraphEngine +from kernbench.topology.builder import load_topology +from kernbench.topology.types import Node + +TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" + + +def _hbm_pa(pe_id: int = 0, offset: int = 0) -> int: + slice_bytes = 48 * (1 << 30) // 8 + return PhysAddr.pe_hbm_addr( + sip_id=0, die_id=0, pe_id=pe_id, + pe_local_hbm_offset=offset, slice_size_bytes=slice_bytes, + ).encode() + + +def _write_msg(req_id: str, pe_id: int, offset: int, nbytes: int) -> MemoryWriteMsg: + return MemoryWriteMsg( + correlation_id="addr-pc", request_id=req_id, + dst_sip=0, dst_cube=0, dst_pe=pe_id, + dst_pa=_hbm_pa(pe_id=pe_id, offset=offset), nbytes=nbytes, + pattern="zero", target_pe=pe_id, + ) + + +def _engine() -> GraphEngine: + return GraphEngine(load_topology(TOPOLOGY_PATH)) + + +def _hbm_ctrl(eng: GraphEngine, cube_id: int = 0) -> HbmCtrlComponent: + return eng._components[f"sip0.cube{cube_id}.hbm_ctrl"] + + +def _run(eng: GraphEngine, msgs: list[MemoryWriteMsg]) -> None: + handles = [eng.submit(m) for m in msgs] + for h in handles: + eng.wait(h) + + +# ── 1. Canonical bit mapping ───────────────────────────────────────── + + +def test_canonical_bit_mapping_256_8(): + """burst_bytes=256, num_pcs=8 must derive pc_shift=8, pc_mask=7. PC + selection on bits [10:8] of the address.""" + node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", + attrs={"num_pcs": 8, "burst_bytes": 256, "pc_bw_gbs": 32.0}, + pos_mm=None) + comp = HbmCtrlComponent(node, None) + comp.start(simpy.Environment()) + + assert comp._pc_shift == 8 + assert comp._pc_mask == 7 + + for i in range(8): + addr = i * 256 + assert comp._pc_for_address(addr) == i, ( + f"addr=0x{addr:x} expected PC{i}, got PC{comp._pc_for_address(addr)}" + ) + # Wrap at 8 * burst + assert comp._pc_for_address(0x800) == 0 + assert comp._pc_for_address(0x900) == 1 + # Within-burst addresses share PC + assert comp._pc_for_address(0x000) == 0 + assert comp._pc_for_address(0x0FF) == 0 + assert comp._pc_for_address(0x100) == 1 + assert comp._pc_for_address(0x1FF) == 1 + + +# ── 2. Strided 8 writes → all 8 PCs touched, balanced ──────────────── + + +def test_strided_8_writes_charge_all_pcs(): + """8 concurrent 256B writes at offsets 0, 256, ..., 1792 must charge + each of the 8 PCs exactly once. Verified via _pc_avail ledger: every + PC must be non-zero (load distributed across all 8 PCs). + + Per-PC work amount = pc_avail - arrival_at_PC. For 1 chunk on each + PC, that equals chunk_time. So pc_avail[i] should be roughly equal + to (arrival_i + chunk_time). The arrival times are staggered by + UCIe's per-txn overhead, so absolute pc_avail values differ — but + the WORK assigned to each PC is 1 chunk.""" + eng = _engine() + ctrl = _hbm_ctrl(eng) + chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs + + msgs = [_write_msg(f"s-{i}", pe_id=0, offset=i * 256, nbytes=256) + for i in range(8)] + _run(eng, msgs) + + for pc in range(8): + assert ctrl._pc_avail[pc] >= chunk_time, ( + f"PC {pc} must be charged ≥ 1 chunk of work; " + f"got {ctrl._pc_avail[pc]:.2f}ns chunk_time={chunk_time:.2f}ns " + f"pc_avail={ctrl._pc_avail}" + ) + + +# ── 3. Same address → only PC 0 advances ───────────────────────────── + + +def test_same_address_only_charges_pc0(): + """4 concurrent 256B writes to identical offset 0x1000 must all charge + PC 0 ((0x1000 >> 8) & 7 = 0) and no other PC. PC 0 must have run 4 + chunks back-to-back (cumulative time ≥ 4 × chunk_time).""" + eng = _engine() + ctrl = _hbm_ctrl(eng) + chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs + + msgs = [_write_msg(f"c-{i}", pe_id=0, offset=0x1000, nbytes=256) + for i in range(4)] + _run(eng, msgs) + + # Only PC 0 should be non-zero + assert ctrl._pc_avail[0] > 0, f"PC 0 must be charged; pc_avail={ctrl._pc_avail}" + for pc in range(1, 8): + assert ctrl._pc_avail[pc] == 0, ( + f"PC {pc} must not be charged (same-address only hits PC 0); " + f"pc_avail={ctrl._pc_avail}" + ) + # PC 0 chained 4 commits back-to-back. The last finish time must be + # at least the cumulative chunk_time (commits are serialized on PC 0). + assert ctrl._pc_avail[0] >= 4 * chunk_time, ( + f"PC 0 should chain 4 chunk_time commits; " + f"pc_avail[0]={ctrl._pc_avail[0]:.2f}ns expected ≥ {4*chunk_time:.2f}ns" + ) + + +# ── 4. PC-aligned multiples collide (Scenario A from ADR-0033 D6) ──── + + +def test_2kb_pairs_with_pc_aligned_offset_collide(): + """Two 2KB writes at offsets 0 and 2048 (= num_pcs * burst_bytes) span + PCs 0..7 each, starting at PC 0 in both cases. All 8 PCs must be + charged TWICE (2 chunks each). pc_avail[i] should hold at least + 2 * chunk_time of cumulative work on every PC.""" + eng = _engine() + ctrl = _hbm_ctrl(eng) + chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs + + msgs = [ + _write_msg("a", pe_id=0, offset=0, nbytes=2048), + _write_msg("b", pe_id=0, offset=2048, nbytes=2048), + ] + _run(eng, msgs) + + # All 8 PCs charged, each at least 2 chunks worth. + for pc in range(8): + assert ctrl._pc_avail[pc] >= 2 * chunk_time, ( + f"PC {pc} should have ≥ 2 chunks of work after PC-aligned " + f"2KB pair; got {ctrl._pc_avail[pc]:.2f}ns " + f"(2*chunk_time={2*chunk_time:.2f}ns); pc_avail={ctrl._pc_avail}" + ) + + +# ── 5. Dynamic pc_shift from burst_bytes ───────────────────────────── + + +def test_dynamic_pc_shift_when_burst_changes(): + """Override burst_bytes to 128 → pc_shift must be 7 (not the default 8). + Verified directly via _pc_for_address: 0x080 lands on PC 1 (it would + be PC 0 under the default shift=8).""" + node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", + attrs={"num_pcs": 8, "burst_bytes": 128, "pc_bw_gbs": 32.0}, + pos_mm=None) + comp = HbmCtrlComponent(node, None) + comp.start(simpy.Environment()) + + assert comp._pc_shift == 7 + assert comp._pc_mask == 7 + assert comp._pc_for_address(0x000) == 0 + assert comp._pc_for_address(0x080) == 1 + assert comp._pc_for_address(0x100) == 2 + assert comp._pc_for_address(0x400) == 0 # wrap at 8 * 128 = 1024 + + +# ── 6. Power-of-2 validation ───────────────────────────────────────── + + +def test_non_power_of_two_num_pcs_rejected(): + node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", + attrs={"num_pcs": 6, "burst_bytes": 256}, pos_mm=None) + comp = HbmCtrlComponent(node, None) + with pytest.raises(ValueError, match="num_pcs"): + comp.start(simpy.Environment()) + + +def test_non_power_of_two_burst_bytes_rejected(): + node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", + attrs={"num_pcs": 8, "burst_bytes": 300}, pos_mm=None) + comp = HbmCtrlComponent(node, None) + with pytest.raises(ValueError, match="burst_bytes"): + comp.start(simpy.Environment())