"""Tests for address-based PC selection at HBM CTRL (ADR-0033 D6). Replaces the prior global round-robin PC selection. PC index is now derived from each chunk's HBM byte-address: pc_shift = log2(burst_bytes) # default 8 for 256B pc_mask = num_pcs - 1 # default 7 for 8 PCs pc = (address >> pc_shift) & pc_mask Most assertions inspect ``HbmCtrlComponent._pc_avail`` directly rather than end-to-end makespan: at small payloads UCIe's per-txn overhead (8 ns) is identical to a chunk_time at the default pc_bw_gbs (32 GB/s × 256 B), so PC contention is fully masked by upstream serialization in the makespan view. The PC ledger is the authoritative signal of which PCs were charged. """ from __future__ import annotations from pathlib import Path import pytest import simpy from kernbench.components.builtin.hbm_ctrl import HbmCtrlComponent from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import MemoryWriteMsg from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import load_topology from kernbench.topology.types import Node TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" def _hbm_pa(pe_id: int = 0, offset: int = 0) -> int: slice_bytes = 48 * (1 << 30) // 8 return PhysAddr.pe_hbm_addr( sip_id=0, die_id=0, pe_id=pe_id, pe_local_hbm_offset=offset, slice_size_bytes=slice_bytes, ).encode() def _write_msg(req_id: str, pe_id: int, offset: int, nbytes: int) -> MemoryWriteMsg: return MemoryWriteMsg( correlation_id="addr-pc", request_id=req_id, dst_sip=0, dst_cube=0, dst_pe=pe_id, dst_pa=_hbm_pa(pe_id=pe_id, offset=offset), nbytes=nbytes, pattern="zero", target_pe=pe_id, ) def _engine() -> GraphEngine: return GraphEngine(load_topology(TOPOLOGY_PATH)) def _hbm_ctrl(eng: GraphEngine, cube_id: int = 0) -> HbmCtrlComponent: return eng._components[f"sip0.cube{cube_id}.hbm_ctrl"] def _run(eng: GraphEngine, msgs: list[MemoryWriteMsg]) -> None: handles = [eng.submit(m) for m in msgs] for h in handles: eng.wait(h) # ── 1. Canonical bit mapping ───────────────────────────────────────── def test_canonical_bit_mapping_256_8(): """burst_bytes=256, num_pcs=8 must derive pc_shift=8, pc_mask=7. PC selection on bits [10:8] of the address.""" node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", attrs={"num_pcs": 8, "burst_bytes": 256, "pc_bw_gbs": 32.0}, pos_mm=None) comp = HbmCtrlComponent(node, None) comp.start(simpy.Environment()) assert comp._pc_shift == 8 assert comp._pc_mask == 7 for i in range(8): addr = i * 256 assert comp._pc_for_address(addr) == i, ( f"addr=0x{addr:x} expected PC{i}, got PC{comp._pc_for_address(addr)}" ) # Wrap at 8 * burst assert comp._pc_for_address(0x800) == 0 assert comp._pc_for_address(0x900) == 1 # Within-burst addresses share PC assert comp._pc_for_address(0x000) == 0 assert comp._pc_for_address(0x0FF) == 0 assert comp._pc_for_address(0x100) == 1 assert comp._pc_for_address(0x1FF) == 1 # ── 2. Strided 8 writes → all 8 PCs touched, balanced ──────────────── def test_strided_8_writes_charge_all_pcs(): """8 concurrent 256B writes at offsets 0, 256, ..., 1792 must charge each of the 8 PCs exactly once. Verified via _pc_avail ledger: every PC must be non-zero (load distributed across all 8 PCs). Per-PC work amount = pc_avail - arrival_at_PC. For 1 chunk on each PC, that equals chunk_time. So pc_avail[i] should be roughly equal to (arrival_i + chunk_time). The arrival times are staggered by UCIe's per-txn overhead, so absolute pc_avail values differ — but the WORK assigned to each PC is 1 chunk.""" eng = _engine() ctrl = _hbm_ctrl(eng) chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs msgs = [_write_msg(f"s-{i}", pe_id=0, offset=i * 256, nbytes=256) for i in range(8)] _run(eng, msgs) for pc in range(8): assert ctrl._pc_avail[pc] >= chunk_time, ( f"PC {pc} must be charged ≥ 1 chunk of work; " f"got {ctrl._pc_avail[pc]:.2f}ns chunk_time={chunk_time:.2f}ns " f"pc_avail={ctrl._pc_avail}" ) # ── 3. Same address → only PC 0 advances ───────────────────────────── def test_same_address_only_charges_pc0(): """4 concurrent 256B writes to identical offset 0x1000 must all charge PC 0 ((0x1000 >> 8) & 7 = 0) and no other PC. PC 0 must have run 4 chunks back-to-back (cumulative time ≥ 4 × chunk_time).""" eng = _engine() ctrl = _hbm_ctrl(eng) chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs msgs = [_write_msg(f"c-{i}", pe_id=0, offset=0x1000, nbytes=256) for i in range(4)] _run(eng, msgs) # Only PC 0 should be non-zero assert ctrl._pc_avail[0] > 0, f"PC 0 must be charged; pc_avail={ctrl._pc_avail}" for pc in range(1, 8): assert ctrl._pc_avail[pc] == 0, ( f"PC {pc} must not be charged (same-address only hits PC 0); " f"pc_avail={ctrl._pc_avail}" ) # PC 0 chained 4 commits back-to-back. The last finish time must be # at least the cumulative chunk_time (commits are serialized on PC 0). assert ctrl._pc_avail[0] >= 4 * chunk_time, ( f"PC 0 should chain 4 chunk_time commits; " f"pc_avail[0]={ctrl._pc_avail[0]:.2f}ns expected ≥ {4*chunk_time:.2f}ns" ) # ── 4. PC-aligned multiples collide (Scenario A from ADR-0033 D6) ──── def test_2kb_pairs_with_pc_aligned_offset_collide(): """Two 2KB writes at offsets 0 and 2048 (= num_pcs * burst_bytes) span PCs 0..7 each, starting at PC 0 in both cases. All 8 PCs must be charged TWICE (2 chunks each). pc_avail[i] should hold at least 2 * chunk_time of cumulative work on every PC.""" eng = _engine() ctrl = _hbm_ctrl(eng) chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs msgs = [ _write_msg("a", pe_id=0, offset=0, nbytes=2048), _write_msg("b", pe_id=0, offset=2048, nbytes=2048), ] _run(eng, msgs) # All 8 PCs charged, each at least 2 chunks worth. for pc in range(8): assert ctrl._pc_avail[pc] >= 2 * chunk_time, ( f"PC {pc} should have ≥ 2 chunks of work after PC-aligned " f"2KB pair; got {ctrl._pc_avail[pc]:.2f}ns " f"(2*chunk_time={2*chunk_time:.2f}ns); pc_avail={ctrl._pc_avail}" ) # ── 5. Dynamic pc_shift from burst_bytes ───────────────────────────── def test_dynamic_pc_shift_when_burst_changes(): """Override burst_bytes to 128 → pc_shift must be 7 (not the default 8). Verified directly via _pc_for_address: 0x080 lands on PC 1 (it would be PC 0 under the default shift=8).""" node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", attrs={"num_pcs": 8, "burst_bytes": 128, "pc_bw_gbs": 32.0}, pos_mm=None) comp = HbmCtrlComponent(node, None) comp.start(simpy.Environment()) assert comp._pc_shift == 7 assert comp._pc_mask == 7 assert comp._pc_for_address(0x000) == 0 assert comp._pc_for_address(0x080) == 1 assert comp._pc_for_address(0x100) == 2 assert comp._pc_for_address(0x400) == 0 # wrap at 8 * 128 = 1024 # ── 6. Power-of-2 validation ───────────────────────────────────────── def test_non_power_of_two_num_pcs_rejected(): node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", attrs={"num_pcs": 6, "burst_bytes": 256}, pos_mm=None) comp = HbmCtrlComponent(node, None) with pytest.raises(ValueError, match="num_pcs"): comp.start(simpy.Environment()) def test_non_power_of_two_burst_bytes_rejected(): node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl", attrs={"num_pcs": 8, "burst_bytes": 300}, pos_mm=None) comp = HbmCtrlComponent(node, None) with pytest.raises(ValueError, match="burst_bytes"): comp.start(simpy.Environment())