Latency model: HBM PC striping + chunk-loop drain (ADR-0033)

Previous model double-counted slow-upstream paths (e.g., 64KB via UCIe
128 GB/s was ~2x pessimistic). HBM CTRL now distributes bursts across
8 pseudo-channels via global round-robin, with per-chunk commit timing
that pipelines correctly against the bottleneck link's data arrival.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-14 21:59:07 -07:00
parent f6d262e359
commit 5fdb6f8797
11 changed files with 1192 additions and 52 deletions
+330
View File
@@ -0,0 +1,330 @@
"""Tests for HBM CTRL per-pseudo-channel (PC) striping model (ADR-0033).
Replaces the prior dual-channel `simpy.Resource(capacity=1)` model with a
stateless per-PC `available_at[N]` array, global round-robin chunking, and
read/write sharing per PC. Burst granularity is `burst_bytes` (default 256B).
These tests are written BEFORE the production change and are expected to
FAIL on current code (which serializes via Resource cap=1). Phase 2 must
make them PASS without weakening assertions.
Verification matrix references ADR-0033 D1 (modeled) and D2 (approximated).
"""
from __future__ import annotations
from pathlib import Path
import pytest
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryReadMsg, MemoryWriteMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology, resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _engine() -> GraphEngine:
return GraphEngine(load_topology(TOPOLOGY_PATH))
def _hbm_pa(sip: int = 0, cube: int = 0, pe_id: int = 0, offset: int = 0x1000) -> int:
slice_bytes = 48 * (1 << 30) // 8
return PhysAddr.pe_hbm_addr(
sip_id=sip, die_id=cube, pe_id=pe_id,
pe_local_hbm_offset=offset, slice_size_bytes=slice_bytes,
).encode()
def _write_msg(req_id: str, *, cube: int, pe: int, nbytes: int) -> MemoryWriteMsg:
return MemoryWriteMsg(
correlation_id="pc-striping", request_id=req_id,
dst_sip=0, dst_cube=cube, dst_pe=pe,
dst_pa=_hbm_pa(sip=0, cube=cube, pe_id=pe), nbytes=nbytes,
pattern="zero", target_pe=pe,
)
def _single_write_ns(nbytes: int, cube: int = 0, pe: int = 0) -> float:
eng = _engine()
msg = _write_msg(f"single-{cube}-{pe}-{nbytes}", cube=cube, pe=pe, nbytes=nbytes)
h = eng.submit(msg)
eng.wait(h)
_, t = eng.get_completion(h)
return t["total_ns"]
def _path_drain_for_write(eng: GraphEngine, msg: MemoryWriteMsg) -> float:
"""Compute engine path drain dynamically (test-time access to engine internals)."""
pcie_ep_id = eng._resolver.find_pcie_ep(msg.dst_sip)
pa = PhysAddr.decode(msg.dst_pa)
hbm_node = eng._resolver.resolve(pa)
path = eng._router.find_memory_path(pcie_ep_id, hbm_node)
return eng._path_drain_ns(path, msg.nbytes)
# ── 1. Builder derives pc_bw_gbs ──────────────────────────────────
def test_builder_derives_pc_bw_gbs():
"""Topology builder must inject `pc_bw_gbs = hbm_to_router_bw_gbs / num_pcs`
as an attr on every hbm_ctrl node. Enforces ADR-0019 D9 invariant
(channels_per_PE × per-PC BW = aggregated link BW) at build time.
"""
handle = resolve_topology(str(TOPOLOGY_PATH))
topo = handle.topology_obj
spec = topo.spec
expected_total_bw = float(spec["cube"]["links"]["hbm_to_router_bw_gbs"])
expected_num_pcs = int(spec["cube"]["memory_map"]["hbm_channels_per_pe"])
expected_pc_bw = expected_total_bw / expected_num_pcs
hbm_nodes = [n for n in topo.nodes.values() if "hbm_ctrl" in n.id]
assert hbm_nodes, "no hbm_ctrl nodes found in topology"
for node in hbm_nodes:
assert "num_pcs" in node.attrs, f"{node.id} missing num_pcs"
assert int(node.attrs["num_pcs"]) == expected_num_pcs, (
f"{node.id} num_pcs={node.attrs['num_pcs']} != {expected_num_pcs}"
)
assert "pc_bw_gbs" in node.attrs, f"{node.id} missing builder-derived pc_bw_gbs"
assert abs(float(node.attrs["pc_bw_gbs"]) - expected_pc_bw) < 1e-6, (
f"{node.id} pc_bw_gbs={node.attrs['pc_bw_gbs']} != {expected_pc_bw}"
)
# ── 2. PC parallelism: concurrent writes do NOT serialize at HBM CTRL ──
def test_two_concurrent_writes_parallel_across_pcs():
"""Two concurrent writes to the same cube (different PEs) must use
different PCs (via global round-robin) and finish in less than 2x
the single-write latency.
Current model (Resource cap=1) serializes them → max ≈ 2x single.
PC striping must give max < 1.7x single (allowing for shared wire BW
occupancy, which remains).
"""
nbytes = 1024
single_ns = _single_write_ns(nbytes)
eng = _engine()
msg_a = _write_msg("conc-a", cube=0, pe=0, nbytes=nbytes)
msg_b = _write_msg("conc-b", cube=0, pe=1, nbytes=nbytes)
ha = eng.submit(msg_a)
hb = eng.submit(msg_b)
eng.wait(ha)
eng.wait(hb)
_, ta = eng.get_completion(ha)
_, tb = eng.get_completion(hb)
max_ns = max(ta["total_ns"], tb["total_ns"])
assert max_ns < single_ns * 1.7, (
f"PC striping: 2 concurrent 1KB writes should not serialize at HBM CTRL. "
f"single={single_ns:.2f}ns, concurrent max={max_ns:.2f}ns, "
f"ratio={max_ns/single_ns:.2f} (expected < 1.7)"
)
def test_eight_concurrent_writes_makespan():
"""8 concurrent 1KB writes (one per PE in cube0) must achieve makespan
significantly less than 8x single-write latency.
With 8 PCs and global round-robin, each write maps to a distinct set of
PCs; the makespan is dominated by wire BW (shared 256 GB/s pipe), not
by HBM-side serialization.
Current cap=1 model: makespan ≈ 8x single. Target: < 4x single.
"""
nbytes = 1024
single_ns = _single_write_ns(nbytes)
eng = _engine()
handles = []
for pe in range(8):
msg = _write_msg(f"8way-{pe}", cube=0, pe=pe, nbytes=nbytes)
handles.append(eng.submit(msg))
for h in handles:
eng.wait(h)
times = [eng.get_completion(h)[1]["total_ns"] for h in handles]
makespan = max(times)
assert makespan < single_ns * 4.0, (
f"8 concurrent 1KB writes: makespan={makespan:.2f}ns, "
f"single={single_ns:.2f}ns, ratio={makespan/single_ns:.2f} "
f"(expected < 4.0 with PC striping; current cap=1 gives ~8x)"
)
# ── 3. Large transfer not 2x pessimistic ──────────────────────────
def test_large_transfer_not_double_counted():
"""64KB write must not be ~2x the wire transfer time.
With cut-through (head_arrived event) + PC striping, the HBM PC commit
time overlaps with wire arrival. For 64KB at 256 GB/s aggregate:
- Wire transfer: ~256ns
- PC commit (parallel across 8 PCs, 32 chunks each): ~256ns
- Overlapped real-HW total: ~256ns (one of them dominates)
- Current sequential model: ~512ns (~2x)
Assert: total < 1.5x of (wire transfer time alone).
"""
nbytes = 65536 # 64KB
# Path bottleneck (dynamic) — for MemoryWrite this is UCIe 128 GB/s.
eng = _engine()
msg = _write_msg("64kb-probe", cube=0, pe=0, nbytes=nbytes)
drain = _path_drain_for_write(eng, msg)
total = _single_write_ns(nbytes)
assert total < drain * 1.5, (
f"64KB write should not be ~2x path bottleneck transfer time. "
f"drain={drain:.2f}ns, total={total:.2f}ns, "
f"ratio={total/drain:.2f} (expected < 1.5)"
)
# ── 4. Read/write share per-PC available_at ──────────────────────
def test_read_write_share_pc_array():
"""Read and write requests targeting overlapping PC regions must
serialize on the shared `pc_avail` array (NOT proceed in parallel like
the prior dual-channel model).
Strategy: a read and a write to the same PE/cube should land on the
same set of PCs (since global round-robin advances by chunk count, and
chunk count of 256B == 1 chunk consumes 1 PC). With single-chunk read+write
submitted concurrently, the second to acquire its chunk's PC must wait.
We assert: makespan of (concurrent read + write) > single_write_ns.
If they ran in parallel on disjoint resources (old dual-channel),
makespan ≈ single. With shared PC, makespan > single.
"""
nbytes = 256 # 1 chunk
pa = _hbm_pa(sip=0, cube=0, pe_id=0)
single_w = _single_write_ns(nbytes)
eng = _engine()
w_msg = _write_msg("rw-write", cube=0, pe=0, nbytes=nbytes)
r_msg = MemoryReadMsg(
correlation_id="pc-striping", request_id="rw-read",
src_sip=0, src_cube=0, src_pe=0,
src_pa=pa, nbytes=nbytes,
)
hw = eng.submit(w_msg)
hr = eng.submit(r_msg)
eng.wait(hw)
eng.wait(hr)
_, tw = eng.get_completion(hw)
_, tr = eng.get_completion(hr)
makespan = max(tw["total_ns"], tr["total_ns"])
# When R and W share the same first PC, the second one to acquire pays
# the burst time of the first. Assert makespan strictly > single,
# demonstrating sharing (vs the prior dual-channel parallelism).
assert makespan > single_w * 1.05, (
f"Read+Write should share per-PC slot when targeting the same starting "
f"PC. single_write={single_w:.2f}ns, R+W makespan={makespan:.2f}ns "
f"(expected > 1.05x single, demonstrating PC sharing)"
)
# ── 5. Switch penalty: default 0, mechanism wired up ─────────────
def _makespan(eng: GraphEngine, handles: list) -> float:
for h in handles:
eng.wait(h)
return max(eng.get_completion(h)[1]["total_ns"] for h in handles)
def _engine_with_switch_penalty(switch_penalty_ns: float) -> GraphEngine:
"""Build a GraphEngine, overriding switch_penalty_ns on every hbm_ctrl
node. None means leave the attr absent (i.e., test the default)."""
graph = load_topology(TOPOLOGY_PATH)
if switch_penalty_ns is not None:
for node in graph.nodes.values():
if "hbm_ctrl" in node.id:
node.attrs["switch_penalty_ns"] = switch_penalty_ns
return GraphEngine(graph)
def _rw_write_time(eng: GraphEngine, nbytes: int) -> float:
"""Submit one read followed by one write of the same size; return the
write's completion time. With `nbytes >= num_pcs * burst_bytes`, the
read populates PCs 0..N-1 with last_dir='R' and the write then wraps
back to PC 0, so every chunk of the write sees an R→W direction
switch. The write's completion time is the direct observable for the
switch-penalty mechanism (the read's time is dominated by the
response-path latency and would mask the effect)."""
r = MemoryReadMsg(
correlation_id="pc-striping", request_id="rw-1",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=nbytes,
)
w = _write_msg("rw-2", cube=0, pe=0, nbytes=nbytes)
hr = eng.submit(r)
hw = eng.submit(w)
eng.wait(hr); eng.wait(hw)
return eng.get_completion(hw)[1]["total_ns"]
def test_switch_penalty_default_zero():
"""Default (no `switch_penalty_ns` attr) must behave identically to
explicit `switch_penalty_ns=0`.
This documents Tier 0 (ADR-0033 D2): we assume an ideal HBM scheduler
amortizes switching cost; the mechanism exists but is dormant.
"""
nbytes = 2048
rw_default = _rw_write_time(_engine_with_switch_penalty(None), nbytes)
rw_zero = _rw_write_time(_engine_with_switch_penalty(0.0), nbytes)
diff = abs(rw_default - rw_zero)
assert diff < 0.01, (
f"Default (no attr) must match explicit switch_penalty_ns=0. "
f"default={rw_default:.2f}ns, explicit_zero={rw_zero:.2f}ns, "
f"diff={diff:.4f}ns"
)
def test_switch_penalty_mechanism_when_enabled():
"""When `switch_penalty_ns` is set non-zero via attr, R→W on the same
PC must show that extra delay.
Phase 2 must wire up the mechanism so that overriding the attr at
runtime (or via a modified topology) produces the expected delay.
Default config keeps it 0; this test creates an engine with an
explicit override.
"""
# Use nbytes that span all 8 PCs so the write back-wraps to PCs that
# were just touched by the read, forcing an R→W switch on each PC.
# 8 PCs × 256B burst = 2048B fills every PC exactly once.
nbytes = 2048
switch_penalty = 20.0 # large enough to be visible
# R+W with explicit switch_penalty=0: baseline (W observed time)
rw_zero = _rw_write_time(_engine_with_switch_penalty(0.0), nbytes)
# R+W with explicit switch_penalty=20: mechanism engaged
rw_pen = _rw_write_time(_engine_with_switch_penalty(switch_penalty), nbytes)
delta = rw_pen - rw_zero
# The switch penalty applies once on the second txn's first chunk.
# Conservative: assert at least half the switch_penalty shows up.
assert delta >= switch_penalty * 0.4, (
f"switch_penalty_ns={switch_penalty} should add measurable delay "
f"when R→W on same PC. R+W@0={rw_zero:.2f}ns, "
f"R+W@{switch_penalty}={rw_pen:.2f}ns, delta={delta:.2f}ns "
f"(expected >= {switch_penalty*0.4:.2f}ns)"
)
# ── 6. Backwards compat sanity ───────────────────────────────────
def test_existing_single_txn_latency_positive():
"""Sanity: single write still produces positive latency (no regression
of basic engine behavior). Companion to test_bw_occupancy.py."""
t = _single_write_ns(4096)
assert t > 0