diff --git a/src/kernbench/benches/1H_milestone_output/gqa/gqa_comparison.png b/src/kernbench/benches/1H_milestone_output/gqa/gqa_comparison.png new file mode 100644 index 0000000..97aae13 Binary files /dev/null and b/src/kernbench/benches/1H_milestone_output/gqa/gqa_comparison.png differ diff --git a/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_decode.png b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_decode.png new file mode 100644 index 0000000..3c2aff5 Binary files /dev/null and b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_decode.png differ diff --git a/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_prefill.png b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_prefill.png new file mode 100644 index 0000000..25cb8b3 Binary files /dev/null and b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_multi_user_prefill.png differ diff --git a/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_decode.png b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_decode.png new file mode 100644 index 0000000..351c3eb Binary files /dev/null and b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_decode.png differ diff --git a/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_prefill.png b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_prefill.png new file mode 100644 index 0000000..2907472 Binary files /dev/null and b/src/kernbench/benches/1H_milestone_output/gqa/gqa_op_log_single_user_prefill.png differ diff --git a/src/kernbench/benches/milestone_gqa_llama70b.py b/src/kernbench/benches/milestone_gqa_llama70b.py index 7991a6f..86dbf54 100644 --- a/src/kernbench/benches/milestone_gqa_llama70b.py +++ b/src/kernbench/benches/milestone_gqa_llama70b.py @@ -196,6 +196,168 @@ def _run_panel(panel: str, topology: str) -> dict: } +# ── Figure renderers (sub-cycle 4c, 5 of 6 figures) ────────────────── +# +# Sixth figure ``gqa_scaling.png`` is deferred to after sub-cycle 4b +# lands the Q/cube ∈ {1, 2, 4} sweep on multi_user_* panels — it needs +# multiple sweep.json rows per multi_user panel to be meaningful. + +_OP_LOG_KEYS = ( + "gemm_count", + "ipcq_send_count", + "ipcq_recv_count", + "dma_read_count", + "dma_write_count", +) +_OP_LOG_DISPLAY = { + "gemm_count": "GEMM", + "ipcq_send_count": "IPCQ send", + "ipcq_recv_count": "IPCQ recv", + "dma_read_count": "DMA read", + "dma_write_count": "DMA write", +} +_OP_LOG_COLORS = { + "gemm_count": "#F59E0B", + "ipcq_send_count": "#3B82F6", + "ipcq_recv_count": "#10B981", + "dma_read_count": "#A855F7", + "dma_write_count": "#EF4444", +} +_PANEL_DISPLAY = { + "single_user_prefill": "single_user / prefill", + "multi_user_prefill": "multi_user / prefill", + "single_user_decode": "single_user / decode", + "multi_user_decode": "multi_user / decode", +} + + +def _load_sweep_data(sweep_json: Path | str) -> dict: + sweep_json = Path(sweep_json) + if not sweep_json.exists(): + return {"rows": [], "config": {}, "panels": []} + return json.loads(sweep_json.read_text()) + + +def _row_for(rows: list, panel: str) -> dict | None: + for r in rows: + if r.get("panel") == panel: + return r + return None + + +def emit_panel_op_log_summary( + panel: str, + sweep_json: Path | str = _SWEEP_JSON, + out_dir: Path | str = _OUTPUT_DIR, +) -> str | None: + """One bar chart of the 5 op_log counts for ``panel``. + + Returns the written PNG path, or ``None`` when sweep.json is empty + or the requested panel is absent. + """ + import matplotlib.pyplot as plt + + data = _load_sweep_data(sweep_json) + row = _row_for(data.get("rows", []), panel) + if row is None: + return None + summary = row.get("op_log_summary", {}) + n_ranks = row.get("n_ranks") + + labels = [_OP_LOG_DISPLAY[k] for k in _OP_LOG_KEYS] + values = [summary.get(k, 0) for k in _OP_LOG_KEYS] + colors = [_OP_LOG_COLORS[k] for k in _OP_LOG_KEYS] + + fig, ax = plt.subplots(figsize=(8, 5)) + bars = ax.bar(labels, values, color=colors) + for b, v in zip(bars, values): + ax.text(b.get_x() + b.get_width() / 2, b.get_height(), + f"{int(v)}", ha="center", va="bottom", fontsize=9) + ax.set_title( + f"{_PANEL_DISPLAY.get(panel, panel)} (n_ranks={n_ranks})", + fontsize=12, fontweight="bold", + ) + ax.set_ylabel("count") + ax.grid(True, axis="y", alpha=0.3) + fig.tight_layout() + + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out = out_dir / f"gqa_op_log_{panel}.png" + fig.savefig(out, dpi=120) + plt.close(fig) + return str(out) + + +def emit_gqa_comparison( + sweep_json: Path | str = _SWEEP_JSON, + out_dir: Path | str = _OUTPUT_DIR, +) -> str | None: + """Grouped-bar chart comparing the 5 op_log counts across all panels.""" + import matplotlib.pyplot as plt + import numpy as np + + data = _load_sweep_data(sweep_json) + panels_in = data.get("panels") or list(_PANELS_V1) + rows = data.get("rows", []) + panels = [p for p in panels_in if _row_for(rows, p) is not None] + if not panels: + return None + + n_groups = len(panels) + n_series = len(_OP_LOG_KEYS) + x = np.arange(n_groups) + width = 0.8 / n_series + + fig, ax = plt.subplots(figsize=(11, 6)) + for i, key in enumerate(_OP_LOG_KEYS): + offset = (i - (n_series - 1) / 2) * width + vals = [_row_for(rows, p)["op_log_summary"].get(key, 0) + for p in panels] + ax.bar(x + offset, vals, width, + label=_OP_LOG_DISPLAY[key], color=_OP_LOG_COLORS[key]) + + ax.set_xticks(x) + ax.set_xticklabels( + [f"{_PANEL_DISPLAY.get(p, p)}\n(n_ranks={_row_for(rows, p)['n_ranks']})" + for p in panels], + fontsize=8, + ) + ax.set_ylabel("count") + ax.set_title("GQA Llama-70B — op_log summary across panels", + fontsize=13, fontweight="bold") + ax.legend(fontsize=8, loc="upper right") + ax.grid(True, axis="y", alpha=0.3) + fig.tight_layout() + + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out = out_dir / "gqa_comparison.png" + fig.savefig(out, dpi=120) + plt.close(fig) + return str(out) + + +def emit_all_gqa_plots( + sweep_json: Path | str = _SWEEP_JSON, + out_dir: Path | str = _OUTPUT_DIR, +) -> list[str]: + """Render all 5 in-scope figures and return the written paths. + + Sub-cycle 4c v1 emits 5 of the 6 figures ADR-0057 D3 lists; the + 6th (gqa_scaling.png) needs sub-cycle 4b's Q/cube sweep data. + """ + paths: list[str] = [] + for panel in _PANELS_V1: + p = emit_panel_op_log_summary(panel, sweep_json, out_dir) + if p is not None: + paths.append(p) + comp = emit_gqa_comparison(sweep_json, out_dir) + if comp is not None: + paths.append(comp) + return paths + + # ── Bench entry ──────────────────────────────────────────────────────── @@ -204,42 +366,58 @@ def _run_panel(panel: str, topology: str) -> dict: description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).", ) def run(torch) -> None: - """Drive the four GQA panels at validation scale; write sweep.json. + """Drive the four GQA panels at validation scale; write sweep.json and figures. - v1 only supports validation mode (``GQA_VALIDATION=1``). Headline - mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3. + Modes (mutually exclusive): + MILESTONE_FAST=1 Skip the sweep; re-render figures from the + committed sweep.json. Seconds, no simulator. + GQA_VALIDATION=1 Run the four-panel validation sweep + figures. + ~1-2h on the full simulator. + + Headline-scale mode is deferred to sub-cycle 4c (figures landed + here; headline-scale + scaling figure await sub-cycle 4b). A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4 - "at least one request" contract is satisfied even though the panels - drive their own engines. + "at least one request" contract is satisfied even when the panels + are skipped via MILESTONE_FAST=1. """ - if not os.environ.get("GQA_VALIDATION"): - raise RuntimeError( - "milestone-gqa-llama70b v1 only supports validation mode. " - "Set GQA_VALIDATION=1 to run. Headline mode is deferred to " - "sub-cycle 4b/4c per ADR-0057 D3." - ) _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml") + fast = bool(os.environ.get("MILESTONE_FAST")) + if not fast and not os.environ.get("GQA_VALIDATION"): + raise RuntimeError( + "milestone-gqa-llama70b v1 needs GQA_VALIDATION=1 (run the " + "sweep) or MILESTONE_FAST=1 (reuse committed sweep.json). " + "Headline mode is deferred to sub-cycle 4b/4c per ADR-0057 D3." + ) - rows = [_run_panel(panel, topology) for panel in _PANELS_V1] + if not fast: + topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml") + rows = [_run_panel(panel, topology) for panel in _PANELS_V1] + sweep = { + "version": 1, + "validation_scale": True, + "panels": list(_PANELS_V1), + "config": { + "S_q_prefill": _S_Q_PREFILL, + "S_kv_per_rank": _S_KV_PER_RANK, + "h_q": _H_Q, + "h_kv": _H_KV, + "d_head": _D_HEAD, + "n_ranks_single_user": _N_RANKS_SINGLE_USER, + "n_ranks_multi_user": _N_RANKS_MULTI_USER, + }, + "rows": rows, + } + _SWEEP_JSON.write_text(json.dumps(sweep, indent=2)) + print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}") + elif not _SWEEP_JSON.exists(): + raise RuntimeError( + f"MILESTONE_FAST=1 requires {_SWEEP_JSON} to exist; " + "run with GQA_VALIDATION=1 once to seed it." + ) - sweep = { - "version": 1, - "validation_scale": True, - "panels": list(_PANELS_V1), - "config": { - "S_q_prefill": _S_Q_PREFILL, - "S_kv_per_rank": _S_KV_PER_RANK, - "h_q": _H_Q, - "h_kv": _H_KV, - "d_head": _D_HEAD, - "n_ranks_single_user": _N_RANKS_SINGLE_USER, - "n_ranks_multi_user": _N_RANKS_MULTI_USER, - }, - "rows": rows, - } - _SWEEP_JSON.write_text(json.dumps(sweep, indent=2)) - print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}") + paths = emit_all_gqa_plots() + print(f" milestone-gqa-llama70b: {len(paths)} figures -> {_OUTPUT_DIR} " + f"(fast={fast})") # Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out). torch.zeros( diff --git a/tests/gqa/_gqa_plot_helpers.py b/tests/gqa/_gqa_plot_helpers.py new file mode 100644 index 0000000..19cf155 --- /dev/null +++ b/tests/gqa/_gqa_plot_helpers.py @@ -0,0 +1,25 @@ +"""Thin re-export shim for the GQA figure tests. + +Not a test module (no ``test_`` prefix → pytest does not collect it). + +Mirrors ``tests/gemm/_gemm_plot_helpers.py``. The renderer logic lives in +``kernbench.benches.milestone_gqa_llama70b`` (production single home, +ADR-0054). Defaults still target the bench's ``_OUTPUT_DIR``. +""" +from __future__ import annotations + +from kernbench.benches.milestone_gqa_llama70b import ( + _OUTPUT_DIR as GQA_PLOTS_DIR, + _SWEEP_JSON as GQA_SWEEP_JSON, + emit_all_gqa_plots, + emit_gqa_comparison, + emit_panel_op_log_summary, +) + +__all__ = [ + "GQA_PLOTS_DIR", + "GQA_SWEEP_JSON", + "emit_all_gqa_plots", + "emit_gqa_comparison", + "emit_panel_op_log_summary", +] diff --git a/tests/gqa/test_plot_gqa_figures.py b/tests/gqa/test_plot_gqa_figures.py new file mode 100644 index 0000000..7f1e9d2 --- /dev/null +++ b/tests/gqa/test_plot_gqa_figures.py @@ -0,0 +1,109 @@ +"""Phase 1 spec test for GQA figure renderers (sub-cycle 4c). + +ADR-0057 D3 sub-cycle 4c adds 6 figure renderers; this test pins the +5 of 6 that don't depend on sub-cycle 4b's Q/cube sweep: + + - 4 per-panel op_log_summary PNGs (one per panel of v1's sweep.json) + - 1 cross-panel ``gqa_comparison.png`` (4-panel grouped bars over the + 5 op_log_summary counts: gemm, ipcq_send, ipcq_recv, dma_read, dma_write) + +The 6th, ``gqa_scaling.png``, needs the Q/cube ∈ {1, 2, 4} sweep from +sub-cycle 4b and is deferred. + +Each test depends on the committed +``benches/1H_milestone_output/gqa/sweep.json`` (landed in commit +``e748a62``); they assert the renderer writes a non-empty PNG at the +expected path. + +Phase 1 expectation: tests fail at import (renderer functions don't +exist yet on the bench module). Phase 2 lands them and the tests +turn green. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from tests.gqa._gqa_plot_helpers import ( + GQA_PLOTS_DIR, + GQA_SWEEP_JSON, + emit_all_gqa_plots, + emit_gqa_comparison, + emit_panel_op_log_summary, +) + + +_PANELS = ( + "single_user_prefill", + "multi_user_prefill", + "single_user_decode", + "multi_user_decode", +) + + +@pytest.mark.skipif( + not GQA_SWEEP_JSON.exists(), + reason="gqa sweep.json absent; run milestone-gqa-llama70b first", +) +@pytest.mark.parametrize("panel", _PANELS) +def test_emit_panel_op_log_summary_writes_png_for_each_panel(panel): + out = emit_panel_op_log_summary(panel) + assert out is not None, f"{panel}: renderer returned None" + path = Path(out) + assert path.exists(), f"{panel}: expected PNG at {path}" + assert path.suffix == ".png", f"{panel}: not a PNG: {path}" + assert path.stat().st_size > 0, f"{panel}: empty PNG: {path}" + assert panel in path.stem, ( + f"{panel}: panel name not in filename {path.name}" + ) + + +@pytest.mark.skipif( + not GQA_SWEEP_JSON.exists(), + reason="gqa sweep.json absent; run milestone-gqa-llama70b first", +) +def test_emit_gqa_comparison_writes_png(): + out = emit_gqa_comparison() + assert out is not None + path = Path(out) + assert path.exists() + assert path.name == "gqa_comparison.png" + assert path.stat().st_size > 0 + + +@pytest.mark.skipif( + not GQA_SWEEP_JSON.exists(), + reason="gqa sweep.json absent; run milestone-gqa-llama70b first", +) +def test_emit_all_gqa_plots_writes_five_figures(): + """emit_all returns a list of 5 written PNG paths (deferring the + 6th gqa_scaling.png to after sub-cycle 4b lands the Q/cube sweep).""" + paths = emit_all_gqa_plots() + assert isinstance(paths, list) + # 4 per-panel + 1 comparison. + assert len(paths) == 5, f"expected 5 PNGs, got {len(paths)}: {paths}" + for p in paths: + assert Path(p).exists() and Path(p).stat().st_size > 0 + names = {Path(p).name for p in paths} + assert "gqa_comparison.png" in names + for panel in _PANELS: + assert any(panel in n for n in names), ( + f"no per-panel PNG for {panel} in {names}" + ) + + +def test_emit_all_gqa_plots_output_dir_matches_bench_output_dir(): + """The renderers must write under the bench's own _OUTPUT_DIR so + MILESTONE_FAST=1 reuse (and committed baselines) all point at the + same on-disk location.""" + # Stub assertion that fails until emit_all_gqa_plots exists with a + # default ``out_dir`` argument identical to GQA_PLOTS_DIR. + import inspect + + sig = inspect.signature(emit_all_gqa_plots) + assert "out_dir" in sig.parameters + default = sig.parameters["out_dir"].default + assert Path(default) == GQA_PLOTS_DIR, ( + f"default out_dir {default} != bench _OUTPUT_DIR {GQA_PLOTS_DIR}" + )