eval: fold GEMM/allreduce harnesses into self-contained milestone benches

Move the GEMM + allreduce sweep/render logic out of scripts/ and tests/
into two self-contained eval benches so a user can regenerate every
result + figure with one command:

  kernbench run --bench milestone-1h-gemm   (MILESTONE_FAST=1 reuses JSON)
  kernbench run --bench milestone-1h-ccl

- benches/milestone_1h_{gemm,ccl}.py: single home for each domain; the
  run(torch) entry drives the sweeps and writes figures into
  benches/1H_milestone_output/{gemm,ccl}/ (gitignored), then submits a
  sentinel tensor to satisfy the run_bench contract.
- tests/gemm + tests/sccl helpers and scripts/gemm_sweep.py become thin
  re-export/wrapper shims over the benches (single source preserved); the
  pytest-only param builders + _run_distributed wrapper stay in the shim.
- eval-bench pattern: a bench may drive many configs + build its own
  per-config engines (extends ADR-0045 D5; reverses ADR-0044 D1/D2).

ADR-0054 (EN+KO) records the design; ADR-0043/0044/0045 + CLAUDE.md CLI
Semantics amended; ADR INDEX regenerated. Verified: milestone benches run
clean (ok=True, all artifacts), full suite 67 passed, lang-pairs OK.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 15:19:32 -07:00
parent e33e76f2d1
commit cc1bbd0ab7
19 changed files with 2189 additions and 1465 deletions
+25 -277
View File
@@ -1,283 +1,31 @@
"""Shared plotting plumbing for the GEMM figure tests.
"""Thin re-export shim for the GEMM figure tests.
Not a test module (no ``test_`` prefix -> pytest does not collect it).
Not a test module (no ``test_`` prefix pytest does not collect it).
Reads the committed ``docs/diagrams/gemm_sweep.json`` (produced by the heavy
``scripts/gemm_sweep.py`` sim sweep) and renders matplotlib PNGs into
``docs/diagrams/gemm_plots/``. No simulation here -> the figure tests are fast
and run by default; regenerating the underlying data stays a manual script.
Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py):
- stage breakdown (load_ref operand staging)
- MAC utilization — measured (load_ref)
- MAC utilization — theoretical vs measured (load_ref)
The sweep + renderer logic now lives in
``kernbench.benches.milestone_1h_gemm`` (production single home, ADR-0054,
also driven by ``scripts/gemm_sweep.py``). The figure tests import the same
names from here; behavior is unchanged (defaults still target
``docs/diagrams/gemm_plots/``).
"""
from __future__ import annotations
import json
from pathlib import Path
from kernbench.benches.milestone_1h_gemm import (
DEFAULT_PLOTS_DIR as GEMM_PLOTS_DIR,
DEFAULT_SWEEP_JSON as GEMM_SWEEP_JSON,
ROOT,
emit_all_gemm_plots,
emit_mac_utilization_measured,
emit_mac_utilization_theoretical_vs_measured,
emit_stage_breakdown,
)
ROOT = Path(__file__).resolve().parent.parent.parent
GEMM_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json"
GEMM_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots"
# Shapes excluded from the figures (mirrors build_overview_slides).
EXCLUDED_SHAPES = {(512, 512, 512)}
# Stage bars shown (raw op_log stage_type keys) + display names + colors.
STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"]
STAGE_DISPLAY = {
"DMA_READ": "DMA in",
"FETCH": "Fetch",
"GEMM": "GEMM",
"DMA_WRITE": "DMA out",
}
STAGE_COLORS = {
"DMA_READ": "#3B82F6",
"FETCH": "#10B981",
"GEMM": "#F59E0B",
"DMA_WRITE": "#A855F7",
}
# MAC-utilization model constants (mirror build_overview_slides).
_HBM_GBS = 256.0
_BPE = 2
_T_STAGE = 16.0
_D_STAGES = 3
_PLOT_VARIANT = "load_ref"
def _load_sweep_data() -> dict:
if not GEMM_SWEEP_JSON.exists():
return {"rows": []}
data = json.loads(GEMM_SWEEP_JSON.read_text())
data["rows"] = [
r for r in data.get("rows", [])
if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES
]
return data
def _shape_label(r: dict) -> str:
if r["M"] == r["K"] == r["N"]:
return f"M=K=N={r['M']}"
return f"M={r['M']} K={r['K']} N={r['N']}"
def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool:
return M < tile_M or K < tile_K or N < tile_N
def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]:
out = []
for lbl, tc, fl in zip(shape_labels, tile_counts, flagged):
s = f"{lbl}\n({tc} tiles)"
if fl:
s += " *"
out.append(s)
return out
def _grouped_bar_png(
out_name: str, *, title: str, subtitle: str | None,
shape_labels, tile_counts, flagged, series: dict, colors: dict,
y_label: str, threshold: float | None = None, footnote: str | None = None,
) -> str:
"""Render one grouped-bar chart to GEMM_PLOTS_DIR/out_name; return the path."""
import matplotlib.pyplot as plt
import numpy as np
n_groups = len(shape_labels)
n_series = max(1, len(series))
x = np.arange(n_groups)
width = 0.8 / n_series
fig, ax = plt.subplots(figsize=(11, 6))
for i, (name, vals) in enumerate(series.items()):
offset = (i - (n_series - 1) / 2) * width
ax.bar(x + offset, vals, width, label=name, color=colors.get(name))
ax.set_xticks(x)
ax.set_xticklabels(
_xtick_labels(shape_labels, tile_counts, flagged), fontsize=8,
)
ax.set_ylabel(y_label)
ax.set_title(title, fontsize=13, fontweight="bold")
if subtitle:
ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center",
va="bottom", fontsize=8, color="#475569")
if threshold is not None:
ax.axhline(threshold, ls="--", color="gray", lw=1.0)
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, axis="y", alpha=0.3)
caption = "* = under-tile shape (M<TILE_M, K<TILE_K, or N<TILE_N)"
if footnote:
caption = footnote + "\n" + caption
fig.text(0.5, 0.01, caption, ha="center", fontsize=7, color="gray",
wrap=True)
fig.tight_layout(rect=(0, 0.05, 1, 1))
GEMM_PLOTS_DIR.mkdir(parents=True, exist_ok=True)
out = GEMM_PLOTS_DIR / out_name
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)
# ── individual chart renderers (read sweep JSON, emit one PNG each) ─────
def emit_stage_breakdown() -> str | None:
"""Per-stage engine wall-clock per shape (load_ref operand staging)."""
data = _load_sweep_data()
rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT]
if not rows:
return None
tile = data["tile_sizes"]
shape_labels = [_shape_label(r) for r in rows]
flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"])
for r in rows]
tile_counts = [r["tile_count_expected"] for r in rows]
series = {
STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0)
for r in rows]
for s in STAGE_KEYS
}
colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS}
return _grouped_bar_png(
"gemm_stage_breakdown.png",
title="GEMM stage breakdown",
subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / "
f"DMA out), {_PLOT_VARIANT} staging. "
f"Tile {tile['M']}x{tile['K']}x{tile['N']}."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="ns",
footnote="Bars = engine wall-clock interval (merged overlaps).",
)
def emit_mac_utilization_measured() -> str | None:
"""GEMM util % and useful pipeline-eff % (analytical model, load_ref)."""
data = _load_sweep_data()
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gemm_util, useful_eff = [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gu = useful / (tile_flops * tiles) * 100
gemm_util.append(gu)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair
ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0
useful_eff.append(ueff)
series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff}
colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"}
return _grouped_bar_png(
"gemm_mac_utilization_measured.png",
title="GEMM MAC utilization — load_ref",
subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); "
"Useful eff = GEMM util x ideal pipeline efficiency."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
footnote="Theoretical ideal-pipeline model (not simulator data).",
)
def emit_mac_utilization_theoretical_vs_measured() -> str | None:
"""Theoretical vs simulator-measured GEMM util / useful eff (load_ref)."""
data = _load_sweep_data()
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
peak_per_ns = tile_flops / _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gu_t, gu_m, eff_t, eff_m = [], [], [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gut = useful / (tile_flops * tiles)
gu_t.append(gut * 100)
rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles
gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair
eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0)
cw = r.get("composite_window_ns", 0.0) or 0.0
eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0)
series = {
"GEMM util % (theoretical)": gu_t,
"GEMM util % (measured)": gu_m,
"Theoretical eff %": eff_t,
"Measured eff %": eff_m,
}
colors = {
"GEMM util % (theoretical)": "#10B981",
"GEMM util % (measured)": "#6EE7B7",
"Theoretical eff %": "#F59E0B",
"Measured eff %": "#3B82F6",
}
return _grouped_bar_png(
"gemm_mac_utilization_theoretical_vs_measured.png",
title="GEMM MAC utilization — theoretical vs measured (load_ref)",
subtitle=("theoretical model vs simulator op_log; agreement "
"validates the analytical pipeline model."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
)
def emit_all_gemm_plots() -> list[str]:
"""Render every GEMM figure that has data; return the list of paths written."""
paths = []
for fn in (emit_stage_breakdown,
emit_mac_utilization_measured,
emit_mac_utilization_theoretical_vs_measured):
p = fn()
if p:
paths.append(p)
return paths
__all__ = [
"GEMM_PLOTS_DIR",
"GEMM_SWEEP_JSON",
"ROOT",
"emit_all_gemm_plots",
"emit_mac_utilization_measured",
"emit_mac_utilization_theoretical_vs_measured",
"emit_stage_breakdown",
]