eval: fold GEMM/allreduce harnesses into self-contained milestone benches

Move the GEMM + allreduce sweep/render logic out of scripts/ and tests/
into two self-contained eval benches so a user can regenerate every
result + figure with one command:

  kernbench run --bench milestone-1h-gemm   (MILESTONE_FAST=1 reuses JSON)
  kernbench run --bench milestone-1h-ccl

- benches/milestone_1h_{gemm,ccl}.py: single home for each domain; the
  run(torch) entry drives the sweeps and writes figures into
  benches/1H_milestone_output/{gemm,ccl}/ (gitignored), then submits a
  sentinel tensor to satisfy the run_bench contract.
- tests/gemm + tests/sccl helpers and scripts/gemm_sweep.py become thin
  re-export/wrapper shims over the benches (single source preserved); the
  pytest-only param builders + _run_distributed wrapper stay in the shim.
- eval-bench pattern: a bench may drive many configs + build its own
  per-config engines (extends ADR-0045 D5; reverses ADR-0044 D1/D2).

ADR-0054 (EN+KO) records the design; ADR-0043/0044/0045 + CLAUDE.md CLI
Semantics amended; ADR INDEX regenerated. Verified: milestone benches run
clean (ok=True, all artifacts), full suite 67 passed, lang-pairs OK.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 15:19:32 -07:00
parent e33e76f2d1
commit cc1bbd0ab7
19 changed files with 2189 additions and 1465 deletions
File diff suppressed because it is too large Load Diff
+568
View File
@@ -0,0 +1,568 @@
"""milestone-1h-gemm bench: GEMM evaluation harness (sweep + figures).
Self-contained milestone bench (ADR-0054). Holds the shape×variant sweep
and the figure renderers; the ``run(torch)`` entry at the bottom runs the
sweep (or reuses the committed JSON when ``MILESTONE_FAST=1``) and writes
every figure into ``benches/1H_milestone_output/gemm/``.
This is the single home for the GEMM eval logic: the figure tests import a
thin re-export shim (``tests/gemm/_gemm_plot_helpers.py``), as does the
``scripts/gemm_sweep.py`` wrapper.
The sweep drives ``matmul-composite`` across shapes×variants through the
same ``run_bench`` path the CLI uses, harvests ``result.engine.op_log``,
and writes the sweep JSON. The renderers read that JSON and emit matplotlib
PNGs. No simulation in the renderers — they are fast.
Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py):
- stage breakdown (load_ref operand staging)
- MAC utilization — measured (load_ref)
- MAC utilization — theoretical vs measured (load_ref)
"""
from __future__ import annotations
import json
import os
import sys
import time
from pathlib import Path
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy
ROOT = Path(__file__).resolve().parents[3]
DEFAULT_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json"
DEFAULT_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots"
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gemm"
# ── sweep configuration ────────────────────────────────────────────────
# Default sweep covering under-tile, single-tile, multi-tile, and asymmetric
# regimes. Each entry is "MxKxN" or a single int (square M=K=N).
# Override via env: SWEEP_SHAPES="16,32,16x2048x16,..."
DEFAULT_SHAPES = [
"32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K
"32x64x32", # 1 tile, exact single-tile fit
"32x128x32", # 2 tiles, aligned
"32x128x128", # 8 tiles, aligned
"32x3072x32", # 48 tiles, all K-axis (tall-skinny)
"8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed
"128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed
"512", # 2048 tiles, fully aligned — "well-pipelined" reference
]
# Operand-staging variants exercised per shape.
VARIANTS = ["ref_ref", "load_ref", "load_load"]
# Engines whose timings we collect (component_id suffix match).
ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"]
# Per-stage breakdown labels (StageType enum names from pe_types.py).
STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"]
# Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N).
TILE_M, TILE_K, TILE_N = 32, 64, 32
def _ceil(a: int, b: int) -> int:
return (a + b - 1) // b
def _engine_wall_ns(records, suffix: str) -> float:
"""Wall-clock interval the engine was active (union of overlapping ops)."""
intervals = [(r.t_start, r.t_end) for r in records
if r.component_id.endswith("." + suffix)]
if not intervals:
return 0.0
intervals.sort()
merged_end = intervals[0][1]
merged_start = intervals[0][0]
total = 0.0
for s, e in intervals[1:]:
if s <= merged_end:
merged_end = max(merged_end, e)
else:
total += merged_end - merged_start
merged_start, merged_end = s, e
total += merged_end - merged_start
return total
def _engine_occupancy_ns(records, suffix: str) -> float:
return sum(r.t_end - r.t_start for r in records
if r.component_id.endswith("." + suffix))
def _engine_count(records, suffix: str) -> int:
return sum(1 for r in records if r.component_id.endswith("." + suffix))
def _stage_occupancy_ns(records, stage_type: str) -> float:
return sum(
r.t_end - r.t_start
for r in records
if r.params.get("stage_type") == stage_type
)
def _stage_wall_ns(records, stage_type: str) -> float:
"""Interval-union wall-clock for records whose stage_type matches."""
intervals = sorted(
(r.t_start, r.t_end) for r in records
if r.params.get("stage_type") == stage_type
)
if not intervals:
return 0.0
total = 0.0
cs, ce = intervals[0]
for s, e in intervals[1:]:
if s <= ce:
ce = max(ce, e)
else:
total += ce - cs
cs, ce = s, e
total += ce - cs
return total
def _stage_count(records, stage_type: str) -> int:
return sum(1 for r in records if r.params.get("stage_type") == stage_type)
def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict:
os.environ["MATMUL_M"] = str(M)
os.environ["MATMUL_K"] = str(K)
os.environ["MATMUL_N"] = str(N)
os.environ["MATMUL_VARIANT"] = variant
# Late imports so env vars are read by matmul_composite at module load.
# Force re-import to pick up new env values.
for mod_name in [m for m in list(sys.modules)
if m.startswith("kernbench.benches.matmul_composite")]:
del sys.modules[mod_name]
from kernbench.benches.registry import resolve as resolve_bench
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology(topology)
bench = resolve_bench("matmul-composite").run
device = resolve_device(None)
t0 = time.time()
result = run_bench(
topology=topo, bench_fn=bench, device=device,
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True,
),
)
wall = time.time() - t0
op_log = result.engine.op_log
if not result.completion.ok:
raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}")
# Bytes touched at f16 (2 B): full A + full B + full out (each operand
# streamed once through HBM by the composite plan).
bytes_total = (M * K + K * N + M * N) * 2
row = {
"M": M, "K": K, "N": N,
"variant": variant,
"flops": 2 * M * K * N,
"bytes_hbm": bytes_total,
"arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte
"tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K),
"sim_wall_clock_s": round(wall, 3),
"engines": {},
}
for eng in ENGINES:
row["engines"][eng] = {
"occupancy_ns": _engine_occupancy_ns(op_log, eng),
"wall_ns": _engine_wall_ns(op_log, eng),
"record_count": _engine_count(op_log, eng),
}
row["stages"] = {}
for stage in STAGES:
row["stages"][stage] = {
"occupancy_ns": _stage_occupancy_ns(op_log, stage),
"wall_ns": _stage_wall_ns(op_log, stage),
"record_count": _stage_count(op_log, stage),
}
# Kernel-window wall-clock = max t_end - min t_start over PE engine records.
pe_records = [r for r in op_log
if any(r.component_id.endswith("." + e) for e in ENGINES)]
if pe_records:
row["pe_window_ns"] = max(r.t_end for r in pe_records) \
- min(r.t_start for r in pe_records)
else:
row["pe_window_ns"] = 0.0
stage_records = [r for r in op_log
if r.params.get("stage_type") in STAGES]
if stage_records:
row["composite_window_ns"] = max(r.t_end for r in stage_records) \
- min(r.t_start for r in stage_records)
else:
row["composite_window_ns"] = 0.0
return row
def _parse_shapes(raw) -> list[tuple[int, int, int]]:
shapes: list[tuple[int, int, int]] = []
for s in raw:
s = s.strip()
if not s:
continue
if "x" in s.lower():
parts = s.lower().split("x")
shapes.append((int(parts[0]), int(parts[1]), int(parts[2])))
else:
v = int(s)
shapes.append((v, v, v))
return shapes
def run_sweep(out_json: Path | str = DEFAULT_SWEEP_JSON) -> Path:
"""Drive matmul-composite across shapes×variants; write the sweep JSON.
Honors ``SWEEP_SHAPES`` / ``SWEEP_TOPOLOGY`` env overrides (same as the
historical ``scripts/gemm_sweep.py``). Returns the JSON path written.
"""
shapes_env = os.environ.get("SWEEP_SHAPES")
raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES)
shapes = _parse_shapes(raw)
topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml")
rows = []
for M, K, N in shapes:
for variant in VARIANTS:
print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True)
row = _run_one(M, K, N, topology, variant=variant)
rows.append(row)
eng_dma = row["engines"]["pe_dma"]
eng_gem = row["engines"]["pe_gemm"]
print(f" tiles={row['tile_count_expected']:>6} "
f"pe_window={row['pe_window_ns']:8.1f}ns "
f"dma_occ={eng_dma['occupancy_ns']:9.1f} "
f"gemm_occ={eng_gem['occupancy_ns']:8.1f} "
f"(sim {row['sim_wall_clock_s']:.1f}s)")
out_json = Path(out_json)
out_json.parent.mkdir(parents=True, exist_ok=True)
out_json.write_text(json.dumps({
"tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N},
"engines": ENGINES,
"stages": STAGES,
"variants": VARIANTS,
"rows": rows,
}, indent=2))
print(f"\n[sweep] wrote {out_json}")
return out_json
# ── figure rendering ───────────────────────────────────────────────────
# Shapes excluded from the figures (mirrors build_overview_slides).
EXCLUDED_SHAPES = {(512, 512, 512)}
# Stage bars shown (raw op_log stage_type keys) + display names + colors.
STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"]
STAGE_DISPLAY = {
"DMA_READ": "DMA in",
"FETCH": "Fetch",
"GEMM": "GEMM",
"DMA_WRITE": "DMA out",
}
STAGE_COLORS = {
"DMA_READ": "#3B82F6",
"FETCH": "#10B981",
"GEMM": "#F59E0B",
"DMA_WRITE": "#A855F7",
}
# MAC-utilization model constants (mirror build_overview_slides).
_HBM_GBS = 256.0
_BPE = 2
_T_STAGE = 16.0
_D_STAGES = 3
_PLOT_VARIANT = "load_ref"
def _load_sweep_data(sweep_json: Path | str = DEFAULT_SWEEP_JSON) -> dict:
sweep_json = Path(sweep_json)
if not sweep_json.exists():
return {"rows": []}
data = json.loads(sweep_json.read_text())
data["rows"] = [
r for r in data.get("rows", [])
if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES
]
return data
def _shape_label(r: dict) -> str:
if r["M"] == r["K"] == r["N"]:
return f"M=K=N={r['M']}"
return f"M={r['M']} K={r['K']} N={r['N']}"
def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool:
return M < tile_M or K < tile_K or N < tile_N
def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]:
out = []
for lbl, tc, fl in zip(shape_labels, tile_counts, flagged):
s = f"{lbl}\n({tc} tiles)"
if fl:
s += " *"
out.append(s)
return out
def _grouped_bar_png(
out_name: str, *, out_dir: Path, title: str, subtitle: str | None,
shape_labels, tile_counts, flagged, series: dict, colors: dict,
y_label: str, threshold: float | None = None, footnote: str | None = None,
) -> str:
"""Render one grouped-bar chart to out_dir/out_name; return the path."""
import matplotlib.pyplot as plt
import numpy as np
n_groups = len(shape_labels)
n_series = max(1, len(series))
x = np.arange(n_groups)
width = 0.8 / n_series
fig, ax = plt.subplots(figsize=(11, 6))
for i, (name, vals) in enumerate(series.items()):
offset = (i - (n_series - 1) / 2) * width
ax.bar(x + offset, vals, width, label=name, color=colors.get(name))
ax.set_xticks(x)
ax.set_xticklabels(
_xtick_labels(shape_labels, tile_counts, flagged), fontsize=8,
)
ax.set_ylabel(y_label)
ax.set_title(title, fontsize=13, fontweight="bold")
if subtitle:
ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center",
va="bottom", fontsize=8, color="#475569")
if threshold is not None:
ax.axhline(threshold, ls="--", color="gray", lw=1.0)
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, axis="y", alpha=0.3)
caption = "* = under-tile shape (M<TILE_M, K<TILE_K, or N<TILE_N)"
if footnote:
caption = footnote + "\n" + caption
fig.text(0.5, 0.01, caption, ha="center", fontsize=7, color="gray",
wrap=True)
fig.tight_layout(rect=(0, 0.05, 1, 1))
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out = out_dir / out_name
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)
def emit_stage_breakdown(
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
out_dir: Path | str = DEFAULT_PLOTS_DIR,
) -> str | None:
"""Per-stage engine wall-clock per shape (load_ref operand staging)."""
data = _load_sweep_data(sweep_json)
rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT]
if not rows:
return None
tile = data["tile_sizes"]
shape_labels = [_shape_label(r) for r in rows]
flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"])
for r in rows]
tile_counts = [r["tile_count_expected"] for r in rows]
series = {
STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0)
for r in rows]
for s in STAGE_KEYS
}
colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS}
return _grouped_bar_png(
"gemm_stage_breakdown.png", out_dir=Path(out_dir),
title="GEMM stage breakdown",
subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / "
f"DMA out), {_PLOT_VARIANT} staging. "
f"Tile {tile['M']}x{tile['K']}x{tile['N']}."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="ns",
footnote="Bars = engine wall-clock interval (merged overlaps).",
)
def emit_mac_utilization_measured(
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
out_dir: Path | str = DEFAULT_PLOTS_DIR,
) -> str | None:
"""GEMM util % and useful pipeline-eff % (analytical model, load_ref)."""
data = _load_sweep_data(sweep_json)
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gemm_util, useful_eff = [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gu = useful / (tile_flops * tiles) * 100
gemm_util.append(gu)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair
ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0
useful_eff.append(ueff)
series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff}
colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"}
return _grouped_bar_png(
"gemm_mac_utilization_measured.png", out_dir=Path(out_dir),
title="GEMM MAC utilization — load_ref",
subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); "
"Useful eff = GEMM util x ideal pipeline efficiency."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
footnote="Theoretical ideal-pipeline model (not simulator data).",
)
def emit_mac_utilization_theoretical_vs_measured(
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
out_dir: Path | str = DEFAULT_PLOTS_DIR,
) -> str | None:
"""Theoretical vs simulator-measured GEMM util / useful eff (load_ref)."""
data = _load_sweep_data(sweep_json)
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
peak_per_ns = tile_flops / _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gu_t, gu_m, eff_t, eff_m = [], [], [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gut = useful / (tile_flops * tiles)
gu_t.append(gut * 100)
rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles
gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair
eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0)
cw = r.get("composite_window_ns", 0.0) or 0.0
eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0)
series = {
"GEMM util % (theoretical)": gu_t,
"GEMM util % (measured)": gu_m,
"Theoretical eff %": eff_t,
"Measured eff %": eff_m,
}
colors = {
"GEMM util % (theoretical)": "#10B981",
"GEMM util % (measured)": "#6EE7B7",
"Theoretical eff %": "#F59E0B",
"Measured eff %": "#3B82F6",
}
return _grouped_bar_png(
"gemm_mac_utilization_theoretical_vs_measured.png", out_dir=Path(out_dir),
title="GEMM MAC utilization — theoretical vs measured (load_ref)",
subtitle=("theoretical model vs simulator op_log; agreement "
"validates the analytical pipeline model."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
)
def emit_all_gemm_plots(
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
out_dir: Path | str = DEFAULT_PLOTS_DIR,
) -> list[str]:
"""Render every GEMM figure that has data; return the paths written."""
paths = []
for fn in (emit_stage_breakdown,
emit_mac_utilization_measured,
emit_mac_utilization_theoretical_vs_measured):
p = fn(sweep_json, out_dir)
if p:
paths.append(p)
return paths
# ── bench entry ────────────────────────────────────────────────────────
@bench(
name="milestone-1h-gemm",
description="1H milestone: regenerate all GEMM results + figures.",
)
def run(torch) -> None:
"""Run the GEMM sweep (or reuse committed JSON) and render every figure.
``MILESTONE_FAST=1`` reuses the committed ``DEFAULT_SWEEP_JSON`` (seconds);
otherwise the full sweep runs into ``out_dir/gemm_sweep.json`` (minutes).
The sweep drives its own engines, so a sentinel tensor is submitted at the
end to satisfy the run_bench contract (ADR-0045 D4).
"""
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
fast = bool(os.environ.get("MILESTONE_FAST"))
if fast:
sweep_json = DEFAULT_SWEEP_JSON
else:
sweep_json = run_sweep(out_json=_OUTPUT_DIR / "gemm_sweep.json")
paths = emit_all_gemm_plots(sweep_json=sweep_json, out_dir=_OUTPUT_DIR)
print(f" milestone-1h-gemm: {len(paths)} figures -> {_OUTPUT_DIR} "
f"(fast={fast})")
torch.zeros(
(1, 1), dtype="f16",
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
name="milestone_gemm_sentinel",
)