"""Shared plotting plumbing for the GEMM figure tests. Not a test module (no ``test_`` prefix -> pytest does not collect it). Reads the committed ``docs/diagrams/gemm_sweep.json`` (produced by the heavy ``scripts/gemm_sweep.py`` sim sweep) and renders matplotlib PNGs into ``docs/diagrams/gemm_plots/``. No simulation here -> the figure tests are fast and run by default; regenerating the underlying data stays a manual script. Chart set (mirrors the GEMM MAC slides in scripts/build_overview_slides.py): - stage breakdown (load_ref operand staging) - MAC utilization — measured (load_ref) - MAC utilization — theoretical vs measured (load_ref) """ from __future__ import annotations import json from pathlib import Path ROOT = Path(__file__).resolve().parent.parent.parent GEMM_SWEEP_JSON = ROOT / "docs" / "diagrams" / "gemm_sweep.json" GEMM_PLOTS_DIR = ROOT / "docs" / "diagrams" / "gemm_plots" # Shapes excluded from the figures (mirrors build_overview_slides). EXCLUDED_SHAPES = {(512, 512, 512)} # Stage bars shown (raw op_log stage_type keys) + display names + colors. STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"] STAGE_DISPLAY = { "DMA_READ": "DMA in", "FETCH": "Fetch", "GEMM": "GEMM", "DMA_WRITE": "DMA out", } STAGE_COLORS = { "DMA_READ": "#3B82F6", "FETCH": "#10B981", "GEMM": "#F59E0B", "DMA_WRITE": "#A855F7", } # MAC-utilization model constants (mirror build_overview_slides). _HBM_GBS = 256.0 _BPE = 2 _T_STAGE = 16.0 _D_STAGES = 3 _PLOT_VARIANT = "load_ref" def _load_sweep_data() -> dict: if not GEMM_SWEEP_JSON.exists(): return {"rows": []} data = json.loads(GEMM_SWEEP_JSON.read_text()) data["rows"] = [ r for r in data.get("rows", []) if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES ] return data def _shape_label(r: dict) -> str: if r["M"] == r["K"] == r["N"]: return f"M=K=N={r['M']}" return f"M={r['M']} K={r['K']} N={r['N']}" def _under_tile(M, K, N, tile_M, tile_K, tile_N) -> bool: return M < tile_M or K < tile_K or N < tile_N def _xtick_labels(shape_labels, tile_counts, flagged) -> list[str]: out = [] for lbl, tc, fl in zip(shape_labels, tile_counts, flagged): s = f"{lbl}\n({tc} tiles)" if fl: s += " *" out.append(s) return out def _grouped_bar_png( out_name: str, *, title: str, subtitle: str | None, shape_labels, tile_counts, flagged, series: dict, colors: dict, y_label: str, threshold: float | None = None, footnote: str | None = None, ) -> str: """Render one grouped-bar chart to GEMM_PLOTS_DIR/out_name; return the path.""" import matplotlib.pyplot as plt import numpy as np n_groups = len(shape_labels) n_series = max(1, len(series)) x = np.arange(n_groups) width = 0.8 / n_series fig, ax = plt.subplots(figsize=(11, 6)) for i, (name, vals) in enumerate(series.items()): offset = (i - (n_series - 1) / 2) * width ax.bar(x + offset, vals, width, label=name, color=colors.get(name)) ax.set_xticks(x) ax.set_xticklabels( _xtick_labels(shape_labels, tile_counts, flagged), fontsize=8, ) ax.set_ylabel(y_label) ax.set_title(title, fontsize=13, fontweight="bold") if subtitle: ax.text(0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", va="bottom", fontsize=8, color="#475569") if threshold is not None: ax.axhline(threshold, ls="--", color="gray", lw=1.0) ax.legend(fontsize=8, loc="upper right") ax.grid(True, axis="y", alpha=0.3) caption = "* = under-tile shape (M str | None: """Per-stage engine wall-clock per shape (load_ref operand staging).""" data = _load_sweep_data() rows = [r for r in data["rows"] if r.get("variant") == _PLOT_VARIANT] if not rows: return None tile = data["tile_sizes"] shape_labels = [_shape_label(r) for r in rows] flagged = [_under_tile(r["M"], r["K"], r["N"], tile["M"], tile["K"], tile["N"]) for r in rows] tile_counts = [r["tile_count_expected"] for r in rows] series = { STAGE_DISPLAY[s]: [r.get("stages", {}).get(s, {}).get("wall_ns", 0.0) for r in rows] for s in STAGE_KEYS } colors = {STAGE_DISPLAY[s]: STAGE_COLORS[s] for s in STAGE_KEYS} return _grouped_bar_png( "gemm_stage_breakdown.png", title="GEMM stage breakdown", subtitle=(f"Per-stage engine wall-clock (DMA in / Fetch / GEMM / " f"DMA out), {_PLOT_VARIANT} staging. " f"Tile {tile['M']}x{tile['K']}x{tile['N']}."), shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, series=series, colors=colors, y_label="ns", footnote="Bars = engine wall-clock interval (merged overlaps).", ) def emit_mac_utilization_measured() -> str | None: """GEMM util % and useful pipeline-eff % (analytical model, load_ref).""" data = _load_sweep_data() rows = data["rows"] if not rows: return None tile = data["tile_sizes"] TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] tile_flops = 2 * TILE_M * TILE_K * TILE_N dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS head_ns = (_D_STAGES - 1) * _T_STAGE by_shape = {(r["M"], r["K"], r["N"]): r for r in rows if r["variant"] == _PLOT_VARIANT} shapes = list(by_shape) if not shapes: return None shape_labels = [_shape_label(by_shape[k]) for k in shapes] flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes] tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] gemm_util, useful_eff = [], [] for k in shapes: r = by_shape[k] M, K, N = r["M"], r["K"], r["N"] useful = 2 * M * K * N tiles = r["tile_count_expected"] gu = useful / (tile_flops * tiles) * 100 gemm_util.append(gu) m_tiles = (M + TILE_M - 1) // TILE_M n_tiles = (N + TILE_N - 1) // TILE_N n_mn = m_tiles * n_tiles compute_total = tiles * _T_STAGE wall = head_ns + tiles * _T_STAGE + max(0, n_mn - 1) * dma_w_per_pair ueff = (compute_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0 useful_eff.append(ueff) series = {"GEMM util %": gemm_util, "Useful eff %": useful_eff} colors = {"GEMM util %": "#10B981", "Useful eff %": "#F59E0B"} return _grouped_bar_png( "gemm_mac_utilization_measured.png", title="GEMM MAC utilization — load_ref", subtitle=("GEMM util = useful FLOPs / (tile FLOPs x tiles); " "Useful eff = GEMM util x ideal pipeline efficiency."), shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, series=series, colors=colors, y_label="%", threshold=100.0, footnote="Theoretical ideal-pipeline model (not simulator data).", ) def emit_mac_utilization_theoretical_vs_measured() -> str | None: """Theoretical vs simulator-measured GEMM util / useful eff (load_ref).""" data = _load_sweep_data() rows = data["rows"] if not rows: return None tile = data["tile_sizes"] TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] tile_flops = 2 * TILE_M * TILE_K * TILE_N dma_w_per_pair = (TILE_M * TILE_N * _BPE) / _HBM_GBS head_ns = (_D_STAGES - 1) * _T_STAGE peak_per_ns = tile_flops / _T_STAGE by_shape = {(r["M"], r["K"], r["N"]): r for r in rows if r["variant"] == _PLOT_VARIANT} shapes = list(by_shape) if not shapes: return None shape_labels = [_shape_label(by_shape[k]) for k in shapes] flagged = [_under_tile(*k, TILE_M, TILE_K, TILE_N) for k in shapes] tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] gu_t, gu_m, eff_t, eff_m = [], [], [], [] for k in shapes: r = by_shape[k] M, K, N = r["M"], r["K"], r["N"] useful = 2 * M * K * N tiles = r["tile_count_expected"] gut = useful / (tile_flops * tiles) gu_t.append(gut * 100) rec = r.get("stages", {}).get("GEMM", {}).get("record_count", 0) or tiles gu_m.append((useful / (tile_flops * rec) * 100) if rec else 0.0) m_tiles = (M + TILE_M - 1) // TILE_M n_tiles = (N + TILE_N - 1) // TILE_N n_mn = m_tiles * n_tiles compute_total = tiles * _T_STAGE wall_t = head_ns + compute_total + max(0, n_mn - 1) * dma_w_per_pair eff_t.append((compute_total * gut / wall_t * 100) if wall_t > 0 else 0.0) cw = r.get("composite_window_ns", 0.0) or 0.0 eff_m.append((useful / cw / peak_per_ns * 100) if cw > 0 else 0.0) series = { "GEMM util % (theoretical)": gu_t, "GEMM util % (measured)": gu_m, "Theoretical eff %": eff_t, "Measured eff %": eff_m, } colors = { "GEMM util % (theoretical)": "#10B981", "GEMM util % (measured)": "#6EE7B7", "Theoretical eff %": "#F59E0B", "Measured eff %": "#3B82F6", } return _grouped_bar_png( "gemm_mac_utilization_theoretical_vs_measured.png", title="GEMM MAC utilization — theoretical vs measured (load_ref)", subtitle=("theoretical model vs simulator op_log; agreement " "validates the analytical pipeline model."), shape_labels=shape_labels, tile_counts=tile_counts, flagged=flagged, series=series, colors=colors, y_label="%", threshold=100.0, ) def emit_all_gemm_plots() -> list[str]: """Render every GEMM figure that has data; return the list of paths written.""" paths = [] for fn in (emit_stage_breakdown, emit_mac_utilization_measured, emit_mac_utilization_theoretical_vs_measured): p = fn() if p: paths.append(p) return paths