Files
kernbench2/src/kernbench/cli/probe.py
T
2026-03-18 11:47:48 -07:00

249 lines
9.4 KiB
Python

"""kernbench probe: latency and BW verification utility.
Runs predefined traffic patterns through the simulation engine and reports
latency, effective bandwidth, bottleneck bandwidth, and utilization for each
case. Validates monotonicity invariants across hop counts and access types.
"""
from __future__ import annotations
from pathlib import Path
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
from kernbench.topology.types import TopologyGraph
# -- Helpers ----------------------------------------------------------
def _hbm_pa(sip: int, cube: int, pe_id: int, spec: dict) -> int:
mm = spec["cube"]["memory_map"]
slice_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _build_edge_map(graph: TopologyGraph) -> dict[tuple[str, str], object]:
return {(e.src, e.dst): e for e in graph.edges}
def _formula_breakdown(
path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph,
) -> tuple[float, float, float, float]:
"""Return (wire_ns, overhead_ns, drain_ns, formula_ns) for a path."""
ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
wire_ns = 0.0
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e:
wire_ns += e.distance_mm * ns_per_mm
overhead_ns = 0.0
for nid in path:
node = graph.nodes.get(nid)
if node:
overhead_ns += float(node.attrs.get("overhead_ns", 0.0))
bws = [e.bw_gbs for i in range(len(path) - 1)
if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs]
drain_ns = nbytes / min(bws) if bws else 0.0
return wire_ns, overhead_ns, drain_ns, wire_ns + overhead_ns + drain_ns
def _bottleneck_bw(path: list[str], edge_map: dict) -> float | None:
"""Per-request bottleneck: single request uses one connection."""
bws: list[float] = []
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e and e.bw_gbs:
bws.append(e.bw_gbs)
return min(bws) if bws else None
def _fmt_bw(bw: float | None) -> str:
return f"{bw:.1f}" if bw is not None else "-"
def _fmt_util(eff: float, bn: float | None) -> str:
if bn is None or bn <= 0:
return "-"
return f"{eff / bn * 100:.1f}%"
def _short_name(node_id: str) -> str:
"""Shorten node id: keep last 2 segments to avoid ambiguity (xbar.pe0 vs pe0)."""
parts = node_id.split(".")
return ".".join(parts[-2:]) if len(parts) >= 2 else node_id
def _short_path(path: list[str]) -> str:
return " -> ".join(_short_name(n) for n in path)
# -- Probe runner -----------------------------------------------------
def run_probe(topology_path: str, case_filter: str | None = None) -> int:
path = Path(topology_path).expanduser().resolve()
graph = load_topology(path)
edge_map = _build_edge_map(graph)
spec = graph.spec
resolver = AddressResolver(graph)
router = PathRouter(graph)
nbytes = 4096
show_all = case_filter is None or case_filter == "all"
# === H2D Write ===
h2d_cases = [
("h2d-1hop", 0, 1),
("h2d-2hop", 4, 2),
("h2d-3hop", 8, 3),
("h2d-4hop", 12, 4),
]
h2d_results: list[tuple[str, int, float, float, float | None]] = []
h2d_paths: list[tuple[str, list[str], list[str], list[str]]] = []
print()
print("=== H2D Write Latency (IO->HBM, varying hop count) ===")
print(f" {'Case':<14} {'Target':<16} {'Hops':>4} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 115)
for name, cube, hops in h2d_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec)
msg = MemoryWriteMsg(
correlation_id="probe", request_id=name,
dst_sip=0, dst_cube=cube, dst_pe=0,
dst_pa=pa, nbytes=nbytes, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pa_obj = PhysAddr.decode(pa)
dst_node = resolver.resolve(pa_obj)
pcie_ep = resolver.find_pcie_ep(0)
io_cpu = resolver.find_io_cpu(0)
m_cpu = resolver.find_m_cpu(0, cube)
leg1 = router.find_node_path(pcie_ep, io_cpu)
leg2 = router.find_node_path(io_cpu, m_cpu)
leg3 = router.find_mcpu_dma_path(m_cpu, dst_node)
full_path = leg1 + leg2[1:] + leg3[1:]
bn_bw = _bottleneck_bw(full_path, edge_map)
# Forward path breakdown only (response path is implicit in actual_ns)
fwd_path = leg1 + leg2[1:] + leg3[1:]
wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
h2d_results.append((name, hops, total_ns, eff_bw, bn_bw))
h2d_paths.append((name, leg1, leg2, leg3))
print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(h2d_results) >= 2:
lats = [r[2] for r in h2d_results]
mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1))
sym = "[v]" if mono else "[x]"
print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}")
if h2d_paths:
print()
print(" Route Details:")
print(f" {'Case':<14} {'Leg':>4} Path")
print(" " + "-" * 80)
for name, leg1, leg2, leg3 in h2d_paths:
print(f" {name:<14} {'L1':>4} {_short_path(leg1)}")
print(f" {'':<14} {'L2':>4} {_short_path(leg2)}")
print(f" {'':<14} {'L3':>4} {_short_path(leg3)}")
# === PE DMA → HBM (direct PE-level injection) ===
# (name, sip, src_cube, src_pe, dst_cube, dst_pe)
pe_cases = [
("pe-local-hbm", 0, 0, 0, 0, 0), # pe0 → slice0 (local, 256 GB/s)
("pe-same-half-hbm", 0, 0, 0, 0, 1), # pe0 → slice1 (xbar chain, 128 GB/s)
("pe-cross-half-hbm", 0, 0, 0, 0, 4), # pe0 → slice4 (xbar chain, 128 GB/s)
("pe-cross-cube-hbm", 0, 0, 0, 1, 0), # cube0.pe0 → cube1.slice0 (NOC, 128 GB/s)
]
pe_results: list[tuple[str, float, float, float | None]] = []
pe_paths: list[tuple[str, list[str]]] = []
print()
print("=== PE DMA Latency (pe_dma -> xbar -> HBM, direct injection) ===")
print(f" {'Case':<22} {'Target':<28} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 120)
for name, sip, src_cube, src_pe, dst_cube, dst_pe in pe_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
dst_pa = _hbm_pa(sip=sip, cube=dst_cube, pe_id=dst_pe, spec=spec)
msg = PeDmaMsg(
correlation_id="probe", request_id=name,
src_sip=sip, src_cube=src_cube, src_pe=src_pe,
dst_pa=dst_pa, nbytes=nbytes,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pe_ref = f"sip{sip}.cube{src_cube}.pe{src_pe}"
pa_obj = PhysAddr.decode(dst_pa)
dst_node = resolver.resolve(pa_obj)
dma_path = router.find_path(pe_ref, dst_node)
bn_bw = _bottleneck_bw(dma_path, edge_map)
wire, ovhd, drain, formula = _formula_breakdown(dma_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
target_str = f"c{src_cube}.pe{src_pe}->c{dst_cube}.slice{dst_pe}"
pe_results.append((name, total_ns, eff_bw, bn_bw))
pe_paths.append((name, dma_path))
print(f" {name:<22} {target_str:<28} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(pe_results) >= 2:
local = [r for r in pe_results if "local" in r[0]]
chain = [r for r in pe_results if "local" not in r[0]]
if local and chain:
print(f" * Local BN: {_fmt_bw(local[0][3])} GB/s, "
f"Chain/NOC BN: {_fmt_bw(chain[0][3])} GB/s")
if pe_paths:
print()
print(" Route Details:")
print(f" {'Case':<22} Path")
print(" " + "-" * 80)
for name, dma_path in pe_paths:
print(f" {name:<22} {_short_path(dma_path)}")
print()
return 0
def cmd_probe(args) -> int:
return run_probe(args.topology, getattr(args, "case", "all"))