Files
kernbench2/src/kernbench/cli/probe.py
T
ywkang 81cc32c46b ADR-0001 Rev 2: 51-bit PhysAddr layout with concrete sub-unit tables
Remove rack_id (4 bits), rename sip_seg→die_id, shift fields to enable
42-bit local_offset (4 TB per die). Define PE_LOCAL/MCPU_LOCAL/CUBE_SRAM
sub-unit tables for AHBM dies and IOCPU sub-unit table for IOCHIPLET
dies (1 TB window). Supersedes ADR-0031.

Also fixes latent VA/PA confusion in pe_dma pipeline DMA path where
virtual addresses were decoded as physical addresses without MMU
translation — previously masked by coincidental bit-position alignment.

529 passed (+6 recovered), 10 pre-existing failures unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-27 15:52:29 -07:00

452 lines
18 KiB
Python

"""kernbench probe: latency and BW verification utility.
Runs predefined traffic patterns through the simulation engine and reports
latency, effective bandwidth, bottleneck bandwidth, and utilization for each
case. Validates monotonicity invariants across hop counts and access types.
"""
from __future__ import annotations
from pathlib import Path
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
from kernbench.topology.types import TopologyGraph
# -- Helpers ----------------------------------------------------------
def _hbm_pa(sip: int, cube: int, pe_id: int, spec: dict) -> int:
mm = spec["cube"]["memory_map"]
slice_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
pa = PhysAddr.pe_hbm_addr(
sip_id=sip, die_id=cube, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _build_edge_map(graph: TopologyGraph) -> dict[tuple[str, str], object]:
return {(e.src, e.dst): e for e in graph.edges}
def _formula_breakdown(
path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph,
) -> tuple[float, float, float, float]:
"""Return (wire_ns, overhead_ns, drain_ns, formula_ns) for a path."""
ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
wire_ns = 0.0
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e:
wire_ns += e.distance_mm * ns_per_mm
overhead_ns = 0.0
for nid in path:
node = graph.nodes.get(nid)
if node:
overhead_ns += float(node.attrs.get("overhead_ns", 0.0))
bws = [e.bw_gbs for i in range(len(path) - 1)
if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs]
drain_ns = nbytes / min(bws) if bws else 0.0
return wire_ns, overhead_ns, drain_ns, wire_ns + overhead_ns + drain_ns
def _hop_timestamps(
path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph,
) -> list[tuple[str, float, str]]:
"""Return per-hop timestamps: [(node_short, cumulative_ns, annotation), ...].
Annotations mark bottleneck edges and significant overhead nodes.
"""
ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
# Find bottleneck BW for annotation
bws = [e.bw_gbs for i in range(len(path) - 1)
if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs]
bn_bw = min(bws) if bws else None
cumulative = 0.0
result: list[tuple[str, float, str]] = []
result.append((_short_name(path[0]), 0.0, ""))
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
ann = ""
if e:
cumulative += e.distance_mm * ns_per_mm
if bn_bw is not None and e.bw_gbs and e.bw_gbs == bn_bw:
ann = f"<BN:{e.bw_gbs:.0f}GB/s>"
node = graph.nodes.get(path[i + 1])
if node:
ovhd = float(node.attrs.get("overhead_ns", 0.0))
cumulative += ovhd
if ovhd > 0 and not ann:
ann = f"+{ovhd:.1f}ns"
result.append((_short_name(path[i + 1]), cumulative, ann))
# Add drain at terminal
if bn_bw and nbytes > 0:
cumulative += nbytes / bn_bw
result[-1] = (result[-1][0], cumulative, result[-1][2] + f" drain:{nbytes/bn_bw:.1f}ns")
return result
def _bottleneck_bw(path: list[str], edge_map: dict) -> float | None:
"""Per-request bottleneck: single request uses one connection."""
bws: list[float] = []
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e and e.bw_gbs:
bws.append(e.bw_gbs)
return min(bws) if bws else None
def _fmt_bw(bw: float | None) -> str:
return f"{bw:.1f}" if bw is not None else "-"
def _fmt_util(eff: float, bn: float | None) -> str:
if bn is None or bn <= 0:
return "-"
return f"{eff / bn * 100:.1f}%"
def _short_name(node_id: str) -> str:
"""Shorten node id: keep last 2 segments to avoid ambiguity (router.pe0 vs pe0)."""
parts = node_id.split(".")
return ".".join(parts[-2:]) if len(parts) >= 2 else node_id
def _short_path(path: list[str]) -> str:
return " -> ".join(_short_name(n) for n in path)
def _print_hop_trace(timestamps: list[tuple[str, float, str]], indent: str = " ") -> None:
"""Print per-hop timestamp trace."""
for node, t_ns, ann in timestamps:
ann_str = f" {ann}" if ann else ""
print(f"{indent}{t_ns:>8.2f}ns {node}{ann_str}")
SWEEP_SIZES = [4096, 16384, 65536, 262144, 1048576]
SWEEP_LABELS = ["4KB", "16KB", "64KB", "256KB", "1MB"]
def _sweep_util(ovhd_ns: float, wire_ns: float, bn_bw: float | None, sizes: list[int] = SWEEP_SIZES) -> list[float]:
"""Compute utilization % for each data size using formula model."""
if bn_bw is None or bn_bw <= 0:
return [0.0] * len(sizes)
result = []
for nb in sizes:
drain = nb / bn_bw
total = ovhd_ns + wire_ns + drain
eff = nb / total if total > 0 else 0.0
result.append(eff / bn_bw * 100)
return result
def _print_sweep_table(case_names: list[str], sweep_data: list[list[float]]) -> None:
"""Print compact BW saturation table."""
hdr = f" {'Case':<26}" + "".join(f" {l:>7}" for l in SWEEP_LABELS)
print(f"\n BW Saturation (Util% by data size):")
print(hdr)
print(" " + "-" * (26 + 8 * len(SWEEP_LABELS)))
for name, utils in zip(case_names, sweep_data):
cols = "".join(f" {u:>6.1f}%" for u in utils)
print(f" {name:<26}{cols}")
# -- Probe runner -----------------------------------------------------
def run_probe(topology_path: str, case_filter: str | None = None) -> int:
path = Path(topology_path).expanduser().resolve()
graph = load_topology(path)
edge_map = _build_edge_map(graph)
spec = graph.spec
resolver = AddressResolver(graph)
router = PathRouter(graph)
nbytes = 32768
show_all = case_filter is None or case_filter == "all"
# === Collect H2D results ===
h2d_cases = [
("h2d-1hop", 0, 1),
("h2d-2hop", 4, 2),
("h2d-3hop", 8, 3),
("h2d-4hop", 12, 4),
]
h2d_results: list[tuple[str, int, float, float, float | None, float, float, float, float, float]] = []
h2d_route_data: list[tuple[str, list[str], list[str], list[str], list[str]]] = []
for name, cube, hops in h2d_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec)
msg = MemoryWriteMsg(
correlation_id="probe", request_id=name,
dst_sip=0, dst_cube=cube, dst_pe=0,
dst_pa=pa, nbytes=nbytes, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pa_obj = PhysAddr.decode(pa)
dst_node = resolver.resolve(pa_obj)
pcie_ep = resolver.find_pcie_ep(0)
io_cpu = resolver.find_io_cpu(0)
m_cpu = resolver.find_m_cpu(0, cube)
leg1 = router.find_node_path(pcie_ep, io_cpu)
leg2 = router.find_node_path(io_cpu, m_cpu)
leg3 = router.find_mcpu_dma_path(m_cpu, dst_node)
full_path = leg1 + leg2[1:] + leg3[1:]
bn_bw = _bottleneck_bw(full_path, edge_map)
fwd_path = leg1 + leg2[1:] + leg3[1:]
wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
h2d_results.append((name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct))
h2d_route_data.append((name, leg1, leg2, leg3, fwd_path))
# === Collect D2H Read results ===
d2h_cases = [
("d2h-1hop", 0, 1),
("d2h-2hop", 4, 2),
("d2h-3hop", 8, 3),
("d2h-4hop", 12, 4),
]
d2h_results: list[tuple[str, int, float, float, float | None, float, float, float, float, float]] = []
d2h_route_data: list[tuple[str, list[str], list[str], list[str], list[str]]] = []
for name, cube, hops in d2h_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec)
msg = MemoryReadMsg(
correlation_id="probe", request_id=name,
src_sip=0, src_cube=cube, src_pe=0,
src_pa=pa, nbytes=nbytes,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pa_obj = PhysAddr.decode(pa)
dst_node = resolver.resolve(pa_obj)
pcie_ep = resolver.find_pcie_ep(0)
fwd_path = router.find_memory_path(pcie_ep, dst_node)
rev_path = list(reversed(fwd_path))
bn_bw = _bottleneck_bw(fwd_path, edge_map)
wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
d2h_results.append((name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct))
d2h_route_data.append((name, fwd_path, rev_path, [], fwd_path))
# === Collect PE DMA results ===
pe_cases = [
("pe-local-hbm", 0, 0, 0, 0, 0),
("pe-same-half-hbm", 0, 0, 0, 0, 1),
("pe-cross-half-hbm", 0, 0, 0, 0, 4),
("pe-cross-cube-hbm-best", 0, 0, 0, 1, 0), # adjacent cube
("pe-cross-cube-hbm-worst", 0, 0, 0, 15, 0), # diagonal far cube
]
pe_results: list[tuple[str, float, float, float | None, float, float, float, float, float]] = []
pe_route_data: list[tuple[str, list[str], str]] = []
for name, sip, src_cube, src_pe, dst_cube, dst_pe in pe_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
dst_pa = _hbm_pa(sip=sip, cube=dst_cube, pe_id=dst_pe, spec=spec)
msg = PeDmaMsg(
correlation_id="probe", request_id=name,
src_sip=sip, src_cube=src_cube, src_pe=src_pe,
dst_pa=dst_pa, nbytes=nbytes,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pe_ref = f"sip{sip}.cube{src_cube}.pe{src_pe}"
pa_obj = PhysAddr.decode(dst_pa)
dst_node = resolver.resolve(pa_obj)
dma_path = router.find_path(pe_ref, dst_node)
bn_bw = _bottleneck_bw(dma_path, edge_map)
wire, ovhd, drain, formula = _formula_breakdown(dma_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
target_str = f"c{src_cube}.pe{src_pe}->c{dst_cube}.slice{dst_pe}"
pe_results.append((name, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct))
pe_route_data.append((name, dma_path, target_str))
# ================================================================
# OUTPUT: Summary tables first, then route details
# ================================================================
# --- H2D Summary Table ---
print()
print(f"=== H2D Write Latency (IO->HBM, data={nbytes}B) ===")
print(f" {'Case':<14} {'Target':<16} {'Hops':>4} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 115)
for i, (name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct) in enumerate(h2d_results):
cube = h2d_cases[i][1] if i < len(h2d_cases) else 0
print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(h2d_results) >= 2:
lats = [r[2] for r in h2d_results]
mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1))
sym = "[v]" if mono else "[x]"
print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}")
if h2d_results:
h2d_sweep = [_sweep_util(r[5], r[7], r[4]) for r in h2d_results]
_print_sweep_table([r[0] for r in h2d_results], h2d_sweep)
# --- D2H Summary Table ---
print()
print(f"=== D2H Read Latency (HBM->IO, data={nbytes}B) ===")
print(f" {'Case':<14} {'Source':<16} {'Hops':>4} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 115)
for i, (name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct) in enumerate(d2h_results):
cube = d2h_cases[i][1] if i < len(d2h_cases) else 0
print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(d2h_results) >= 2:
lats = [r[2] for r in d2h_results]
mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1))
sym = "[v]" if mono else "[x]"
print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}")
if d2h_results:
# D2H fixed cost = actual_total - drain (includes fwd+rev overhead)
d2h_sweep = [_sweep_util(r[2] - r[6], 0.0, r[4]) for r in d2h_results]
_print_sweep_table([r[0] for r in d2h_results], d2h_sweep)
# H2D vs D2H comparison
if h2d_results and d2h_results and len(h2d_results) == len(d2h_results):
all_gte = all(d2h_results[i][2] >= h2d_results[i][2] for i in range(len(h2d_results)))
sym = "[v]" if all_gte else "[x]"
print(f" {sym} D2H >= H2D (reverse data path): {'PASS' if all_gte else 'FAIL'}")
# --- PE DMA Summary Table ---
print()
print(f"=== PE DMA Latency (pe_dma -> router -> HBM, data={nbytes}B) ===")
print(f" {'Case':<26} {'Target':<28} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 124)
for name, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct in pe_results:
target_str = [t for n, _, t in pe_route_data if n == name]
t_str = target_str[0] if target_str else ""
print(f" {name:<26} {t_str:<28} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(pe_results) >= 2:
local = [r for r in pe_results if "local" in r[0]]
remote = [r for r in pe_results if "local" not in r[0]]
if local and remote:
print(f" * Local BN: {_fmt_bw(local[0][3])} GB/s, "
f"Remote BN: {_fmt_bw(remote[0][3])} GB/s")
best = [r for r in pe_results if "best" in r[0]]
worst = [r for r in pe_results if "worst" in r[0]]
if best and worst:
sym = "[v]" if best[0][1] < worst[0][1] else "[x]"
print(f" {sym} Cross-cube best < worst: {'PASS' if best[0][1] < worst[0][1] else 'FAIL'}"
f" ({best[0][1]:.2f}ns < {worst[0][1]:.2f}ns)")
if pe_results:
pe_sweep = [_sweep_util(r[4], r[6], r[3]) for r in pe_results]
_print_sweep_table([r[0] for r in pe_results], pe_sweep)
# ================================================================
# ROUTE DETAILS (grouped below all tables)
# ================================================================
print()
print("=" * 60)
print(" ROUTE DETAILS (per-hop timestamps)")
print("=" * 60)
# --- H2D Routes ---
if h2d_route_data:
print()
print(" --- H2D Write Routes ---")
for name, leg1, leg2, leg3, fwd_path in h2d_route_data:
timestamps = _hop_timestamps(fwd_path, nbytes, edge_map, graph)
print(f"\n [{name}]")
print(f" Leg1: {_short_path(leg1)}")
print(f" Leg2: {_short_path(leg2)}")
print(f" Leg3: {_short_path(leg3)}")
print(f" Per-hop trace:")
_print_hop_trace(timestamps, indent=" ")
# --- D2H Routes ---
if d2h_route_data:
print()
print(" --- D2H Read Routes ---")
for name, fwd_path, rev_path, _, _ in d2h_route_data:
timestamps_fwd = _hop_timestamps(fwd_path, 0, edge_map, graph)
timestamps_rev = _hop_timestamps(rev_path, nbytes, edge_map, graph)
print(f"\n [{name}]")
print(f" Fwd (cmd): {_short_path(fwd_path)}")
print(f" Rev (data): {_short_path(rev_path)}")
print(f" Forward cmd trace (no data):")
_print_hop_trace(timestamps_fwd, indent=" ")
print(f" Reverse data trace:")
_print_hop_trace(timestamps_rev, indent=" ")
# --- PE DMA Routes ---
if pe_route_data:
print()
print(" --- PE DMA Routes ---")
for name, dma_path, target_str in pe_route_data:
timestamps = _hop_timestamps(dma_path, nbytes, edge_map, graph)
print(f"\n [{name}] {target_str}")
print(f" Path: {_short_path(dma_path)}")
print(f" Per-hop trace:")
_print_hop_trace(timestamps, indent=" ")
print()
return 0
def cmd_probe(args) -> int:
return run_probe(args.topology, getattr(args, "case", "all"))