"""kernbench probe: latency and BW verification utility. Runs predefined traffic patterns through the simulation engine and reports latency, effective bandwidth, bottleneck bandwidth, and utilization for each case. Validates monotonicity invariants across hop counts and access types. """ from __future__ import annotations from pathlib import Path from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.routing.router import AddressResolver, PathRouter from kernbench.runtime_api.kernel import MemoryReadMsg, MemoryWriteMsg, PeDmaMsg from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import load_topology from kernbench.topology.types import TopologyGraph # -- Helpers ---------------------------------------------------------- def _hbm_pa(sip: int, cube: int, pe_id: int, spec: dict) -> int: mm = spec["cube"]["memory_map"] slice_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"] pa = PhysAddr.pe_hbm_addr( sip_id=sip, die_id=cube, pe_id=pe_id, pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes, ) return pa.encode() def _build_edge_map(graph: TopologyGraph) -> dict[tuple[str, str], object]: return {(e.src, e.dst): e for e in graph.edges} def _formula_breakdown( path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph, ) -> tuple[float, float, float, float]: """Return (wire_ns, overhead_ns, drain_ns, formula_ns) for a path.""" ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01) wire_ns = 0.0 for i in range(len(path) - 1): e = edge_map.get((path[i], path[i + 1])) if e: wire_ns += e.distance_mm * ns_per_mm overhead_ns = 0.0 for nid in path: node = graph.nodes.get(nid) if node: overhead_ns += float(node.attrs.get("overhead_ns", 0.0)) bws = [e.bw_gbs for i in range(len(path) - 1) if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs] drain_ns = nbytes / min(bws) if bws else 0.0 return wire_ns, overhead_ns, drain_ns, wire_ns + overhead_ns + drain_ns def _hop_timestamps( path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph, ) -> list[tuple[str, float, str]]: """Return per-hop timestamps: [(node_short, cumulative_ns, annotation), ...]. Annotations mark bottleneck edges and significant overhead nodes. """ ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01) # Find bottleneck BW for annotation bws = [e.bw_gbs for i in range(len(path) - 1) if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs] bn_bw = min(bws) if bws else None cumulative = 0.0 result: list[tuple[str, float, str]] = [] result.append((_short_name(path[0]), 0.0, "")) for i in range(len(path) - 1): e = edge_map.get((path[i], path[i + 1])) ann = "" if e: cumulative += e.distance_mm * ns_per_mm if bn_bw is not None and e.bw_gbs and e.bw_gbs == bn_bw: ann = f"" node = graph.nodes.get(path[i + 1]) if node: ovhd = float(node.attrs.get("overhead_ns", 0.0)) cumulative += ovhd if ovhd > 0 and not ann: ann = f"+{ovhd:.1f}ns" result.append((_short_name(path[i + 1]), cumulative, ann)) # Add drain at terminal if bn_bw and nbytes > 0: cumulative += nbytes / bn_bw result[-1] = (result[-1][0], cumulative, result[-1][2] + f" drain:{nbytes/bn_bw:.1f}ns") return result def _bottleneck_bw(path: list[str], edge_map: dict) -> float | None: """Per-request bottleneck: single request uses one connection.""" bws: list[float] = [] for i in range(len(path) - 1): e = edge_map.get((path[i], path[i + 1])) if e and e.bw_gbs: bws.append(e.bw_gbs) return min(bws) if bws else None def _fmt_bw(bw: float | None) -> str: return f"{bw:.1f}" if bw is not None else "-" def _fmt_util(eff: float, bn: float | None) -> str: if bn is None or bn <= 0: return "-" return f"{eff / bn * 100:.1f}%" def _short_name(node_id: str) -> str: """Shorten node id: keep last 2 segments to avoid ambiguity (router.pe0 vs pe0).""" parts = node_id.split(".") return ".".join(parts[-2:]) if len(parts) >= 2 else node_id def _short_path(path: list[str]) -> str: return " -> ".join(_short_name(n) for n in path) def _print_hop_trace(timestamps: list[tuple[str, float, str]], indent: str = " ") -> None: """Print per-hop timestamp trace.""" for node, t_ns, ann in timestamps: ann_str = f" {ann}" if ann else "" print(f"{indent}{t_ns:>8.2f}ns {node}{ann_str}") SWEEP_SIZES = [4096, 16384, 65536, 262144, 1048576] SWEEP_LABELS = ["4KB", "16KB", "64KB", "256KB", "1MB"] def _sweep_util(ovhd_ns: float, wire_ns: float, bn_bw: float | None, sizes: list[int] = SWEEP_SIZES) -> list[float]: """Compute utilization % for each data size using formula model.""" if bn_bw is None or bn_bw <= 0: return [0.0] * len(sizes) result = [] for nb in sizes: drain = nb / bn_bw total = ovhd_ns + wire_ns + drain eff = nb / total if total > 0 else 0.0 result.append(eff / bn_bw * 100) return result def _print_sweep_table(case_names: list[str], sweep_data: list[list[float]]) -> None: """Print compact BW saturation table.""" hdr = f" {'Case':<26}" + "".join(f" {l:>7}" for l in SWEEP_LABELS) print(f"\n BW Saturation (Util% by data size):") print(hdr) print(" " + "-" * (26 + 8 * len(SWEEP_LABELS))) for name, utils in zip(case_names, sweep_data): cols = "".join(f" {u:>6.1f}%" for u in utils) print(f" {name:<26}{cols}") # -- Probe runner ----------------------------------------------------- def run_probe(topology_path: str, case_filter: str | None = None) -> int: path = Path(topology_path).expanduser().resolve() graph = load_topology(path) edge_map = _build_edge_map(graph) spec = graph.spec resolver = AddressResolver(graph) router = PathRouter(graph) nbytes = 32768 show_all = case_filter is None or case_filter == "all" # === Collect H2D results === h2d_cases = [ ("h2d-1hop", 0, 1), ("h2d-2hop", 4, 2), ("h2d-3hop", 8, 3), ("h2d-4hop", 12, 4), ] h2d_results: list[tuple[str, int, float, float, float | None, float, float, float, float, float]] = [] h2d_route_data: list[tuple[str, list[str], list[str], list[str], list[str]]] = [] for name, cube, hops in h2d_cases: if not show_all and case_filter != name: continue engine = GraphEngine(graph) pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec) msg = MemoryWriteMsg( correlation_id="probe", request_id=name, dst_sip=0, dst_cube=cube, dst_pe=0, dst_pa=pa, nbytes=nbytes, pattern="zero", ) h = engine.submit(msg) engine.wait(h) _, trace = engine.get_completion(h) total_ns = trace["total_ns"] eff_bw = nbytes / total_ns if total_ns > 0 else 0.0 pa_obj = PhysAddr.decode(pa) dst_node = resolver.resolve(pa_obj) pcie_ep = resolver.find_pcie_ep(0) io_cpu = resolver.find_io_cpu(0) m_cpu = resolver.find_m_cpu(0, cube) leg1 = router.find_node_path(pcie_ep, io_cpu) leg2 = router.find_node_path(io_cpu, m_cpu) leg3 = router.find_mcpu_dma_path(m_cpu, dst_node) full_path = leg1 + leg2[1:] + leg3[1:] bn_bw = _bottleneck_bw(full_path, edge_map) fwd_path = leg1 + leg2[1:] + leg3[1:] wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph) ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0 drain_pct = drain / total_ns * 100 if total_ns > 0 else 0 h2d_results.append((name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct)) h2d_route_data.append((name, leg1, leg2, leg3, fwd_path)) # === Collect D2H Read results === d2h_cases = [ ("d2h-1hop", 0, 1), ("d2h-2hop", 4, 2), ("d2h-3hop", 8, 3), ("d2h-4hop", 12, 4), ] d2h_results: list[tuple[str, int, float, float, float | None, float, float, float, float, float]] = [] d2h_route_data: list[tuple[str, list[str], list[str], list[str], list[str]]] = [] for name, cube, hops in d2h_cases: if not show_all and case_filter != name: continue engine = GraphEngine(graph) pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec) msg = MemoryReadMsg( correlation_id="probe", request_id=name, src_sip=0, src_cube=cube, src_pe=0, src_pa=pa, nbytes=nbytes, ) h = engine.submit(msg) engine.wait(h) _, trace = engine.get_completion(h) total_ns = trace["total_ns"] eff_bw = nbytes / total_ns if total_ns > 0 else 0.0 pa_obj = PhysAddr.decode(pa) dst_node = resolver.resolve(pa_obj) pcie_ep = resolver.find_pcie_ep(0) fwd_path = router.find_memory_path(pcie_ep, dst_node) rev_path = list(reversed(fwd_path)) bn_bw = _bottleneck_bw(fwd_path, edge_map) wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph) ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0 drain_pct = drain / total_ns * 100 if total_ns > 0 else 0 d2h_results.append((name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct)) d2h_route_data.append((name, fwd_path, rev_path, [], fwd_path)) # === Collect PE DMA results === pe_cases = [ ("pe-local-hbm", 0, 0, 0, 0, 0), ("pe-same-half-hbm", 0, 0, 0, 0, 1), ("pe-cross-half-hbm", 0, 0, 0, 0, 4), ("pe-cross-cube-hbm-best", 0, 0, 0, 1, 0), # adjacent cube ("pe-cross-cube-hbm-worst", 0, 0, 0, 15, 0), # diagonal far cube ] pe_results: list[tuple[str, float, float, float | None, float, float, float, float, float]] = [] pe_route_data: list[tuple[str, list[str], str]] = [] for name, sip, src_cube, src_pe, dst_cube, dst_pe in pe_cases: if not show_all and case_filter != name: continue engine = GraphEngine(graph) dst_pa = _hbm_pa(sip=sip, cube=dst_cube, pe_id=dst_pe, spec=spec) msg = PeDmaMsg( correlation_id="probe", request_id=name, src_sip=sip, src_cube=src_cube, src_pe=src_pe, dst_pa=dst_pa, nbytes=nbytes, ) h = engine.submit(msg) engine.wait(h) _, trace = engine.get_completion(h) total_ns = trace["total_ns"] eff_bw = nbytes / total_ns if total_ns > 0 else 0.0 pe_ref = f"sip{sip}.cube{src_cube}.pe{src_pe}" pa_obj = PhysAddr.decode(dst_pa) dst_node = resolver.resolve(pa_obj) dma_path = router.find_path(pe_ref, dst_node) bn_bw = _bottleneck_bw(dma_path, edge_map) wire, ovhd, drain, formula = _formula_breakdown(dma_path, nbytes, edge_map, graph) ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0 drain_pct = drain / total_ns * 100 if total_ns > 0 else 0 target_str = f"c{src_cube}.pe{src_pe}->c{dst_cube}.slice{dst_pe}" pe_results.append((name, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct)) pe_route_data.append((name, dma_path, target_str)) # ================================================================ # OUTPUT: Summary tables first, then route details # ================================================================ # --- H2D Summary Table --- print() print(f"=== H2D Write Latency (IO->HBM, data={nbytes}B) ===") print(f" {'Case':<14} {'Target':<16} {'Hops':>4} {'Actual':>8}" f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}" f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}") print(" " + "-" * 115) for i, (name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct) in enumerate(h2d_results): cube = h2d_cases[i][1] if i < len(h2d_cases) else 0 print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}" f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%" f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}") if len(h2d_results) >= 2: lats = [r[2] for r in h2d_results] mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1)) sym = "[v]" if mono else "[x]" print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}") if h2d_results: h2d_sweep = [_sweep_util(r[5], r[7], r[4]) for r in h2d_results] _print_sweep_table([r[0] for r in h2d_results], h2d_sweep) # --- D2H Summary Table --- print() print(f"=== D2H Read Latency (HBM->IO, data={nbytes}B) ===") print(f" {'Case':<14} {'Source':<16} {'Hops':>4} {'Actual':>8}" f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}" f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}") print(" " + "-" * 115) for i, (name, hops, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct) in enumerate(d2h_results): cube = d2h_cases[i][1] if i < len(d2h_cases) else 0 print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}" f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%" f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}") if len(d2h_results) >= 2: lats = [r[2] for r in d2h_results] mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1)) sym = "[v]" if mono else "[x]" print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}") if d2h_results: # D2H fixed cost = actual_total - drain (includes fwd+rev overhead) d2h_sweep = [_sweep_util(r[2] - r[6], 0.0, r[4]) for r in d2h_results] _print_sweep_table([r[0] for r in d2h_results], d2h_sweep) # H2D vs D2H comparison if h2d_results and d2h_results and len(h2d_results) == len(d2h_results): all_gte = all(d2h_results[i][2] >= h2d_results[i][2] for i in range(len(h2d_results))) sym = "[v]" if all_gte else "[x]" print(f" {sym} D2H >= H2D (reverse data path): {'PASS' if all_gte else 'FAIL'}") # --- PE DMA Summary Table --- print() print(f"=== PE DMA Latency (pe_dma -> router -> HBM, data={nbytes}B) ===") print(f" {'Case':<26} {'Target':<28} {'Actual':>8}" f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}" f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}") print(" " + "-" * 124) for name, total_ns, eff_bw, bn_bw, ovhd, drain, wire, ovhd_pct, drain_pct in pe_results: target_str = [t for n, _, t in pe_route_data if n == name] t_str = target_str[0] if target_str else "" print(f" {name:<26} {t_str:<28} {total_ns:>8.2f}" f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%" f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}") if len(pe_results) >= 2: local = [r for r in pe_results if "local" in r[0]] remote = [r for r in pe_results if "local" not in r[0]] if local and remote: print(f" * Local BN: {_fmt_bw(local[0][3])} GB/s, " f"Remote BN: {_fmt_bw(remote[0][3])} GB/s") best = [r for r in pe_results if "best" in r[0]] worst = [r for r in pe_results if "worst" in r[0]] if best and worst: sym = "[v]" if best[0][1] < worst[0][1] else "[x]" print(f" {sym} Cross-cube best < worst: {'PASS' if best[0][1] < worst[0][1] else 'FAIL'}" f" ({best[0][1]:.2f}ns < {worst[0][1]:.2f}ns)") if pe_results: pe_sweep = [_sweep_util(r[4], r[6], r[3]) for r in pe_results] _print_sweep_table([r[0] for r in pe_results], pe_sweep) # ================================================================ # ROUTE DETAILS (grouped below all tables) # ================================================================ print() print("=" * 60) print(" ROUTE DETAILS (per-hop timestamps)") print("=" * 60) # --- H2D Routes --- if h2d_route_data: print() print(" --- H2D Write Routes ---") for name, leg1, leg2, leg3, fwd_path in h2d_route_data: timestamps = _hop_timestamps(fwd_path, nbytes, edge_map, graph) print(f"\n [{name}]") print(f" Leg1: {_short_path(leg1)}") print(f" Leg2: {_short_path(leg2)}") print(f" Leg3: {_short_path(leg3)}") print(f" Per-hop trace:") _print_hop_trace(timestamps, indent=" ") # --- D2H Routes --- if d2h_route_data: print() print(" --- D2H Read Routes ---") for name, fwd_path, rev_path, _, _ in d2h_route_data: timestamps_fwd = _hop_timestamps(fwd_path, 0, edge_map, graph) timestamps_rev = _hop_timestamps(rev_path, nbytes, edge_map, graph) print(f"\n [{name}]") print(f" Fwd (cmd): {_short_path(fwd_path)}") print(f" Rev (data): {_short_path(rev_path)}") print(f" Forward cmd trace (no data):") _print_hop_trace(timestamps_fwd, indent=" ") print(f" Reverse data trace:") _print_hop_trace(timestamps_rev, indent=" ") # --- PE DMA Routes --- if pe_route_data: print() print(" --- PE DMA Routes ---") for name, dma_path, target_str in pe_route_data: timestamps = _hop_timestamps(dma_path, nbytes, edge_map, graph) print(f"\n [{name}] {target_str}") print(f" Path: {_short_path(dma_path)}") print(f" Per-hop trace:") _print_hop_trace(timestamps, indent=" ") print() return 0 def cmd_probe(args) -> int: return run_probe(args.topology, getattr(args, "case", "all"))