gemm: test-generated GEMM plots under tests/gemm/ + docs/diagrams/gemm_plots/
Mirror the sccl pattern for GEMM figures: a tests/gemm/ package renders the GEMM bar charts as PNGs from the committed docs/diagrams/gemm_sweep.json, so the figures are fast test artifacts (run by default) while the heavy sim sweep stays a manual script (scripts/gemm_sweep.py, kept) wrapped by a slow regenerator test. tests/gemm/: - _gemm_plot_helpers.py: matplotlib renderers (series logic mirrors the GEMM _render_* functions in scripts/build_overview_slides.py). - test_plot_gemm_stage_breakdown.py: gemm_stage_breakdown.png (load_ref). - test_plot_gemm_mac_utilization.py: gemm_mac_utilization_measured.png + gemm_mac_utilization_theoretical_vs_measured.png (load_ref). - test_gemm_sweep.py: @pytest.mark.slow regenerator (runs scripts/gemm_sweep.py). Chart set trimmed to three (stage breakdown, MAC util, theoretical-vs-measured); "formula" relabeled to "theoretical" throughout the comparison chart. Known follow-ups (not blocking): - gemm_mac_utilization_measured.png currently plots the theoretical ideal- pipeline model, not simulator-measured data; the name is a misnomer pending a decision to repoint its content or retitle. - The theoretical-model constants (HBM 256 GB/s, T_stage 16 ns, 3 stages) are inherited verbatim from build_overview_slides.py and not yet verified against ADR-0033 / ADR-0014 / topology. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,283 @@
|
||||
"""Shared plotting plumbing for the GEMM figure tests.
|
||||
|
||||
Not a test module (no ``test_`` prefix -> pytest does not collect it).
|
||||
|
||||
Reads the committed ``docs/diagrams/gemm_sweep.json`` (produced by the heavy
|
||||
``scripts/gemm_sweep.py`` sim sweep) and renders matplotlib PNGs into
|
||||
``docs/diagrams/gemm_plots/``. No simulation here -> the figure tests are fast
|
||||
and run by default; regenerating the underlying data stays a manual script.
|
||||
|
||||
Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py):
|
||||
- stage breakdown (load_ref operand staging)
|
||||
- MAC utilization — measured (load_ref)
|
||||
- MAC utilization — theoretical vs measured (load_ref)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
GEMM_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json"
|
||||
GEMM_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots"
|
||||
|
||||
# Shapes excluded from the figures (mirrors build_overview_slides).
|
||||
EXCLUDED_SHAPES = {(512, 512, 512)}
|
||||
|
||||
# Stage bars shown (raw op_log stage_type keys) + display names + colors.
|
||||
STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"]
|
||||
STAGE_DISPLAY = {
|
||||
"DMA_READ": "DMA in",
|
||||
"FETCH": "Fetch",
|
||||
"GEMM": "GEMM",
|
||||
"DMA_WRITE": "DMA out",
|
||||
}
|
||||
STAGE_COLORS = {
|
||||
"DMA_READ": "#3B82F6",
|
||||
"FETCH": "#10B981",
|
||||
"GEMM": "#F59E0B",
|
||||
"DMA_WRITE": "#A855F7",
|
||||
}
|
||||
|
||||
# MAC-utilization model constants (mirror build_overview_slides).
|
||||
_HBM_GBS = 256.0
|
||||
_BPE = 2
|
||||
_T_STAGE = 16.0
|
||||
_D_STAGES = 3
|
||||
|
||||
_PLOT_VARIANT = "load_ref"
|
||||
|
||||
|
||||
def _load_sweep_data() -> dict:
|
||||
if not GEMM_SWEEP_JSON.exists():
|
||||
return {"rows": []}
|
||||
data = json.loads(GEMM_SWEEP_JSON.read_text())
|
||||
data["rows"] = [
|
||||
r for r in data.get("rows", [])
|
||||
if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
def _shape_label(r: dict) -> str:
|
||||
if r["M"] == r["K"] == r["N"]:
|
||||
return f"M=K=N={r['M']}"
|
||||
return f"M={r['M']} K={r['K']} N={r['N']}"
|
||||
|
||||
|
||||
def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool:
|
||||
return M < tile_M or K < tile_K or N < tile_N
|
||||
|
||||
|
||||
def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]:
|
||||
out = []
|
||||
for lbl, tc, fl in zip(shape_labels, tile_counts, flagged):
|
||||
s = f"{lbl}\n({tc} tiles)"
|
||||
if fl:
|
||||
s += " *"
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _grouped_bar_png(
|
||||
out_name: str, *, title: str, subtitle: str | None,
|
||||
shape_labels, tile_counts, flagged, series: dict, colors: dict,
|
||||
y_label: str, threshold: float | None = None, footnote: str | None = None,
|
||||
) -> str:
|
||||
"""Render one grouped-bar chart to GEMM_PLOTS_DIR/out_name; return the path."""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
n_groups = len(shape_labels)
|
||||
n_series = max(1, len(series))
|
||||
x = np.arange(n_groups)
|
||||
width = 0.8 / n_series
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 6))
|
||||
for i, (name, vals) in enumerate(series.items()):
|
||||
offset = (i - (n_series - 1) / 2) * width
|
||||
ax.bar(x + offset, vals, width, label=name, color=colors.get(name))
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(
|
||||
_xtick_labels(shape_labels, tile_counts, flagged), fontsize=8,
|
||||
)
|
||||
ax.set_ylabel(y_label)
|
||||
ax.set_title(title, fontsize=13, fontweight="bold")
|
||||
if subtitle:
|
||||
ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center",
|
||||
va="bottom", fontsize=8, color="#475569")
|
||||
if threshold is not None:
|
||||
ax.axhline(threshold, ls="--", color="gray", lw=1.0)
|
||||
ax.legend(fontsize=8, loc="upper right")
|
||||
ax.grid(True, axis="y", alpha=0.3)
|
||||
|
||||
caption = "* = under-tile shape (M<TILE_M, K<TILE_K, or N<TILE_N)"
|
||||
if footnote:
|
||||
caption = footnote + "\n" + caption
|
||||
fig.text(0.5, 0.01, caption, ha="center", fontsize=7, color="gray",
|
||||
wrap=True)
|
||||
|
||||
fig.tight_layout(rect=(0, 0.05, 1, 1))
|
||||
GEMM_PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
out = GEMM_PLOTS_DIR / out_name
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
return str(out)
|
||||
|
||||
|
||||
# ── individual chart renderers (read sweep JSON, emit one PNG each) ─────
|
||||
|
||||
|
||||
def emit_stage_breakdown() -> str | None:
|
||||
"""Per-stage engine wall-clock per shape (load_ref operand staging)."""
|
||||
data = _load_sweep_data()
|
||||
rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
shape_labels = [_shape_label(r) for r in rows]
|
||||
flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"])
|
||||
for r in rows]
|
||||
tile_counts = [r["tile_count_expected"] for r in rows]
|
||||
series = {
|
||||
STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0)
|
||||
for r in rows]
|
||||
for s in STAGE_KEYS
|
||||
}
|
||||
colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS}
|
||||
return _grouped_bar_png(
|
||||
"gemm_stage_breakdown.png",
|
||||
title="GEMM stage breakdown",
|
||||
subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / "
|
||||
f"DMA out), {_PLOT_VARIANT} staging. "
|
||||
f"Tile {tile['M']}x{tile['K']}x{tile['N']}."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="ns",
|
||||
footnote="Bars = engine wall-clock interval (merged overlaps).",
|
||||
)
|
||||
|
||||
|
||||
def emit_mac_utilization_measured() -> str | None:
|
||||
"""GEMM util % and useful pipeline-eff % (analytical model, load_ref)."""
|
||||
data = _load_sweep_data()
|
||||
rows = data["rows"]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
|
||||
tile_flops = 2 * TILE_M * TILE_K * TILE_N
|
||||
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
|
||||
head_ns = (_D_STAGES - 1) * _T_STAGE
|
||||
|
||||
by_shape = {(r["M"], r["K"], r["N"]): r
|
||||
for r in rows if r["variant"] == _PLOT_VARIANT}
|
||||
shapes = list(by_shape)
|
||||
if not shapes:
|
||||
return None
|
||||
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
|
||||
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
|
||||
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
|
||||
|
||||
gemm_util, useful_eff = [], []
|
||||
for k in shapes:
|
||||
r = by_shape[k]
|
||||
M, K, N = r["M"], r["K"], r["N"]
|
||||
useful = 2 * M * K * N
|
||||
tiles = r["tile_count_expected"]
|
||||
gu = useful / (tile_flops * tiles) * 100
|
||||
gemm_util.append(gu)
|
||||
m_tiles = (M + TILE_M - 1) // TILE_M
|
||||
n_tiles = (N + TILE_N - 1) // TILE_N
|
||||
n_mn = m_tiles * n_tiles
|
||||
compute_total = tiles * _T_STAGE
|
||||
wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair
|
||||
ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0
|
||||
useful_eff.append(ueff)
|
||||
|
||||
series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff}
|
||||
colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"}
|
||||
return _grouped_bar_png(
|
||||
"gemm_mac_utilization_measured.png",
|
||||
title="GEMM MAC utilization — load_ref",
|
||||
subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); "
|
||||
"Useful eff = GEMM util x ideal pipeline efficiency."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="%", threshold=100.0,
|
||||
footnote="Theoretical ideal-pipeline model (not simulator data).",
|
||||
)
|
||||
|
||||
|
||||
def emit_mac_utilization_theoretical_vs_measured() -> str | None:
|
||||
"""Theoretical vs simulator-measured GEMM util / useful eff (load_ref)."""
|
||||
data = _load_sweep_data()
|
||||
rows = data["rows"]
|
||||
if not rows:
|
||||
return None
|
||||
tile = data["tile_sizes"]
|
||||
TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"]
|
||||
tile_flops = 2 * TILE_M * TILE_K * TILE_N
|
||||
dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS
|
||||
head_ns = (_D_STAGES - 1) * _T_STAGE
|
||||
peak_per_ns = tile_flops / _T_STAGE
|
||||
|
||||
by_shape = {(r["M"], r["K"], r["N"]): r
|
||||
for r in rows if r["variant"] == _PLOT_VARIANT}
|
||||
shapes = list(by_shape)
|
||||
if not shapes:
|
||||
return None
|
||||
shape_labels = [_shape_label(by_shape[k]) for k in shapes]
|
||||
flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes]
|
||||
tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes]
|
||||
|
||||
gu_t, gu_m, eff_t, eff_m = [], [], [], []
|
||||
for k in shapes:
|
||||
r = by_shape[k]
|
||||
M, K, N = r["M"], r["K"], r["N"]
|
||||
useful = 2 * M * K * N
|
||||
tiles = r["tile_count_expected"]
|
||||
gut = useful / (tile_flops * tiles)
|
||||
gu_t.append(gut * 100)
|
||||
rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles
|
||||
gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0)
|
||||
m_tiles = (M + TILE_M - 1) // TILE_M
|
||||
n_tiles = (N + TILE_N - 1) // TILE_N
|
||||
n_mn = m_tiles * n_tiles
|
||||
compute_total = tiles * _T_STAGE
|
||||
wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair
|
||||
eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0)
|
||||
cw = r.get("composite_window_ns", 0.0) or 0.0
|
||||
eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0)
|
||||
|
||||
series = {
|
||||
"GEMM util % (theoretical)": gu_t,
|
||||
"GEMM util % (measured)": gu_m,
|
||||
"Theoretical eff %": eff_t,
|
||||
"Measured eff %": eff_m,
|
||||
}
|
||||
colors = {
|
||||
"GEMM util % (theoretical)": "#10B981",
|
||||
"GEMM util % (measured)": "#6EE7B7",
|
||||
"Theoretical eff %": "#F59E0B",
|
||||
"Measured eff %": "#3B82F6",
|
||||
}
|
||||
return _grouped_bar_png(
|
||||
"gemm_mac_utilization_theoretical_vs_measured.png",
|
||||
title="GEMM MAC utilization — theoretical vs measured (load_ref)",
|
||||
subtitle=("theoretical model vs simulator op_log; agreement "
|
||||
"validates the analytical pipeline model."),
|
||||
shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged,
|
||||
series=series, colors=colors, y_label="%", threshold=100.0,
|
||||
)
|
||||
|
||||
|
||||
def emit_all_gemm_plots() -> list[str]:
|
||||
"""Render every GEMM figure that has data; return the list of paths written."""
|
||||
paths = []
|
||||
for fn in (emit_stage_breakdown,
|
||||
emit_mac_utilization_measured,
|
||||
emit_mac_utilization_theoretical_vs_measured):
|
||||
p = fn()
|
||||
if p:
|
||||
paths.append(p)
|
||||
return paths
|
||||
Reference in New Issue
Block a user