2 Commits

Author SHA1 Message Date
mukesh b3ca532023 attention: milestone-gqa-llama70b figures + MILESTONE_FAST (sub-cycle 4c, 5/6)
Add 5 of the 6 figure renderers ADR-0057 D3 sub-cycle 4c specifies:
  - gqa_op_log_{panel}.png × 4 — per-panel bar chart of the 5 op_log
    counts (gemm, ipcq_send, ipcq_recv, dma_read, dma_write).
  - gqa_comparison.png — cross-panel grouped bars over the same 5 series.

Sixth figure (gqa_scaling.png) depends on sub-cycle 4b's Q/cube ∈
{1, 2, 4} sweep on multi_user_* panels and is deferred until that
data exists; emit_all_gqa_plots returns just the 5 in-scope paths.

Add MILESTONE_FAST=1 mode to run(): skip the panel sweep, reuse the
committed sweep.json, render figures only. Validation mode unchanged.
The runtime errors clearly when neither env var is set, listing the
two supported modes.

Renderers live in the bench module (the milestone-1h-gemm pattern);
tests/gqa/_gqa_plot_helpers.py re-exports them for figure tests.

Tests: tests/gqa/test_plot_gqa_figures.py — 7 tests, all green:
  - 4 parametrized per-panel emit assertions
  - 1 comparison emit assertion
  - 1 emit_all returns exactly 5 PNG paths
  - 1 default out_dir matches the bench _OUTPUT_DIR

Commits the 5 PNG baselines under the bench output dir alongside
sweep.json, mirroring milestone-1h-gemm's committed-figures pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 22:23:28 -07:00
mukesh e748a62264 attention: land milestone-gqa-llama70b 4-panel sweep bench (ADR-0057 v1)
Self-contained eval bench (ADR-0054) that drives the four GQA Llama-70B
panels through run_bench with enable_data=True at validation scale and
emits sweep.json with the v1 schema (ADR-0057 D7).

Panel dispatch table maps each panel to (kernel, SFR install, S_q,
n_ranks, rank_axis):
  single_user_prefill   mesh_kv_kernel,  intracube_pe_ring,  S_q=16, n=8, rank_axis=0
  multi_user_prefill    mesh_kv_kernel,  intercube_multisip, S_q=16, n=4, rank_axis=1
  single_user_decode    mesh_mlo_kernel, intracube_pe_ring,  S_q=1,  n=8, rank_axis=0
  multi_user_decode     mesh_mlo_kernel, intercube_multisip, S_q=1,  n=4, rank_axis=1

multi_user panels pass _auto_dim_remap=False (avoid d_head=64
colliding with K's global M=64) and rank_axis=1 (cube-level ring,
gates 7 of every 8 PEs to silence).

Each panel runs on a fresh per-config GraphEngine, then op_log is
summarized into gemm/dma/ipcq counts. Both decode panels emit exactly
2*n_ranks GEMMs (one-shot partial attention per rank, ADR-0056 D3).

v1 supports GQA_VALIDATION=1 only; headline mode + figures deferred to
sub-cycles 4b/4c. Sentinel tensor satisfies the run_bench
"at least one request" contract (ADR-0045 D4 / ADR-0054 D2 carve-out).

Tests: tests/attention/test_milestone_gqa_llama70b.py — all 12 pass.
Includes committed sweep.json baseline at the bench's _OUTPUT_DIR so
subsequent test runs reuse it instead of re-simulating.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 21:57:12 -07:00
10 changed files with 848 additions and 0 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

@@ -0,0 +1,65 @@
{
"version": 1,
"validation_scale": true,
"panels": [
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode"
],
"config": {
"S_q_prefill": 16,
"S_kv_per_rank": 16,
"h_q": 1,
"h_kv": 1,
"d_head": 64,
"n_ranks_single_user": 8,
"n_ranks_multi_user": 4
},
"rows": [
{
"panel": "single_user_prefill",
"n_ranks": 8,
"op_log_summary": {
"gemm_count": 128,
"ipcq_send_count": 112,
"ipcq_recv_count": 112,
"dma_read_count": 24,
"dma_write_count": 8
}
},
{
"panel": "multi_user_prefill",
"n_ranks": 4,
"op_log_summary": {
"gemm_count": 32,
"ipcq_send_count": 24,
"ipcq_recv_count": 24,
"dma_read_count": 12,
"dma_write_count": 4
}
},
{
"panel": "single_user_decode",
"n_ranks": 8,
"op_log_summary": {
"gemm_count": 16,
"ipcq_send_count": 168,
"ipcq_recv_count": 168,
"dma_read_count": 24,
"dma_write_count": 8
}
},
{
"panel": "multi_user_decode",
"n_ranks": 4,
"op_log_summary": {
"gemm_count": 8,
"ipcq_send_count": 36,
"ipcq_recv_count": 36,
"dma_read_count": 12,
"dma_write_count": 4
}
}
]
}
@@ -0,0 +1,427 @@
"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1).
Self-contained eval bench (ADR-0054). Drives the four panels of the GQA
Llama-70B sharding study through ``run_bench`` with ``enable_data=True``,
harvests op_log summaries, and writes JSON into
``benches/1H_milestone_output/gqa/sweep.json``.
v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale:
Panel name in JSON / test Study label SFR install used
─────────────────────────────────────────────────────────────────────
single_user_prefill TL configure_sfr_intracube_pe_ring
multi_user_prefill TR configure_sfr_intercube_multisip
single_user_decode BL configure_sfr_intracube_pe_ring
multi_user_decode BR configure_sfr_intercube_multisip
Kernels use the mesh-native variants (ADR-0059), invoked with the
``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user
cube-level rings — the latter also gates 7 of every 8 PEs to silence).
Validation-scale config (ADR-0057 D4) — kept small so the simulator's
1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
from kernbench.benches.registry import bench
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import (
configure_sfr_intercube_multisip,
configure_sfr_intracube_pe_ring,
)
from kernbench.policy.placement.dp import DPPolicy
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa"
_SWEEP_JSON = _OUTPUT_DIR / "sweep.json"
# ── Validation-scale config (ADR-0057 D4) ─────────────────────────────
_S_Q_PREFILL = 16
_S_Q_DECODE = 1
_S_KV_PER_RANK = 16
_H_Q = 1
_H_KV = 1
_D_HEAD = 64
_N_RANKS_SINGLE_USER = 8
_N_RANKS_MULTI_USER = 4
_DTYPE = "f16"
_PANELS_V1 = (
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode",
)
# Panel → (kernel, SFR install, S_q, n_ranks, rank_axis)
_PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = {
"single_user_prefill": (
attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring,
_S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0,
),
"multi_user_prefill": (
attention_mesh_kv_kernel, configure_sfr_intercube_multisip,
_S_Q_PREFILL, _N_RANKS_MULTI_USER, 1,
),
"single_user_decode": (
attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring,
_S_Q_DECODE, _N_RANKS_SINGLE_USER, 0,
),
"multi_user_decode": (
attention_mesh_mlo_kernel, configure_sfr_intercube_multisip,
_S_Q_DECODE, _N_RANKS_MULTI_USER, 1,
),
}
# ── Per-panel bench fn ─────────────────────────────────────────────────
def _make_bench_fn(panel: str):
kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel]
is_multi_user = panel.startswith("multi_user_")
def _bench_fn(ctx):
sfr_install(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
if is_multi_user:
dp_full = DPPolicy(
cube="replicate", pe="replicate",
num_cubes=n_ranks, num_pes=8,
)
dp_kv = DPPolicy(
cube="row_wise", pe="replicate",
num_cubes=n_ranks, num_pes=8,
)
else:
dp_full = DPPolicy(
cube="replicate", pe="replicate",
num_cubes=1, num_pes=n_ranks,
)
dp_kv = DPPolicy(
cube="replicate", pe="row_wise",
num_cubes=1, num_pes=n_ranks,
)
q = ctx.zeros((S_q, _H_Q * _D_HEAD),
dtype=_DTYPE, dp=dp_full, name=f"{panel}_q")
k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k")
v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v")
o = ctx.empty((S_q, _H_Q * _D_HEAD),
dtype=_DTYPE, dp=dp_full, name=f"{panel}_o")
# rank_axis is a positional arg; _auto_dim_remap=False keeps
# d_head=64 from colliding with the multi_user K's global M=64.
ctx.launch(
f"{panel}_mesh", kernel,
q, k, v, o,
S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks,
rank_axis,
_auto_dim_remap=False,
)
return _bench_fn
# ── Op-log summary harvest ─────────────────────────────────────────────
def _summarize_op_log(op_log) -> dict[str, int]:
"""Counts per ADR-0057 D7 op_log_summary contract."""
gemm_count = 0
ipcq_send_count = 0
ipcq_recv_count = 0
dma_read_count = 0
dma_write_count = 0
for r in op_log:
if r.op_kind == "gemm":
gemm_count += 1
elif r.op_name == "dma_read":
dma_read_count += 1
elif r.op_name == "dma_write":
dma_write_count += 1
elif r.op_name == "ipcq_send":
ipcq_send_count += 1
elif r.op_name == "ipcq_recv":
ipcq_recv_count += 1
elif r.op_name == "ipcq_copy":
# The inbound DMA records ipcq_copy (one per send/recv pair).
# Count it as both a send and a recv side so the row's
# ipcq_send_count and ipcq_recv_count are non-zero even when
# the engine logs the collective via the inbound copy alone.
ipcq_send_count += 1
ipcq_recv_count += 1
return {
"gemm_count": gemm_count,
"ipcq_send_count": ipcq_send_count,
"ipcq_recv_count": ipcq_recv_count,
"dma_read_count": dma_read_count,
"dma_write_count": dma_write_count,
}
def _run_panel(panel: str, topology: str) -> dict:
"""Run one panel via a fresh engine; return its row dict."""
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology(topology)
result = run_bench(
topology=topo, bench_fn=_make_bench_fn(panel),
device=resolve_device(None),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True,
),
)
if not result.completion.ok:
raise RuntimeError(
f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}"
)
_, _, _, n_ranks, _ = _PANEL_DISPATCH[panel]
return {
"panel": panel,
"n_ranks": n_ranks,
"op_log_summary": _summarize_op_log(result.engine.op_log),
}
# ── Figure renderers (sub-cycle 4c, 5 of 6 figures) ──────────────────
#
# Sixth figure ``gqa_scaling.png`` is deferred to after sub-cycle 4b
# lands the Q/cube ∈ {1, 2, 4} sweep on multi_user_* panels — it needs
# multiple sweep.json rows per multi_user panel to be meaningful.
_OP_LOG_KEYS = (
"gemm_count",
"ipcq_send_count",
"ipcq_recv_count",
"dma_read_count",
"dma_write_count",
)
_OP_LOG_DISPLAY = {
"gemm_count": "GEMM",
"ipcq_send_count": "IPCQ send",
"ipcq_recv_count": "IPCQ recv",
"dma_read_count": "DMA read",
"dma_write_count": "DMA write",
}
_OP_LOG_COLORS = {
"gemm_count": "#F59E0B",
"ipcq_send_count": "#3B82F6",
"ipcq_recv_count": "#10B981",
"dma_read_count": "#A855F7",
"dma_write_count": "#EF4444",
}
_PANEL_DISPLAY = {
"single_user_prefill": "single_user / prefill",
"multi_user_prefill": "multi_user / prefill",
"single_user_decode": "single_user / decode",
"multi_user_decode": "multi_user / decode",
}
def _load_sweep_data(sweep_json: Path | str) -> dict:
sweep_json = Path(sweep_json)
if not sweep_json.exists():
return {"rows": [], "config": {}, "panels": []}
return json.loads(sweep_json.read_text())
def _row_for(rows: list, panel: str) -> dict | None:
for r in rows:
if r.get("panel") == panel:
return r
return None
def emit_panel_op_log_summary(
panel: str,
sweep_json: Path | str = _SWEEP_JSON,
out_dir: Path | str = _OUTPUT_DIR,
) -> str | None:
"""One bar chart of the 5 op_log counts for ``panel``.
Returns the written PNG path, or ``None`` when sweep.json is empty
or the requested panel is absent.
"""
import matplotlib.pyplot as plt
data = _load_sweep_data(sweep_json)
row = _row_for(data.get("rows", []), panel)
if row is None:
return None
summary = row.get("op_log_summary", {})
n_ranks = row.get("n_ranks")
labels = [_OP_LOG_DISPLAY[k] for k in _OP_LOG_KEYS]
values = [summary.get(k, 0) for k in _OP_LOG_KEYS]
colors = [_OP_LOG_COLORS[k] for k in _OP_LOG_KEYS]
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(labels, values, color=colors)
for b, v in zip(bars, values):
ax.text(b.get_x() + b.get_width() / 2, b.get_height(),
f"{int(v)}", ha="center", va="bottom", fontsize=9)
ax.set_title(
f"{_PANEL_DISPLAY.get(panel, panel)} (n_ranks={n_ranks})",
fontsize=12, fontweight="bold",
)
ax.set_ylabel("count")
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out = out_dir / f"gqa_op_log_{panel}.png"
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)
def emit_gqa_comparison(
sweep_json: Path | str = _SWEEP_JSON,
out_dir: Path | str = _OUTPUT_DIR,
) -> str | None:
"""Grouped-bar chart comparing the 5 op_log counts across all panels."""
import matplotlib.pyplot as plt
import numpy as np
data = _load_sweep_data(sweep_json)
panels_in = data.get("panels") or list(_PANELS_V1)
rows = data.get("rows", [])
panels = [p for p in panels_in if _row_for(rows, p) is not None]
if not panels:
return None
n_groups = len(panels)
n_series = len(_OP_LOG_KEYS)
x = np.arange(n_groups)
width = 0.8 / n_series
fig, ax = plt.subplots(figsize=(11, 6))
for i, key in enumerate(_OP_LOG_KEYS):
offset = (i - (n_series - 1) / 2) * width
vals = [_row_for(rows, p)["op_log_summary"].get(key, 0)
for p in panels]
ax.bar(x + offset, vals, width,
label=_OP_LOG_DISPLAY[key], color=_OP_LOG_COLORS[key])
ax.set_xticks(x)
ax.set_xticklabels(
[f"{_PANEL_DISPLAY.get(p, p)}\n(n_ranks={_row_for(rows, p)['n_ranks']})"
for p in panels],
fontsize=8,
)
ax.set_ylabel("count")
ax.set_title("GQA Llama-70B — op_log summary across panels",
fontsize=13, fontweight="bold")
ax.legend(fontsize=8, loc="upper right")
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out = out_dir / "gqa_comparison.png"
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)
def emit_all_gqa_plots(
sweep_json: Path | str = _SWEEP_JSON,
out_dir: Path | str = _OUTPUT_DIR,
) -> list[str]:
"""Render all 5 in-scope figures and return the written paths.
Sub-cycle 4c v1 emits 5 of the 6 figures ADR-0057 D3 lists; the
6th (gqa_scaling.png) needs sub-cycle 4b's Q/cube sweep data.
"""
paths: list[str] = []
for panel in _PANELS_V1:
p = emit_panel_op_log_summary(panel, sweep_json, out_dir)
if p is not None:
paths.append(p)
comp = emit_gqa_comparison(sweep_json, out_dir)
if comp is not None:
paths.append(comp)
return paths
# ── Bench entry ────────────────────────────────────────────────────────
@bench(
name="milestone-gqa-llama70b",
description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).",
)
def run(torch) -> None:
"""Drive the four GQA panels at validation scale; write sweep.json and figures.
Modes (mutually exclusive):
MILESTONE_FAST=1 Skip the sweep; re-render figures from the
committed sweep.json. Seconds, no simulator.
GQA_VALIDATION=1 Run the four-panel validation sweep + figures.
~1-2h on the full simulator.
Headline-scale mode is deferred to sub-cycle 4c (figures landed
here; headline-scale + scaling figure await sub-cycle 4b).
A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4
"at least one request" contract is satisfied even when the panels
are skipped via MILESTONE_FAST=1.
"""
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
fast = bool(os.environ.get("MILESTONE_FAST"))
if not fast and not os.environ.get("GQA_VALIDATION"):
raise RuntimeError(
"milestone-gqa-llama70b v1 needs GQA_VALIDATION=1 (run the "
"sweep) or MILESTONE_FAST=1 (reuse committed sweep.json). "
"Headline mode is deferred to sub-cycle 4b/4c per ADR-0057 D3."
)
if not fast:
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
sweep = {
"version": 1,
"validation_scale": True,
"panels": list(_PANELS_V1),
"config": {
"S_q_prefill": _S_Q_PREFILL,
"S_kv_per_rank": _S_KV_PER_RANK,
"h_q": _H_Q,
"h_kv": _H_KV,
"d_head": _D_HEAD,
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
},
"rows": rows,
}
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
elif not _SWEEP_JSON.exists():
raise RuntimeError(
f"MILESTONE_FAST=1 requires {_SWEEP_JSON} to exist; "
"run with GQA_VALIDATION=1 once to seed it."
)
paths = emit_all_gqa_plots()
print(f" milestone-gqa-llama70b: {len(paths)} figures -> {_OUTPUT_DIR} "
f"(fast={fast})")
# Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out).
torch.zeros(
(1, 1), dtype="f16",
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
name="milestone_gqa_sentinel",
)
@@ -0,0 +1,222 @@
"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels).
ADR-0057 (Proposed) defines an eval bench that drives both attention kernels
(ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log
summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``.
v1 (sub-cycle 4a) covers ALL FOUR panels:
Panel name in JSON / test Study label SFR install used
─────────────────────────────────────────────────────────────────────────────
single_user_prefill TL configure_sfr_intracube_pe_ring
multi_user_prefill TR configure_sfr_intercube_multisip
single_user_decode BL configure_sfr_intracube_pe_ring
multi_user_decode BR configure_sfr_intercube_multisip
single_user_* panels became runnable after sub-cycle 4-pre delivered the
new SFR install function (ADR-0058).
In Phase 1 the bench module does not exist; pytest collection fails with
``ModuleNotFoundError``. Once Phase 2 lands the bench module, every
assertion below must pass.
Assertions:
- Bench is registered as ``milestone-gqa-llama70b``.
- A validation run (``GQA_VALIDATION=1``) completes ok via run_bench.
- sweep.json conforms to ADR-0057 D7 (v1 schema).
- All four panel rows present with sane op_log summaries.
- Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank).
- Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs).
"""
from __future__ import annotations
import json
from kernbench.benches.registry import resolve
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
# Production module (Phase 2 deliverable; absent in Phase 1).
import kernbench.benches.milestone_gqa_llama70b as gqa_bench
BENCH_NAME = "milestone-gqa-llama70b"
PANELS_V1 = (
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode",
)
SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode")
MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode")
PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill")
DECODE_PANELS = ("single_user_decode", "multi_user_decode")
def _run_validation():
"""Drive the bench through run_bench at validation scale."""
topo = resolve_topology("topology.yaml")
return run_bench(
topology=topo,
bench_fn=resolve(BENCH_NAME).run,
device=resolve_device(None),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True,
),
)
# ── Registration ────────────────────────────────────────────────────────
def test_bench_registered():
spec = resolve(BENCH_NAME)
assert spec.name == BENCH_NAME
assert callable(spec.run)
assert spec.description.strip(), "description must be non-empty"
# ── Validation run end-to-end ────────────────────────────────────────────
def test_validation_run_completes_ok(monkeypatch):
monkeypatch.setenv("GQA_VALIDATION", "1")
result = _run_validation()
assert result.completion.ok, (
f"validation run failed: {result.completion}"
)
# ── JSON shape (ADR-0057 D7 amended for 4 panels) ──────────────────────
def _sweep_json(monkeypatch) -> dict:
"""Run the bench (if needed) and return the parsed sweep.json."""
monkeypatch.setenv("GQA_VALIDATION", "1")
out = gqa_bench._OUTPUT_DIR / "sweep.json"
if not out.exists():
result = _run_validation()
assert result.completion.ok, result.completion
assert out.exists(), f"missing {out}"
return json.loads(out.read_text())
def test_sweep_json_has_v1_schema(monkeypatch):
data = _sweep_json(monkeypatch)
assert data["version"] == 1
assert data["validation_scale"] is True
assert isinstance(data["panels"], list)
assert isinstance(data["config"], dict)
assert isinstance(data["rows"], list)
def test_sweep_json_panels_are_all_four(monkeypatch):
"""v1 covers all four panels — single_user_{prefill,decode} +
multi_user_{prefill,decode}. Q/cube sweep deferred to 4b."""
data = _sweep_json(monkeypatch)
assert set(data["panels"]) == set(PANELS_V1)
def test_sweep_json_config_matches_adr0057_d4(monkeypatch):
"""Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget).
S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the
simulator's 1 MB per-PE TCM kernel scratch (topology.yaml
``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the
bump-allocated handle outputs of softmax/exp/dot/sum chains over
n_ranks ring steps. Headline-scale runs in 4c will lift these into a
config-driven sweep.
"""
data = _sweep_json(monkeypatch)
cfg = data["config"]
assert cfg["S_q_prefill"] == 16
assert cfg["S_kv_per_rank"] == 16
# v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view
# (which is symbolic and does not survive MemoryStore's nbytes check
# under simulator data execution). Real GQA (h_q > h_kv) is deferred
# to sub-cycle 4c (headline scale).
assert cfg["h_q"] == 1
assert cfg["h_kv"] == 1
assert cfg["d_head"] == 64
# single_user_* uses the 8 PEs in one cube as ring ranks.
assert cfg["n_ranks_single_user"] == 8
# multi_user_* uses cube-level ring; validation uses 4 cubes.
assert cfg["n_ranks_multi_user"] == 4
def test_sweep_json_has_one_row_per_panel(monkeypatch):
data = _sweep_json(monkeypatch)
assert len(data["rows"]) == 4
panels_in_rows = {r["panel"] for r in data["rows"]}
assert panels_in_rows == set(PANELS_V1)
# ── Per-row op_log summary sanity (ADR-0057 D7) ─────────────────────────
def _row(rows: list, panel: str) -> dict:
matches = [r for r in rows if r["panel"] == panel]
assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}"
return matches[0]
def _assert_sane_summary(row: dict) -> None:
s = row["op_log_summary"]
panel = row["panel"]
assert s["gemm_count"] > 0, f"{panel} must run GEMMs"
assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)"
assert s["ipcq_recv_count"] > 0, f"{panel} must recv"
assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads"
assert s["dma_write_count"] >= 1, f"{panel}: final O store"
def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "single_user_prefill"))
def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "multi_user_prefill"))
def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "single_user_decode"))
def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "multi_user_decode"))
# ── Architectural invariant: decode = one-shot per rank ─────────────────
def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
"""ADR-0056 D3: decode kernel does ONE local partial-attention pass per
rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the
total = 2 × n_ranks. This distinguishes decode from prefill where each
ring step has 2 GEMMs and the total scales as 2 × n_ranks²."""
data = _sweep_json(monkeypatch)
row = _row(data["rows"], "single_user_decode")
n_ranks = row["n_ranks"]
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
f"got {row['op_log_summary']['gemm_count']}"
)
def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
"""Same one-shot invariant as single_user_decode — the kernel is the
same; what differs is who the ranks are (cubes vs PEs)."""
data = _sweep_json(monkeypatch)
row = _row(data["rows"], "multi_user_decode")
n_ranks = row["n_ranks"]
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
f"got {row['op_log_summary']['gemm_count']}"
)
+25
View File
@@ -0,0 +1,25 @@
"""Thin re-export shim for the GQA figure tests.
Not a test module (no ``test_`` prefix → pytest does not collect it).
Mirrors ``tests/gemm/_gemm_plot_helpers.py``. The renderer logic lives in
``kernbench.benches.milestone_gqa_llama70b`` (production single home,
ADR-0054). Defaults still target the bench's ``_OUTPUT_DIR``.
"""
from __future__ import annotations
from kernbench.benches.milestone_gqa_llama70b import (
_OUTPUT_DIR as GQA_PLOTS_DIR,
_SWEEP_JSON as GQA_SWEEP_JSON,
emit_all_gqa_plots,
emit_gqa_comparison,
emit_panel_op_log_summary,
)
__all__ = [
"GQA_PLOTS_DIR",
"GQA_SWEEP_JSON",
"emit_all_gqa_plots",
"emit_gqa_comparison",
"emit_panel_op_log_summary",
]
+109
View File
@@ -0,0 +1,109 @@
"""Phase 1 spec test for GQA figure renderers (sub-cycle 4c).
ADR-0057 D3 sub-cycle 4c adds 6 figure renderers; this test pins the
5 of 6 that don't depend on sub-cycle 4b's Q/cube sweep:
- 4 per-panel op_log_summary PNGs (one per panel of v1's sweep.json)
- 1 cross-panel ``gqa_comparison.png`` (4-panel grouped bars over the
5 op_log_summary counts: gemm, ipcq_send, ipcq_recv, dma_read, dma_write)
The 6th, ``gqa_scaling.png``, needs the Q/cube ∈ {1, 2, 4} sweep from
sub-cycle 4b and is deferred.
Each test depends on the committed
``benches/1H_milestone_output/gqa/sweep.json`` (landed in commit
``e748a62``); they assert the renderer writes a non-empty PNG at the
expected path.
Phase 1 expectation: tests fail at import (renderer functions don't
exist yet on the bench module). Phase 2 lands them and the tests
turn green.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from tests.gqa._gqa_plot_helpers import (
GQA_PLOTS_DIR,
GQA_SWEEP_JSON,
emit_all_gqa_plots,
emit_gqa_comparison,
emit_panel_op_log_summary,
)
_PANELS = (
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode",
)
@pytest.mark.skipif(
not GQA_SWEEP_JSON.exists(),
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
)
@pytest.mark.parametrize("panel", _PANELS)
def test_emit_panel_op_log_summary_writes_png_for_each_panel(panel):
out = emit_panel_op_log_summary(panel)
assert out is not None, f"{panel}: renderer returned None"
path = Path(out)
assert path.exists(), f"{panel}: expected PNG at {path}"
assert path.suffix == ".png", f"{panel}: not a PNG: {path}"
assert path.stat().st_size > 0, f"{panel}: empty PNG: {path}"
assert panel in path.stem, (
f"{panel}: panel name not in filename {path.name}"
)
@pytest.mark.skipif(
not GQA_SWEEP_JSON.exists(),
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
)
def test_emit_gqa_comparison_writes_png():
out = emit_gqa_comparison()
assert out is not None
path = Path(out)
assert path.exists()
assert path.name == "gqa_comparison.png"
assert path.stat().st_size > 0
@pytest.mark.skipif(
not GQA_SWEEP_JSON.exists(),
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
)
def test_emit_all_gqa_plots_writes_five_figures():
"""emit_all returns a list of 5 written PNG paths (deferring the
6th gqa_scaling.png to after sub-cycle 4b lands the Q/cube sweep)."""
paths = emit_all_gqa_plots()
assert isinstance(paths, list)
# 4 per-panel + 1 comparison.
assert len(paths) == 5, f"expected 5 PNGs, got {len(paths)}: {paths}"
for p in paths:
assert Path(p).exists() and Path(p).stat().st_size > 0
names = {Path(p).name for p in paths}
assert "gqa_comparison.png" in names
for panel in _PANELS:
assert any(panel in n for n in names), (
f"no per-panel PNG for {panel} in {names}"
)
def test_emit_all_gqa_plots_output_dir_matches_bench_output_dir():
"""The renderers must write under the bench's own _OUTPUT_DIR so
MILESTONE_FAST=1 reuse (and committed baselines) all point at the
same on-disk location."""
# Stub assertion that fails until emit_all_gqa_plots exists with a
# default ``out_dir`` argument identical to GQA_PLOTS_DIR.
import inspect
sig = inspect.signature(emit_all_gqa_plots)
assert "out_dir" in sig.parameters
default = sig.parameters["out_dir"].default
assert Path(default) == GQA_PLOTS_DIR, (
f"default out_dir {default} != bench _OUTPUT_DIR {GQA_PLOTS_DIR}"
)