diff --git a/docs/diagrams/gemm_plots/gemm_mac_utilization_measured.png b/docs/diagrams/gemm_plots/gemm_mac_utilization_measured.png new file mode 100644 index 0000000..a896aaa Binary files /dev/null and b/docs/diagrams/gemm_plots/gemm_mac_utilization_measured.png differ diff --git a/docs/diagrams/gemm_plots/gemm_mac_utilization_theoretical_vs_measured.png b/docs/diagrams/gemm_plots/gemm_mac_utilization_theoretical_vs_measured.png new file mode 100644 index 0000000..9279134 Binary files /dev/null and b/docs/diagrams/gemm_plots/gemm_mac_utilization_theoretical_vs_measured.png differ diff --git a/docs/diagrams/gemm_plots/gemm_stage_breakdown.png b/docs/diagrams/gemm_plots/gemm_stage_breakdown.png new file mode 100644 index 0000000..1130b3c Binary files /dev/null and b/docs/diagrams/gemm_plots/gemm_stage_breakdown.png differ diff --git a/tests/gemm/_gemm_plot_helpers.py b/tests/gemm/_gemm_plot_helpers.py new file mode 100644 index 0000000..b7a97ce --- /dev/null +++ b/tests/gemm/_gemm_plot_helpers.py @@ -0,0 +1,283 @@ +"""Shared plotting plumbing for the GEMM figure tests. + +Not a test module (no ``test_`` prefix -> pytest does not collect it). + +Reads the committed ``docs/diagrams/gemm_sweep.json`` (produced by the heavy +``scripts/gemm_sweep.py`` sim sweep) and renders matplotlib PNGs into +``docs/diagrams/gemm_plots/``. No simulation here -> the figure tests are fast +and run by default; regenerating the underlying data stays a manual script. + +Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py): + - stage breakdown (load_ref operand staging) + - MAC utilization — measured (load_ref) + - MAC utilization — theoretical vs measured (load_ref) +""" +from __future__ import annotations + +import json +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent.parent +GEMM_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json" +GEMM_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots" + +# Shapes excluded from the figures (mirrors build_overview_slides). +EXCLUDED_SHAPES = {(512, 512, 512)} + +# Stage bars shown (raw op_log stage_type keys) + display names + colors. +STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"] +STAGE_DISPLAY = { + "DMA_READ": "DMA in", + "FETCH": "Fetch", + "GEMM": "GEMM", + "DMA_WRITE": "DMA out", +} +STAGE_COLORS = { + "DMA_READ": "#3B82F6", + "FETCH": "#10B981", + "GEMM": "#F59E0B", + "DMA_WRITE": "#A855F7", +} + +# MAC-utilization model constants (mirror build_overview_slides). +_HBM_GBS = 256.0 +_BPE = 2 +_T_STAGE = 16.0 +_D_STAGES = 3 + +_PLOT_VARIANT = "load_ref" + + +def _load_sweep_data() -> dict: + if not GEMM_SWEEP_JSON.exists(): + return {"rows": []} + data = json.loads(GEMM_SWEEP_JSON.read_text()) + data["rows"] = [ + r for r in data.get("rows", []) + if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES + ] + return data + + +def _shape_label(r: dict) -> str: + if r["M"] == r["K"] == r["N"]: + return f"M=K=N={r['M']}" + return f"M={r['M']} K={r['K']} N={r['N']}" + + +def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool: + return M < tile_M or K < tile_K or N < tile_N + + +def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]: + out = [] + for lbl, tc, fl in zip(shape_labels, tile_counts, flagged): + s = f"{lbl}\n({tc} tiles)" + if fl: + s += " *" + out.append(s) + return out + + +def _grouped_bar_png( + out_name: str, *, title: str, subtitle: str | None, + shape_labels, tile_counts, flagged, series: dict, colors: dict, + y_label: str, threshold: float | None = None, footnote: str | None = None, +) -> str: + """Render one grouped-bar chart to GEMM_PLOTS_DIR/out_name; return the path.""" + import matplotlib.pyplot as plt + import numpy as np + + n_groups = len(shape_labels) + n_series = max(1, len(series)) + x = np.arange(n_groups) + width = 0.8 / n_series + + fig, ax = plt.subplots(figsize=(11, 6)) + for i, (name, vals) in enumerate(series.items()): + offset = (i - (n_series - 1) / 2) * width + ax.bar(x + offset, vals, width, label=name, color=colors.get(name)) + + ax.set_xticks(x) + ax.set_xticklabels( + _xtick_labels(shape_labels, tile_counts, flagged), fontsize=8, + ) + ax.set_ylabel(y_label) + ax.set_title(title, fontsize=13, fontweight="bold") + if subtitle: + ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", + va="bottom", fontsize=8, color="#475569") + if threshold is not None: + ax.axhline(threshold, ls="--", color="gray", lw=1.0) + ax.legend(fontsize=8, loc="upper right") + ax.grid(True, axis="y", alpha=0.3) + + caption = "* = under-tile shape (M str | None: + """Per-stage engine wall-clock per shape (load_ref operand staging).""" + data = _load_sweep_data() + rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT] + if not rows: + return None + tile = data["tile_sizes"] + shape_labels = [_shape_label(r) for r in rows] + flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"]) + for r in rows] + tile_counts = [r["tile_count_expected"] for r in rows] + series = { + STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0) + for r in rows] + for s in STAGE_KEYS + } + colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS} + return _grouped_bar_png( + "gemm_stage_breakdown.png", + title="GEMM stage breakdown", + subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / " + f"DMA out), {_PLOT_VARIANT} staging. " + f"Tile {tile['M']}x{tile['K']}x{tile['N']}."), + shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, + series=series, colors=colors, y_label="ns", + footnote="Bars = engine wall-clock interval (merged overlaps).", + ) + + +def emit_mac_utilization_measured() -> str | None: + """GEMM util % and useful pipeline-eff % (analytical model, load_ref).""" + data = _load_sweep_data() + rows = data["rows"] + if not rows: + return None + tile = data["tile_sizes"] + TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] + tile_flops = 2 * TILE_M * TILE_K * TILE_N + dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS + head_ns = (_D_STAGES - 1) * _T_STAGE + + by_shape = {(r["M"], r["K"], r["N"]): r + for r in rows if r["variant"] == _PLOT_VARIANT} + shapes = list(by_shape) + if not shapes: + return None + shape_labels = [_shape_label(by_shape[k]) for k in shapes] + flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes] + tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] + + gemm_util, useful_eff = [], [] + for k in shapes: + r = by_shape[k] + M, K, N = r["M"], r["K"], r["N"] + useful = 2 * M * K * N + tiles = r["tile_count_expected"] + gu = useful / (tile_flops * tiles) * 100 + gemm_util.append(gu) + m_tiles = (M + TILE_M - 1) // TILE_M + n_tiles = (N + TILE_N - 1) // TILE_N + n_mn = m_tiles * n_tiles + compute_total = tiles * _T_STAGE + wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair + ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0 + useful_eff.append(ueff) + + series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff} + colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"} + return _grouped_bar_png( + "gemm_mac_utilization_measured.png", + title="GEMM MAC utilization — load_ref", + subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); " + "Useful eff = GEMM util x ideal pipeline efficiency."), + shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, + series=series, colors=colors, y_label="%", threshold=100.0, + footnote="Theoretical ideal-pipeline model (not simulator data).", + ) + + +def emit_mac_utilization_theoretical_vs_measured() -> str | None: + """Theoretical vs simulator-measured GEMM util / useful eff (load_ref).""" + data = _load_sweep_data() + rows = data["rows"] + if not rows: + return None + tile = data["tile_sizes"] + TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] + tile_flops = 2 * TILE_M * TILE_K * TILE_N + dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS + head_ns = (_D_STAGES - 1) * _T_STAGE + peak_per_ns = tile_flops / _T_STAGE + + by_shape = {(r["M"], r["K"], r["N"]): r + for r in rows if r["variant"] == _PLOT_VARIANT} + shapes = list(by_shape) + if not shapes: + return None + shape_labels = [_shape_label(by_shape[k]) for k in shapes] + flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes] + tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] + + gu_t, gu_m, eff_t, eff_m = [], [], [], [] + for k in shapes: + r = by_shape[k] + M, K, N = r["M"], r["K"], r["N"] + useful = 2 * M * K * N + tiles = r["tile_count_expected"] + gut = useful / (tile_flops * tiles) + gu_t.append(gut * 100) + rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles + gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0) + m_tiles = (M + TILE_M - 1) // TILE_M + n_tiles = (N + TILE_N - 1) // TILE_N + n_mn = m_tiles * n_tiles + compute_total = tiles * _T_STAGE + wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair + eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0) + cw = r.get("composite_window_ns", 0.0) or 0.0 + eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0) + + series = { + "GEMM util % (theoretical)": gu_t, + "GEMM util % (measured)": gu_m, + "Theoretical eff %": eff_t, + "Measured eff %": eff_m, + } + colors = { + "GEMM util % (theoretical)": "#10B981", + "GEMM util % (measured)": "#6EE7B7", + "Theoretical eff %": "#F59E0B", + "Measured eff %": "#3B82F6", + } + return _grouped_bar_png( + "gemm_mac_utilization_theoretical_vs_measured.png", + title="GEMM MAC utilization — theoretical vs measured (load_ref)", + subtitle=("theoretical model vs simulator op_log; agreement " + "validates the analytical pipeline model."), + shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, + series=series, colors=colors, y_label="%", threshold=100.0, + ) + + +def emit_all_gemm_plots() -> list[str]: + """Render every GEMM figure that has data; return the list of paths written.""" + paths = [] + for fn in (emit_stage_breakdown, + emit_mac_utilization_measured, + emit_mac_utilization_theoretical_vs_measured): + p = fn() + if p: + paths.append(p) + return paths diff --git a/tests/gemm/test_gemm_sweep.py b/tests/gemm/test_gemm_sweep.py new file mode 100644 index 0000000..39fa646 --- /dev/null +++ b/tests/gemm/test_gemm_sweep.py @@ -0,0 +1,36 @@ +"""Regenerate docs/diagrams/gemm_sweep.json by running the GEMM sweep. + +Heavy: drives matmul-composite across all shapes x variants through the +simulator (24 runs; the 512 shape alone is 2048 tiles). Marked ``slow`` so it +is excluded from the default ``pytest`` run (addopts: -m "not slow") and runs +on demand: + + pytest -m slow tests/gemm/test_gemm_sweep.py + +Delegates to scripts/gemm_sweep.py (the single source of the sweep logic) via +subprocess so there is no duplicated sim-driving code. +""" +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +import pytest + +from tests.gemm._gemm_plot_helpers import GEMM_SWEEP_JSON, ROOT + + +@pytest.mark.slow +def test_gemm_sweep_regenerates_json(): + script = ROOT / "scripts" / "gemm_sweep.py" + assert script.exists(), f"missing {script}" + proc = subprocess.run( + [sys.executable, str(script)], + cwd=str(ROOT), capture_output=True, text=True, + ) + assert proc.returncode == 0, ( + f"gemm_sweep.py failed (rc={proc.returncode})\n" + f"stdout:\n{proc.stdout[-2000:]}\nstderr:\n{proc.stderr[-2000:]}" + ) + assert Path(GEMM_SWEEP_JSON).exists() diff --git a/tests/gemm/test_plot_gemm_mac_utilization.py b/tests/gemm/test_plot_gemm_mac_utilization.py new file mode 100644 index 0000000..8a06cb3 --- /dev/null +++ b/tests/gemm/test_plot_gemm_mac_utilization.py @@ -0,0 +1,35 @@ +"""Emit the GEMM MAC-utilization bar charts. + +A measured chart (load_ref) plus the theoretical-vs-measured overlay (load_ref). +Reads docs/diagrams/gemm_sweep.json and writes gemm_mac_utilization*.png into +docs/diagrams/gemm_plots/. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from tests.gemm._gemm_plot_helpers import ( + GEMM_SWEEP_JSON, + emit_mac_utilization_measured, + emit_mac_utilization_theoretical_vs_measured, +) + + +@pytest.mark.skipif( + not GEMM_SWEEP_JSON.exists(), + reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first", +) +def test_plot_gemm_mac_utilization_measured(): + out = emit_mac_utilization_measured() + assert out is not None and Path(out).exists() + + +@pytest.mark.skipif( + not GEMM_SWEEP_JSON.exists(), + reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first", +) +def test_plot_gemm_mac_utilization_theoretical_vs_measured(): + out = emit_mac_utilization_theoretical_vs_measured() + assert out is not None and Path(out).exists() diff --git a/tests/gemm/test_plot_gemm_stage_breakdown.py b/tests/gemm/test_plot_gemm_stage_breakdown.py new file mode 100644 index 0000000..99664af --- /dev/null +++ b/tests/gemm/test_plot_gemm_stage_breakdown.py @@ -0,0 +1,24 @@ +"""Emit the GEMM per-stage engine wall-clock bar chart (load_ref). + +Reads docs/diagrams/gemm_sweep.json (run scripts/gemm_sweep.py to refresh it) +and writes gemm_stage_breakdown.png into docs/diagrams/gemm_plots/. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from tests.gemm._gemm_plot_helpers import ( + GEMM_SWEEP_JSON, + emit_stage_breakdown, +) + + +@pytest.mark.skipif( + not GEMM_SWEEP_JSON.exists(), + reason="gemm_sweep.json absent; run scripts/gemm_sweep.py first", +) +def test_plot_gemm_stage_breakdown(): + out = emit_stage_breakdown() + assert out is not None and Path(out).exists()