Files
kernbench2/tests/test_hbm_address_based_pc.py
T
ywkang aaa1cbfaf6 ADR-0033 D6: address-based PC selection at HBM CTRL
Replaces global round-robin with deterministic address-derived PC
striping:

    pc_shift = log2(burst_bytes)
    pc_mask  = num_pcs - 1
    pc       = (flit.address >> pc_shift) & pc_mask

Each Transaction carries base_address (HBM byte offset of the first
chunk); each Flit derives its own address as base + i*flit_bytes.
HBM CTRL routes flits to PCs via this formula, replacing the
arrival-order RR pointer. Also splits the is_last wait into an
asynchronous _finalize_txn process so the worker isn't blocked on
PC commit, exposing true PC parallelism for disjoint addresses.

phyaddr.py documents the canonical bit layout (bits [10:8] for the
default burst=256, num_pcs=8 case). ADR-0033 D6 records the
derivation and the workload scenarios where address-striping
matters (strided streams, offset-disjoint parallel transfers).

Adds tests/test_hbm_address_based_pc.py: canonical bit mapping,
strided 8-way load distribution, same-address PC-0 serialization,
PC-aligned 2KB pair collision, dynamic pc_shift from burst_bytes,
and power-of-2 attr validation. Integration tests inspect
_pc_avail ledger directly: at default config UCIe's 8 ns per-txn
overhead exactly matches chunk_time, masking PC contention at the
makespan level even though the ledger correctly distinguishes the
cases.

Full suite: 631 passed, 1 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 00:18:46 -07:00

217 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for address-based PC selection at HBM CTRL (ADR-0033 D6).
Replaces the prior global round-robin PC selection. PC index is now derived
from each chunk's HBM byte-address:
pc_shift = log2(burst_bytes) # default 8 for 256B
pc_mask = num_pcs - 1 # default 7 for 8 PCs
pc = (address >> pc_shift) & pc_mask
Most assertions inspect ``HbmCtrlComponent._pc_avail`` directly rather than
end-to-end makespan: at small payloads UCIe's per-txn overhead (8 ns) is
identical to a chunk_time at the default pc_bw_gbs (32 GB/s × 256 B), so PC
contention is fully masked by upstream serialization in the makespan view.
The PC ledger is the authoritative signal of which PCs were charged.
"""
from __future__ import annotations
from pathlib import Path
import pytest
import simpy
from kernbench.components.builtin.hbm_ctrl import HbmCtrlComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
from kernbench.topology.types import Node
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _hbm_pa(pe_id: int = 0, offset: int = 0) -> int:
slice_bytes = 48 * (1 << 30) // 8
return PhysAddr.pe_hbm_addr(
sip_id=0, die_id=0, pe_id=pe_id,
pe_local_hbm_offset=offset, slice_size_bytes=slice_bytes,
).encode()
def _write_msg(req_id: str, pe_id: int, offset: int, nbytes: int) -> MemoryWriteMsg:
return MemoryWriteMsg(
correlation_id="addr-pc", request_id=req_id,
dst_sip=0, dst_cube=0, dst_pe=pe_id,
dst_pa=_hbm_pa(pe_id=pe_id, offset=offset), nbytes=nbytes,
pattern="zero", target_pe=pe_id,
)
def _engine() -> GraphEngine:
return GraphEngine(load_topology(TOPOLOGY_PATH))
def _hbm_ctrl(eng: GraphEngine, cube_id: int = 0) -> HbmCtrlComponent:
return eng._components[f"sip0.cube{cube_id}.hbm_ctrl"]
def _run(eng: GraphEngine, msgs: list[MemoryWriteMsg]) -> None:
handles = [eng.submit(m) for m in msgs]
for h in handles:
eng.wait(h)
# ── 1. Canonical bit mapping ─────────────────────────────────────────
def test_canonical_bit_mapping_256_8():
"""burst_bytes=256, num_pcs=8 must derive pc_shift=8, pc_mask=7. PC
selection on bits [10:8] of the address."""
node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl",
attrs={"num_pcs": 8, "burst_bytes": 256, "pc_bw_gbs": 32.0},
pos_mm=None)
comp = HbmCtrlComponent(node, None)
comp.start(simpy.Environment())
assert comp._pc_shift == 8
assert comp._pc_mask == 7
for i in range(8):
addr = i * 256
assert comp._pc_for_address(addr) == i, (
f"addr=0x{addr:x} expected PC{i}, got PC{comp._pc_for_address(addr)}"
)
# Wrap at 8 * burst
assert comp._pc_for_address(0x800) == 0
assert comp._pc_for_address(0x900) == 1
# Within-burst addresses share PC
assert comp._pc_for_address(0x000) == 0
assert comp._pc_for_address(0x0FF) == 0
assert comp._pc_for_address(0x100) == 1
assert comp._pc_for_address(0x1FF) == 1
# ── 2. Strided 8 writes → all 8 PCs touched, balanced ────────────────
def test_strided_8_writes_charge_all_pcs():
"""8 concurrent 256B writes at offsets 0, 256, ..., 1792 must charge
each of the 8 PCs exactly once. Verified via _pc_avail ledger: every
PC must be non-zero (load distributed across all 8 PCs).
Per-PC work amount = pc_avail - arrival_at_PC. For 1 chunk on each
PC, that equals chunk_time. So pc_avail[i] should be roughly equal
to (arrival_i + chunk_time). The arrival times are staggered by
UCIe's per-txn overhead, so absolute pc_avail values differ — but
the WORK assigned to each PC is 1 chunk."""
eng = _engine()
ctrl = _hbm_ctrl(eng)
chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs
msgs = [_write_msg(f"s-{i}", pe_id=0, offset=i * 256, nbytes=256)
for i in range(8)]
_run(eng, msgs)
for pc in range(8):
assert ctrl._pc_avail[pc] >= chunk_time, (
f"PC {pc} must be charged ≥ 1 chunk of work; "
f"got {ctrl._pc_avail[pc]:.2f}ns chunk_time={chunk_time:.2f}ns "
f"pc_avail={ctrl._pc_avail}"
)
# ── 3. Same address → only PC 0 advances ─────────────────────────────
def test_same_address_only_charges_pc0():
"""4 concurrent 256B writes to identical offset 0x1000 must all charge
PC 0 ((0x1000 >> 8) & 7 = 0) and no other PC. PC 0 must have run 4
chunks back-to-back (cumulative time ≥ 4 × chunk_time)."""
eng = _engine()
ctrl = _hbm_ctrl(eng)
chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs
msgs = [_write_msg(f"c-{i}", pe_id=0, offset=0x1000, nbytes=256)
for i in range(4)]
_run(eng, msgs)
# Only PC 0 should be non-zero
assert ctrl._pc_avail[0] > 0, f"PC 0 must be charged; pc_avail={ctrl._pc_avail}"
for pc in range(1, 8):
assert ctrl._pc_avail[pc] == 0, (
f"PC {pc} must not be charged (same-address only hits PC 0); "
f"pc_avail={ctrl._pc_avail}"
)
# PC 0 chained 4 commits back-to-back. The last finish time must be
# at least the cumulative chunk_time (commits are serialized on PC 0).
assert ctrl._pc_avail[0] >= 4 * chunk_time, (
f"PC 0 should chain 4 chunk_time commits; "
f"pc_avail[0]={ctrl._pc_avail[0]:.2f}ns expected ≥ {4*chunk_time:.2f}ns"
)
# ── 4. PC-aligned multiples collide (Scenario A from ADR-0033 D6) ────
def test_2kb_pairs_with_pc_aligned_offset_collide():
"""Two 2KB writes at offsets 0 and 2048 (= num_pcs * burst_bytes) span
PCs 0..7 each, starting at PC 0 in both cases. All 8 PCs must be
charged TWICE (2 chunks each). pc_avail[i] should hold at least
2 * chunk_time of cumulative work on every PC."""
eng = _engine()
ctrl = _hbm_ctrl(eng)
chunk_time = ctrl._burst_bytes / ctrl._pc_bw_gbs
msgs = [
_write_msg("a", pe_id=0, offset=0, nbytes=2048),
_write_msg("b", pe_id=0, offset=2048, nbytes=2048),
]
_run(eng, msgs)
# All 8 PCs charged, each at least 2 chunks worth.
for pc in range(8):
assert ctrl._pc_avail[pc] >= 2 * chunk_time, (
f"PC {pc} should have ≥ 2 chunks of work after PC-aligned "
f"2KB pair; got {ctrl._pc_avail[pc]:.2f}ns "
f"(2*chunk_time={2*chunk_time:.2f}ns); pc_avail={ctrl._pc_avail}"
)
# ── 5. Dynamic pc_shift from burst_bytes ─────────────────────────────
def test_dynamic_pc_shift_when_burst_changes():
"""Override burst_bytes to 128 → pc_shift must be 7 (not the default 8).
Verified directly via _pc_for_address: 0x080 lands on PC 1 (it would
be PC 0 under the default shift=8)."""
node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl",
attrs={"num_pcs": 8, "burst_bytes": 128, "pc_bw_gbs": 32.0},
pos_mm=None)
comp = HbmCtrlComponent(node, None)
comp.start(simpy.Environment())
assert comp._pc_shift == 7
assert comp._pc_mask == 7
assert comp._pc_for_address(0x000) == 0
assert comp._pc_for_address(0x080) == 1
assert comp._pc_for_address(0x100) == 2
assert comp._pc_for_address(0x400) == 0 # wrap at 8 * 128 = 1024
# ── 6. Power-of-2 validation ─────────────────────────────────────────
def test_non_power_of_two_num_pcs_rejected():
node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl",
attrs={"num_pcs": 6, "burst_bytes": 256}, pos_mm=None)
comp = HbmCtrlComponent(node, None)
with pytest.raises(ValueError, match="num_pcs"):
comp.start(simpy.Environment())
def test_non_power_of_two_burst_bytes_rejected():
node = Node(id="t", kind="hbm_ctrl", impl="builtin.hbm_ctrl",
attrs={"num_pcs": 8, "burst_bytes": 300}, pos_mm=None)
comp = HbmCtrlComponent(node, None)
with pytest.raises(ValueError, match="burst_bytes"):
comp.start(simpy.Environment())