gemm: test-generated GEMM plots under tests/gemm/ + docs/diagrams/gemm_plots/

Mirror the sccl pattern for GEMM figures: a tests/gemm/ package renders the
GEMM bar charts as PNGs from the committed docs/diagrams/gemm_sweep.json, so
the figures are fast test artifacts (run by default) while the heavy sim sweep
stays a manual script (scripts/gemm_sweep.py, kept) wrapped by a slow
regenerator test.

tests/gemm/:
- _gemm_plot_helpers.py: matplotlib renderers (series logic mirrors the
  GEMM _render_* functions in scripts/build_overview_slides.py).
- test_plot_gemm_stage_breakdown.py: gemm_stage_breakdown.png (load_ref).
- test_plot_gemm_mac_utilization.py: gemm_mac_utilization_measured.png +
  gemm_mac_utilization_theoretical_vs_measured.png (load_ref).
- test_gemm_sweep.py: @pytest.mark.slow regenerator (runs scripts/gemm_sweep.py).

Chart set trimmed to three (stage breakdown, MAC util, theoretical-vs-measured);
"formula" relabeled to "theoretical" throughout the comparison chart.

Known follow-ups (not blocking):
- gemm_mac_utilization_measured.png currently plots the theoretical ideal-
  pipeline model, not simulator-measured data; the name is a misnomer pending
  a decision to repoint its content or retitle.
- The theoretical-model constants (HBM 256 GB/s, T_stage 16 ns, 3 stages) are
  inherited verbatim from build_overview_slides.py and not yet verified against
  ADR-0033 / ADR-0014 / topology.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 09:58:08 -07:00
parent b610cb0d9a
commit 0e346b939d
7 changed files with 378 additions and 0 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

+283
View File
@@ -0,0 +1,283 @@
"""Shared plotting plumbing for the GEMM figure tests.
Not a test module (no ``test_`` prefix -> pytest does not collect it).
Reads the committed ``docs/diagrams/gemm_sweep.json`` (produced by the heavy
``scripts/gemm_sweep.py`` sim sweep) and renders matplotlib PNGs into
``docs/diagrams/gemm_plots/``. No simulation here -> the figure tests are fast
and run by default; regenerating the underlying data stays a manual script.
Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py):
- stage breakdown (load_ref operand staging)
- MAC utilization — measured (load_ref)
- MAC utilization — theoretical vs measured (load_ref)
"""
from __future__ import annotations
import json
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent.parent
GEMM_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json"
GEMM_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots"
# Shapes excluded from the figures (mirrors build_overview_slides).
EXCLUDED_SHAPES = {(512, 512, 512)}
# Stage bars shown (raw op_log stage_type keys) + display names + colors.
STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"]
STAGE_DISPLAY = {
"DMA_READ": "DMA in",
"FETCH": "Fetch",
"GEMM": "GEMM",
"DMA_WRITE": "DMA out",
}
STAGE_COLORS = {
"DMA_READ": "#3B82F6",
"FETCH": "#10B981",
"GEMM": "#F59E0B",
"DMA_WRITE": "#A855F7",
}
# MAC-utilization model constants (mirror build_overview_slides).
_HBM_GBS = 256.0
_BPE = 2
_T_STAGE = 16.0
_D_STAGES = 3
_PLOT_VARIANT = "load_ref"
def _load_sweep_data() -> dict:
if not GEMM_SWEEP_JSON.exists():
return {"rows": []}
data = json.loads(GEMM_SWEEP_JSON.read_text())
data["rows"] = [
r for r in data.get("rows", [])
if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES
]
return data
def _shape_label(r: dict) -> str:
if r["M"] == r["K"] == r["N"]:
return f"M=K=N={r['M']}"
return f"M={r['M']} K={r['K']} N={r['N']}"
def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool:
return M < tile_M or K < tile_K or N < tile_N
def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]:
out = []
for lbl, tc, fl in zip(shape_labels, tile_counts, flagged):
s = f"{lbl}\n({tc} tiles)"
if fl:
s += " *"
out.append(s)
return out
def _grouped_bar_png(
out_name: str, *, title: str, subtitle: str | None,
shape_labels, tile_counts, flagged, series: dict, colors: dict,
y_label: str, threshold: float | None = None, footnote: str | None = None,
) -> str:
"""Render one grouped-bar chart to GEMM_PLOTS_DIR/out_name; return the path."""
import matplotlib.pyplot as plt
import numpy as np
n_groups = len(shape_labels)
n_series = max(1, len(series))
x = np.arange(n_groups)
width = 0.8 / n_series
fig, ax = plt.subplots(figsize=(11, 6))
for i, (name, vals) in enumerate(series.items()):
offset = (i - (n_series - 1) / 2) * width
ax.bar(x + offset, vals, width, label=name, color=colors.get(name))
ax.set_xticks(x)
ax.set_xticklabels(
_xtick_labels(shape_labels, tile_counts, flagged), fontsize=8,
)
ax.set_ylabel(y_label)
ax.set_title(title, fontsize=13, fontweight="bold")
if subtitle:
ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center",
va="bottom", fontsize=8, color="#475569")
if threshold is not None:
ax.axhline(threshold, ls="--", color="gray", lw=1.0)
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, axis="y", alpha=0.3)
caption = "* = under-tile shape (M<TILE_M, K<TILE_K, or N<TILE_N)"
if footnote:
caption = footnote + "\n" + caption
fig.text(0.5, 0.01, caption, ha="center", fontsize=7, color="gray",
wrap=True)
fig.tight_layout(rect=(0, 0.05, 1, 1))
GEMM_PLOTS_DIR.mkdir(parents=True, exist_ok=True)
out = GEMM_PLOTS_DIR / out_name
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)
# ── individual chart renderers (read sweep JSON, emit one PNG each) ─────
def emit_stage_breakdown() -> str | None:
"""Per-stage engine wall-clock per shape (load_ref operand staging)."""
data = _load_sweep_data()
rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT]
if not rows:
return None
tile = data["tile_sizes"]
shape_labels = [_shape_label(r) for r in rows]
flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"])
for r in rows]
tile_counts = [r["tile_count_expected"] for r in rows]
series = {
STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0)
for r in rows]
for s in STAGE_KEYS
}
colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS}
return _grouped_bar_png(
"gemm_stage_breakdown.png",
title="GEMM stage breakdown",
subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / "
f"DMA out), {_PLOT_VARIANT} staging. "
f"Tile {tile['M']}x{tile['K']}x{tile['N']}."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="ns",
footnote="Bars = engine wall-clock interval (merged overlaps).",
)
def emit_mac_utilization_measured() -> str | None:
"""GEMM util % and useful pipeline-eff % (analytical model, load_ref)."""
data = _load_sweep_data()
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gemm_util, useful_eff = [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gu = useful / (tile_flops * tiles) * 100
gemm_util.append(gu)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair
ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0
useful_eff.append(ueff)
series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff}
colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"}
return _grouped_bar_png(
"gemm_mac_utilization_measured.png",
title="GEMM MAC utilization — load_ref",
subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); "
"Useful eff = GEMM util x ideal pipeline efficiency."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
footnote="Theoretical ideal-pipeline model (not simulator data).",
)
def emit_mac_utilization_theoretical_vs_measured() -> str | None:
"""Theoretical vs simulator-measured GEMM util / useful eff (load_ref)."""
data = _load_sweep_data()
rows = data["rows"]
if not rows:
return None
tile = data["tile_sizes"]
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
tile_flops = 2 * TILE_M * TILE_K * TILE_N
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
head_ns = (_D_STAGES - 1) * _T_STAGE
peak_per_ns = tile_flops / _T_STAGE
by_shape = {(r["M"], r["K"], r["N"]): r
for r in rows if r["variant"] == _PLOT_VARIANT}
shapes = list(by_shape)
if not shapes:
return None
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
gu_t, gu_m, eff_t, eff_m = [], [], [], []
for k in shapes:
r = by_shape[k]
M, K, N = r["M"], r["K"], r["N"]
useful = 2 * M * K * N
tiles = r["tile_count_expected"]
gut = useful / (tile_flops * tiles)
gu_t.append(gut * 100)
rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles
gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0)
m_tiles = (M + TILE_M - 1) // TILE_M
n_tiles = (N + TILE_N - 1) // TILE_N
n_mn = m_tiles * n_tiles
compute_total = tiles * _T_STAGE
wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair
eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0)
cw = r.get("composite_window_ns", 0.0) or 0.0
eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0)
series = {
"GEMM util % (theoretical)": gu_t,
"GEMM util % (measured)": gu_m,
"Theoretical eff %": eff_t,
"Measured eff %": eff_m,
}
colors = {
"GEMM util % (theoretical)": "#10B981",
"GEMM util % (measured)": "#6EE7B7",
"Theoretical eff %": "#F59E0B",
"Measured eff %": "#3B82F6",
}
return _grouped_bar_png(
"gemm_mac_utilization_theoretical_vs_measured.png",
title="GEMM MAC utilization — theoretical vs measured (load_ref)",
subtitle=("theoretical model vs simulator op_log; agreement "
"validates the analytical pipeline model."),
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
series=series, colors=colors, y_label="%", threshold=100.0,
)
def emit_all_gemm_plots() -> list[str]:
"""Render every GEMM figure that has data; return the list of paths written."""
paths = []
for fn in (emit_stage_breakdown,
emit_mac_utilization_measured,
emit_mac_utilization_theoretical_vs_measured):
p = fn()
if p:
paths.append(p)
return paths
+36
View File
@@ -0,0 +1,36 @@
"""Regenerate docs/diagrams/gemm_sweep.json by running the GEMM sweep.
Heavy: drives matmul-composite across all shapes x variants through the
simulator (24 runs; the 512 shape alone is 2048 tiles). Marked ``slow`` so it
is excluded from the default ``pytest`` run (addopts: -m "not slow") and runs
on demand:
pytest -m slow tests/gemm/test_gemm_sweep.py
Delegates to scripts/gemm_sweep.py (the single source of the sweep logic) via
subprocess so there is no duplicated sim-driving code.
"""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
import pytest
from tests.gemm._gemm_plot_helpers import GEMM_SWEEP_JSON, ROOT
@pytest.mark.slow
def test_gemm_sweep_regenerates_json():
script = ROOT / "scripts" / "gemm_sweep.py"
assert script.exists(), f"missing {script}"
proc = subprocess.run(
[sys.executable, str(script)],
cwd=str(ROOT), capture_output=True, text=True,
)
assert proc.returncode == 0, (
f"gemm_sweep.py failed (rc={proc.returncode})\n"
f"stdout:\n{proc.stdout[-2000:]}\nstderr:\n{proc.stderr[-2000:]}"
)
assert Path(GEMM_SWEEP_JSON).exists()
@@ -0,0 +1,35 @@
"""Emit the GEMM MAC-utilization bar charts.
A measured chart (load_ref) plus the theoretical-vs-measured overlay (load_ref).
Reads docs/diagrams/gemm_sweep.json and writes gemm_mac_utilization*.png into
docs/diagrams/gemm_plots/.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from tests.gemm._gemm_plot_helpers import (
GEMM_SWEEP_JSON,
emit_mac_utilization_measured,
emit_mac_utilization_theoretical_vs_measured,
)
@pytest.mark.skipif(
not GEMM_SWEEP_JSON.exists(),
reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first",
)
def test_plot_gemm_mac_utilization_measured():
out = emit_mac_utilization_measured()
assert out is not None and Path(out).exists()
@pytest.mark.skipif(
not GEMM_SWEEP_JSON.exists(),
reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first",
)
def test_plot_gemm_mac_utilization_theoretical_vs_measured():
out = emit_mac_utilization_theoretical_vs_measured()
assert out is not None and Path(out).exists()
@@ -0,0 +1,24 @@
"""Emit the GEMM per-stage engine wall-clock bar chart (load_ref).
Reads docs/diagrams/gemm_sweep.json (run scripts/gemm_sweep.py to refresh it)
and writes gemm_stage_breakdown.png into docs/diagrams/gemm_plots/.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from tests.gemm._gemm_plot_helpers import (
GEMM_SWEEP_JSON,
emit_stage_breakdown,
)
@pytest.mark.skipif(
not GEMM_SWEEP_JSON.exists(),
reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first",
)
def test_plot_gemm_stage_breakdown():
out = emit_stage_breakdown()
assert out is not None and Path(out).exists()