eval: fold GEMM/allreduce harnesses into self-contained milestone benches
Move the GEMM + allreduce sweep/render logic out of scripts/ and tests/
into two self-contained eval benches so a user can regenerate every
result + figure with one command:
kernbench run --bench milestone-1h-gemm (MILESTONE_FAST=1 reuses JSON)
kernbench run --bench milestone-1h-ccl
- benches/milestone_1h_{gemm,ccl}.py: single home for each domain; the
run(torch) entry drives the sweeps and writes figures into
benches/1H_milestone_output/{gemm,ccl}/ (gitignored), then submits a
sentinel tensor to satisfy the run_bench contract.
- tests/gemm + tests/sccl helpers and scripts/gemm_sweep.py become thin
re-export/wrapper shims over the benches (single source preserved); the
pytest-only param builders + _run_distributed wrapper stay in the shim.
- eval-bench pattern: a bench may drive many configs + build its own
per-config engines (extends ADR-0045 D5; reverses ADR-0044 D1/D2).
ADR-0054 (EN+KO) records the design; ADR-0043/0044/0045 + CLAUDE.md CLI
Semantics amended; ADR INDEX regenerated. Verified: milestone benches run
clean (ok=True, all artifacts), full suite 67 passed, lang-pairs OK.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,568 @@
|
||||
"""milestone-1h-gemm bench: GEMM evaluation harness (sweep + figures).
|
||||
|
||||
Self-contained milestone bench (ADR-0054). Holds the shape×variant sweep
|
||||
and the figure renderers; the ``run(torch)`` entry at the bottom runs the
|
||||
sweep (or reuses the committed JSON when ``MILESTONE_FAST=1``) and writes
|
||||
every figure into ``benches/1H_milestone_output/gemm/``.
|
||||
|
||||
This is the single home for the GEMM eval logic: the figure tests import a
|
||||
thin re-export shim (``tests/gemm/_gemm_plot_helpers.py``), as does the
|
||||
``scripts/gemm_sweep.py`` wrapper.
|
||||
|
||||
The sweep drives ``matmul-composite`` across shapes×variants through the
|
||||
same ``run_bench`` path the CLI uses, harvests ``result.engine.op_log``,
|
||||
and writes the sweep JSON. The renderers read that JSON and emit matplotlib
|
||||
PNGs. No simulation in the renderers — they are fast.
|
||||
|
||||
Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py):
|
||||
- stage breakdown (load_ref operand staging)
|
||||
- MAC utilization — measured (load_ref)
|
||||
- MAC utilization — theoretical vs measured (load_ref)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.benches.registry import bench
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[3]
|
||||
DEFAULT_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json"
|
||||
DEFAULT_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots"
|
||||
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gemm"
|
||||
|
||||
# ── sweep configuration ────────────────────────────────────────────────
|
||||
|
||||
# Default sweep covering under-tile, single-tile, multi-tile, and asymmetric
|
||||
# regimes. Each entry is "MxKxN" or a single int (square M=K=N).
|
||||
# Override via env: SWEEP_SHAPES="16,32,16x2048x16,..."
|
||||
DEFAULT_SHAPES = [
|
||||
"32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K
|
||||
"32x64x32", # 1 tile, exact single-tile fit
|
||||
"32x128x32", # 2 tiles, aligned
|
||||
"32x128x128", # 8 tiles, aligned
|
||||
"32x3072x32", # 48 tiles, all K-axis (tall-skinny)
|
||||
"8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed
|
||||
"128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed
|
||||
"512", # 2048 tiles, fully aligned — "well-pipelined" reference
|
||||
]
|
||||
|
||||
# Operand-staging variants exercised per shape.
|
||||
VARIANTS = ["ref_ref", "load_ref", "load_load"]
|
||||
|
||||
# Engines whose timings we collect (component_id suffix match).
|
||||
ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"]
|
||||
|
||||
# Per-stage breakdown labels (StageType enum names from pe_types.py).
|
||||
STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"]
|
||||
|
||||
# Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N).
|
||||
TILE_M, TILE_K, TILE_N = 32, 64, 32
|
||||
|
||||
|
||||
def _ceil(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def _engine_wall_ns(records, suffix: str) -> float:
|
||||
"""Wall-clock interval the engine was active (union of overlapping ops)."""
|
||||
intervals = [(r.t_start, r.t_end) for r in records
|
||||
if r.component_id.endswith("." + suffix)]
|
||||
if not intervals:
|
||||
return 0.0
|
||||
intervals.sort()
|
||||
merged_end = intervals[0][1]
|
||||
merged_start = intervals[0][0]
|
||||
total = 0.0
|
||||
for s, e in intervals[1:]:
|
||||
if s <= merged_end:
|
||||
merged_end = max(merged_end, e)
|
||||
else:
|
||||
total += merged_end - merged_start
|
||||
merged_start, merged_end = s, e
|
||||
total += merged_end - merged_start
|
||||
return total
|
||||
|
||||
|
||||
def _engine_occupancy_ns(records, suffix: str) -> float:
|
||||
return sum(r.t_end - r.t_start for r in records
|
||||
if r.component_id.endswith("." + suffix))
|
||||
|
||||
|
||||
def _engine_count(records, suffix: str) -> int:
|
||||
return sum(1 for r in records if r.component_id.endswith("." + suffix))
|
||||
|
||||
|
||||
def _stage_occupancy_ns(records, stage_type: str) -> float:
|
||||
return sum(
|
||||
r.t_end - r.t_start
|
||||
for r in records
|
||||
if r.params.get("stage_type") == stage_type
|
||||
)
|
||||
|
||||
|
||||
def _stage_wall_ns(records, stage_type: str) -> float:
|
||||
"""Interval-union wall-clock for records whose stage_type matches."""
|
||||
intervals = sorted(
|
||||
(r.t_start, r.t_end) for r in records
|
||||
if r.params.get("stage_type") == stage_type
|
||||
)
|
||||
if not intervals:
|
||||
return 0.0
|
||||
total = 0.0
|
||||
cs, ce = intervals[0]
|
||||
for s, e in intervals[1:]:
|
||||
if s <= ce:
|
||||
ce = max(ce, e)
|
||||
else:
|
||||
total += ce - cs
|
||||
cs, ce = s, e
|
||||
total += ce - cs
|
||||
return total
|
||||
|
||||
|
||||
def _stage_count(records, stage_type: str) -> int:
|
||||
return sum(1 for r in records if r.params.get("stage_type") == stage_type)
|
||||
|
||||
|
||||
def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict:
|
||||
os.environ["MATMUL_M"] = str(M)
|
||||
os.environ["MATMUL_K"] = str(K)
|
||||
os.environ["MATMUL_N"] = str(N)
|
||||
os.environ["MATMUL_VARIANT"] = variant
|
||||
|
||||
# Late imports so env vars are read by matmul_composite at module load.
|
||||
# Force re-import to pick up new env values.
|
||||
for mod_name in [m for m in list(sys.modules)
|
||||
if m.startswith("kernbench.benches.matmul_composite")]:
|
||||
del sys.modules[mod_name]
|
||||
|
||||
from kernbench.benches.registry import resolve as resolve_bench
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo = resolve_topology(topology)
|
||||
bench = resolve_bench("matmul-composite").run
|
||||
device = resolve_device(None)
|
||||
|
||||
t0 = time.time()
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=bench, device=device,
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True,
|
||||
),
|
||||
)
|
||||
wall = time.time() - t0
|
||||
|
||||
op_log = result.engine.op_log
|
||||
if not result.completion.ok:
|
||||
raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}")
|
||||
|
||||
# Bytes touched at f16 (2 B): full A + full B + full out (each operand
|
||||
# streamed once through HBM by the composite plan).
|
||||
bytes_total = (M * K + K * N + M * N) * 2
|
||||
row = {
|
||||
"M": M, "K": K, "N": N,
|
||||
"variant": variant,
|
||||
"flops": 2 * M * K * N,
|
||||
"bytes_hbm": bytes_total,
|
||||
"arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte
|
||||
"tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K),
|
||||
"sim_wall_clock_s": round(wall, 3),
|
||||
"engines": {},
|
||||
}
|
||||
for eng in ENGINES:
|
||||
row["engines"][eng] = {
|
||||
"occupancy_ns": _engine_occupancy_ns(op_log, eng),
|
||||
"wall_ns": _engine_wall_ns(op_log, eng),
|
||||
"record_count": _engine_count(op_log, eng),
|
||||
}
|
||||
row["stages"] = {}
|
||||
for stage in STAGES:
|
||||
row["stages"][stage] = {
|
||||
"occupancy_ns": _stage_occupancy_ns(op_log, stage),
|
||||
"wall_ns": _stage_wall_ns(op_log, stage),
|
||||
"record_count": _stage_count(op_log, stage),
|
||||
}
|
||||
# Kernel-window wall-clock = max t_end - min t_start over PE engine records.
|
||||
pe_records = [r for r in op_log
|
||||
if any(r.component_id.endswith("." + e) for e in ENGINES)]
|
||||
if pe_records:
|
||||
row["pe_window_ns"] = max(r.t_end for r in pe_records) \
|
||||
- min(r.t_start for r in pe_records)
|
||||
else:
|
||||
row["pe_window_ns"] = 0.0
|
||||
stage_records = [r for r in op_log
|
||||
if r.params.get("stage_type") in STAGES]
|
||||
if stage_records:
|
||||
row["composite_window_ns"] = max(r.t_end for r in stage_records) \
|
||||
- min(r.t_start for r in stage_records)
|
||||
else:
|
||||
row["composite_window_ns"] = 0.0
|
||||
return row
|
||||
|
||||
|
||||
def _parse_shapes(raw) -> list[tuple[int, int, int]]:
|
||||
shapes: list[tuple[int, int, int]] = []
|
||||
for s in raw:
|
||||
s = s.strip()
|
||||
if not s:
|
||||
continue
|
||||
if "x" in s.lower():
|
||||
parts = s.lower().split("x")
|
||||
shapes.append((int(parts[0]), int(parts[1]), int(parts[2])))
|
||||
else:
|
||||
v = int(s)
|
||||
shapes.append((v, v, v))
|
||||
return shapes
|
||||
|
||||
|
||||
def run_sweep(out_json: Path | str = DEFAULT_SWEEP_JSON) -> Path:
|
||||
"""Drive matmul-composite across shapes×variants; write the sweep JSON.
|
||||
|
||||
Honors ``SWEEP_SHAPES`` / ``SWEEP_TOPOLOGY`` env overrides (same as the
|
||||
historical ``scripts/gemm_sweep.py``). Returns the JSON path written.
|
||||
"""
|
||||
shapes_env = os.environ.get("SWEEP_SHAPES")
|
||||
raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES)
|
||||
shapes = _parse_shapes(raw)
|
||||
topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml")
|
||||
|
||||
rows = []
|
||||
for M, K, N in shapes:
|
||||
for variant in VARIANTS:
|
||||
print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True)
|
||||
row = _run_one(M, K, N, topology, variant=variant)
|
||||
rows.append(row)
|
||||
eng_dma = row["engines"]["pe_dma"]
|
||||
eng_gem = row["engines"]["pe_gemm"]
|
||||
print(f" tiles={row['tile_count_expected']:>6} "
|
||||
f"pe_window={row['pe_window_ns']:8.1f}ns "
|
||||
f"dma_occ={eng_dma['occupancy_ns']:9.1f} "
|
||||
f"gemm_occ={eng_gem['occupancy_ns']:8.1f} "
|
||||
f"(sim {row['sim_wall_clock_s']:.1f}s)")
|
||||
|
||||
out_json = Path(out_json)
|
||||
out_json.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_json.write_text(json.dumps({
|
||||
"tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N},
|
||||
"engines": ENGINES,
|
||||
"stages": STAGES,
|
||||
"variants": VARIANTS,
|
||||
"rows": rows,
|
||||
}, indent=2))
|
||||
print(f"\n[sweep] wrote {out_json}")
|
||||
return out_json
|
||||
|
||||
|
||||
# ── figure rendering ───────────────────────────────────────────────────
|
||||
|
||||
# Shapes excluded from the figures (mirrors build_overview_slides).
|
||||
EXCLUDED_SHAPES = {(512, 512, 512)}
|
||||
|
||||
# Stage bars shown (raw op_log stage_type keys) + display names + colors.
|
||||
STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"]
|
||||
STAGE_DISPLAY = {
|
||||
"DMA_READ": "DMA in",
|
||||
"FETCH": "Fetch",
|
||||
"GEMM": "GEMM",
|
||||
"DMA_WRITE": "DMA out",
|
||||
}
|
||||
STAGE_COLORS = {
|
||||
"DMA_READ": "#3B82F6",
|
||||
"FETCH": "#10B981",
|
||||
"GEMM": "#F59E0B",
|
||||
"DMA_WRITE": "#A855F7",
|
||||
}
|
||||
|
||||
# MAC-utilization model constants (mirror build_overview_slides).
|
||||
_HBM_GBS = 256.0
|
||||
_BPE = 2
|
||||
_T_STAGE = 16.0
|
||||
_D_STAGES = 3
|
||||
|
||||
_PLOT_VARIANT = "load_ref"
|
||||
|
||||
|
||||
def _load_sweep_data(sweep_json: Path | str = DEFAULT_SWEEP_JSON) -> dict:
|
||||
sweep_json = Path(sweep_json)
|
||||
if not sweep_json.exists():
|
||||
return {"rows": []}
|
||||
data = json.loads(sweep_json.read_text())
|
||||
data["rows"] = [
|
||||
r for r in data.get("rows", [])
|
||||
if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
def _shape_label(r: dict) -> str:
|
||||
if r["M"] == r["K"] == r["N"]:
|
||||
return f"M=K=N={r['M']}"
|
||||
return f"M={r['M']} K={r['K']} N={r['N']}"
|
||||
|
||||
|
||||
def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool:
|
||||
return M < tile_M or K < tile_K or N < tile_N
|
||||
|
||||
|
||||
def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]:
|
||||
out = []
|
||||
for lbl, tc, fl in zip(shape_labels, tile_counts, flagged):
|
||||
s = f"{lbl}\n({tc} tiles)"
|
||||
if fl:
|
||||
s += " *"
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _grouped_bar_png(
|
||||
out_name: str, *, out_dir: Path, title: str, subtitle: str | None,
|
||||
shape_labels, tile_counts, flagged, series: dict, colors: dict,
|
||||
y_label: str, threshold: float | None = None, footnote: str | None = None,
|
||||
) -> str:
|
||||
"""Render one grouped-bar chart to out_dir/out_name; return the path."""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
n_groups = len(shape_labels)
|
||||
n_series = max(1, len(series))
|
||||
x = np.arange(n_groups)
|
||||
width = 0.8 / n_series
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 6))
|
||||
for i, (name, vals) in enumerate(series.items()):
|
||||
offset = (i - (n_series - 1) / 2) * width
|
||||
ax.bar(x + offset, vals, width, label=name, color=colors.get(name))
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(
|
||||
_xtick_labels(shape_labels, tile_counts, flagged), fontsize=8,
|
||||
)
|
||||
ax.set_ylabel(y_label)
|
||||
ax.set_title(title, fontsize=13, fontweight="bold")
|
||||
if subtitle:
|
||||
ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center",
|
||||
va="bottom", fontsize=8, color="#475569")
|
||||
if threshold is not None:
|
||||
ax.axhline(threshold, ls="--", color="gray", lw=1.0)
|
||||
ax.legend(fontsize=8, loc="upper right")
|
||||
ax.grid(True, axis="y", alpha=0.3)
|
||||
|
||||
caption = "* = under-tile shape (M<TILE_M, K<TILE_K, or N<TILE_N)"
|
||||
if footnote:
|
||||
caption = footnote + "\n" + caption
|
||||
fig.text(0.5, 0.01, caption, ha="center", fontsize=7, color="gray",
|
||||
wrap=True)
|
||||
|
||||
fig.tight_layout(rect=(0, 0.05, 1, 1))
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = out_dir / out_name
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
return str(out)
|
||||
|
||||
|
||||
def emit_stage_breakdown(
|
||||
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
|
||||
out_dir: Path | str = DEFAULT_PLOTS_DIR,
|
||||
) -> str | None:
|
||||
"""Per-stage engine wall-clock per shape (load_ref operand staging)."""
|
||||
data = _load_sweep_data(sweep_json)
|
||||
rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
shape_labels = [_shape_label(r) for r in rows]
|
||||
flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"])
|
||||
for r in rows]
|
||||
tile_counts = [r["tile_count_expected"] for r in rows]
|
||||
series = {
|
||||
STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0)
|
||||
for r in rows]
|
||||
for s in STAGE_KEYS
|
||||
}
|
||||
colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS}
|
||||
return _grouped_bar_png(
|
||||
"gemm_stage_breakdown.png", out_dir=Path(out_dir),
|
||||
title="GEMM stage breakdown",
|
||||
subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / "
|
||||
f"DMA out), {_PLOT_VARIANT} staging. "
|
||||
f"Tile {tile['M']}x{tile['K']}x{tile['N']}."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="ns",
|
||||
footnote="Bars = engine wall-clock interval (merged overlaps).",
|
||||
)
|
||||
|
||||
|
||||
def emit_mac_utilization_measured(
|
||||
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
|
||||
out_dir: Path | str = DEFAULT_PLOTS_DIR,
|
||||
) -> str | None:
|
||||
"""GEMM util % and useful pipeline-eff % (analytical model, load_ref)."""
|
||||
data = _load_sweep_data(sweep_json)
|
||||
rows = data["rows"]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
|
||||
tile_flops = 2 * TILE_M * TILE_K * TILE_N
|
||||
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
|
||||
head_ns = (_D_STAGES - 1) * _T_STAGE
|
||||
|
||||
by_shape = {(r["M"], r["K"], r["N"]): r
|
||||
for r in rows if r["variant"] == _PLOT_VARIANT}
|
||||
shapes = list(by_shape)
|
||||
if not shapes:
|
||||
return None
|
||||
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
|
||||
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
|
||||
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
|
||||
|
||||
gemm_util, useful_eff = [], []
|
||||
for k in shapes:
|
||||
r = by_shape[k]
|
||||
M, K, N = r["M"], r["K"], r["N"]
|
||||
useful = 2 * M * K * N
|
||||
tiles = r["tile_count_expected"]
|
||||
gu = useful / (tile_flops * tiles) * 100
|
||||
gemm_util.append(gu)
|
||||
m_tiles = (M + TILE_M - 1) // TILE_M
|
||||
n_tiles = (N + TILE_N - 1) // TILE_N
|
||||
n_mn = m_tiles * n_tiles
|
||||
compute_total = tiles * _T_STAGE
|
||||
wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair
|
||||
ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0
|
||||
useful_eff.append(ueff)
|
||||
|
||||
series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff}
|
||||
colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"}
|
||||
return _grouped_bar_png(
|
||||
"gemm_mac_utilization_measured.png", out_dir=Path(out_dir),
|
||||
title="GEMM MAC utilization — load_ref",
|
||||
subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); "
|
||||
"Useful eff = GEMM util x ideal pipeline efficiency."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="%", threshold=100.0,
|
||||
footnote="Theoretical ideal-pipeline model (not simulator data).",
|
||||
)
|
||||
|
||||
|
||||
def emit_mac_utilization_theoretical_vs_measured(
|
||||
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
|
||||
out_dir: Path | str = DEFAULT_PLOTS_DIR,
|
||||
) -> str | None:
|
||||
"""Theoretical vs simulator-measured GEMM util / useful eff (load_ref)."""
|
||||
data = _load_sweep_data(sweep_json)
|
||||
rows = data["rows"]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
|
||||
tile_flops = 2 * TILE_M * TILE_K * TILE_N
|
||||
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
|
||||
head_ns = (_D_STAGES - 1) * _T_STAGE
|
||||
peak_per_ns = tile_flops / _T_STAGE
|
||||
|
||||
by_shape = {(r["M"], r["K"], r["N"]): r
|
||||
for r in rows if r["variant"] == _PLOT_VARIANT}
|
||||
shapes = list(by_shape)
|
||||
if not shapes:
|
||||
return None
|
||||
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
|
||||
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
|
||||
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
|
||||
|
||||
gu_t, gu_m, eff_t, eff_m = [], [], [], []
|
||||
for k in shapes:
|
||||
r = by_shape[k]
|
||||
M, K, N = r["M"], r["K"], r["N"]
|
||||
useful = 2 * M * K * N
|
||||
tiles = r["tile_count_expected"]
|
||||
gut = useful / (tile_flops * tiles)
|
||||
gu_t.append(gut * 100)
|
||||
rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles
|
||||
gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0)
|
||||
m_tiles = (M + TILE_M - 1) // TILE_M
|
||||
n_tiles = (N + TILE_N - 1) // TILE_N
|
||||
n_mn = m_tiles * n_tiles
|
||||
compute_total = tiles * _T_STAGE
|
||||
wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair
|
||||
eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0)
|
||||
cw = r.get("composite_window_ns", 0.0) or 0.0
|
||||
eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0)
|
||||
|
||||
series = {
|
||||
"GEMM util % (theoretical)": gu_t,
|
||||
"GEMM util % (measured)": gu_m,
|
||||
"Theoretical eff %": eff_t,
|
||||
"Measured eff %": eff_m,
|
||||
}
|
||||
colors = {
|
||||
"GEMM util % (theoretical)": "#10B981",
|
||||
"GEMM util % (measured)": "#6EE7B7",
|
||||
"Theoretical eff %": "#F59E0B",
|
||||
"Measured eff %": "#3B82F6",
|
||||
}
|
||||
return _grouped_bar_png(
|
||||
"gemm_mac_utilization_theoretical_vs_measured.png", out_dir=Path(out_dir),
|
||||
title="GEMM MAC utilization — theoretical vs measured (load_ref)",
|
||||
subtitle=("theoretical model vs simulator op_log; agreement "
|
||||
"validates the analytical pipeline model."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="%", threshold=100.0,
|
||||
)
|
||||
|
||||
|
||||
def emit_all_gemm_plots(
|
||||
sweep_json: Path | str = DEFAULT_SWEEP_JSON,
|
||||
out_dir: Path | str = DEFAULT_PLOTS_DIR,
|
||||
) -> list[str]:
|
||||
"""Render every GEMM figure that has data; return the paths written."""
|
||||
paths = []
|
||||
for fn in (emit_stage_breakdown,
|
||||
emit_mac_utilization_measured,
|
||||
emit_mac_utilization_theoretical_vs_measured):
|
||||
p = fn(sweep_json, out_dir)
|
||||
if p:
|
||||
paths.append(p)
|
||||
return paths
|
||||
|
||||
|
||||
# ── bench entry ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@bench(
|
||||
name="milestone-1h-gemm",
|
||||
description="1H milestone: regenerate all GEMM results + figures.",
|
||||
)
|
||||
def run(torch) -> None:
|
||||
"""Run the GEMM sweep (or reuse committed JSON) and render every figure.
|
||||
|
||||
``MILESTONE_FAST=1`` reuses the committed ``DEFAULT_SWEEP_JSON`` (seconds);
|
||||
otherwise the full sweep runs into ``out_dir/gemm_sweep.json`` (minutes).
|
||||
The sweep drives its own engines, so a sentinel tensor is submitted at the
|
||||
end to satisfy the run_bench contract (ADR-0045 D4).
|
||||
"""
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fast = bool(os.environ.get("MILESTONE_FAST"))
|
||||
if fast:
|
||||
sweep_json = DEFAULT_SWEEP_JSON
|
||||
else:
|
||||
sweep_json = run_sweep(out_json=_OUTPUT_DIR / "gemm_sweep.json")
|
||||
paths = emit_all_gemm_plots(sweep_json=sweep_json, out_dir=_OUTPUT_DIR)
|
||||
print(f" milestone-1h-gemm: {len(paths)} figures -> {_OUTPUT_DIR} "
|
||||
f"(fast={fast})")
|
||||
|
||||
torch.zeros(
|
||||
(1, 1), dtype="f16",
|
||||
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
|
||||
name="milestone_gemm_sentinel",
|
||||
)
|
||||
Reference in New Issue
Block a user