"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels). ADR-0057 (Proposed) defines an eval bench that drives both attention kernels (ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``. v1 (sub-cycle 4a) covers ALL FOUR panels: Panel name in JSON / test Study label SFR install used ───────────────────────────────────────────────────────────────────────────── single_user_prefill TL configure_sfr_intracube_pe_ring multi_user_prefill TR configure_sfr_intercube_multisip single_user_decode BL configure_sfr_intracube_pe_ring multi_user_decode BR configure_sfr_intercube_multisip single_user_* panels became runnable after sub-cycle 4-pre delivered the new SFR install function (ADR-0058). In Phase 1 the bench module does not exist; pytest collection fails with ``ModuleNotFoundError``. Once Phase 2 lands the bench module, every assertion below must pass. Assertions: - Bench is registered as ``milestone-gqa-llama70b``. - A validation run (``GQA_VALIDATION=1``) completes ok via run_bench. - sweep.json conforms to ADR-0057 D7 (v1 schema). - All four panel rows present with sane op_log summaries. - Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank). - Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs). """ from __future__ import annotations import json from kernbench.benches.registry import resolve from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology # Production module (Phase 2 deliverable; absent in Phase 1). import kernbench.benches.milestone_gqa_llama70b as gqa_bench BENCH_NAME = "milestone-gqa-llama70b" PANELS_V1 = ( "single_user_prefill", "multi_user_prefill", "single_user_decode", "multi_user_decode", ) SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode") MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode") PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill") DECODE_PANELS = ("single_user_decode", "multi_user_decode") def _run_validation(): """Drive the bench through run_bench at validation scale.""" topo = resolve_topology("topology.yaml") return run_bench( topology=topo, bench_fn=resolve(BENCH_NAME).run, device=resolve_device(None), engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True, ), ) # ── Registration ──────────────────────────────────────────────────────── def test_bench_registered(): spec = resolve(BENCH_NAME) assert spec.name == BENCH_NAME assert callable(spec.run) assert spec.description.strip(), "description must be non-empty" # ── Validation run end-to-end ──────────────────────────────────────────── def test_validation_run_completes_ok(monkeypatch): monkeypatch.setenv("GQA_VALIDATION", "1") result = _run_validation() assert result.completion.ok, ( f"validation run failed: {result.completion}" ) # ── JSON shape (ADR-0057 D7 amended for 4 panels) ────────────────────── def _sweep_json(monkeypatch) -> dict: """Run the bench (if needed) and return the parsed sweep.json.""" monkeypatch.setenv("GQA_VALIDATION", "1") out = gqa_bench._OUTPUT_DIR / "sweep.json" if not out.exists(): result = _run_validation() assert result.completion.ok, result.completion assert out.exists(), f"missing {out}" return json.loads(out.read_text()) def test_sweep_json_has_v1_schema(monkeypatch): data = _sweep_json(monkeypatch) assert data["version"] == 1 assert data["validation_scale"] is True assert isinstance(data["panels"], list) assert isinstance(data["config"], dict) assert isinstance(data["rows"], list) def test_sweep_json_panels_are_all_four(monkeypatch): """v1 covers all four panels — single_user_{prefill,decode} + multi_user_{prefill,decode}. Q/cube sweep deferred to 4b.""" data = _sweep_json(monkeypatch) assert set(data["panels"]) == set(PANELS_V1) def test_sweep_json_config_matches_adr0057_d4(monkeypatch): """Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget). S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the simulator's 1 MB per-PE TCM kernel scratch (topology.yaml ``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the bump-allocated handle outputs of softmax/exp/dot/sum chains over n_ranks ring steps. Headline-scale runs in 4c will lift these into a config-driven sweep. """ data = _sweep_json(monkeypatch) cfg = data["config"] assert cfg["S_q_prefill"] == 16 assert cfg["S_kv_per_rank"] == 16 # v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view # (which is symbolic and does not survive MemoryStore's nbytes check # under simulator data execution). Real GQA (h_q > h_kv) is deferred # to sub-cycle 4c (headline scale). assert cfg["h_q"] == 1 assert cfg["h_kv"] == 1 assert cfg["d_head"] == 64 # single_user_* uses the 8 PEs in one cube as ring ranks. assert cfg["n_ranks_single_user"] == 8 # multi_user_* uses cube-level ring; validation uses 4 cubes. assert cfg["n_ranks_multi_user"] == 4 def test_sweep_json_has_one_row_per_panel(monkeypatch): data = _sweep_json(monkeypatch) assert len(data["rows"]) == 4 panels_in_rows = {r["panel"] for r in data["rows"]} assert panels_in_rows == set(PANELS_V1) # ── Per-row op_log summary sanity (ADR-0057 D7) ───────────────────────── def _row(rows: list, panel: str) -> dict: matches = [r for r in rows if r["panel"] == panel] assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}" return matches[0] def _assert_sane_summary(row: dict) -> None: s = row["op_log_summary"] panel = row["panel"] assert s["gemm_count"] > 0, f"{panel} must run GEMMs" assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)" assert s["ipcq_recv_count"] > 0, f"{panel} must recv" assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads" assert s["dma_write_count"] >= 1, f"{panel}: final O store" def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch): data = _sweep_json(monkeypatch) _assert_sane_summary(_row(data["rows"], "single_user_prefill")) def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch): data = _sweep_json(monkeypatch) _assert_sane_summary(_row(data["rows"], "multi_user_prefill")) def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch): data = _sweep_json(monkeypatch) _assert_sane_summary(_row(data["rows"], "single_user_decode")) def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch): data = _sweep_json(monkeypatch) _assert_sane_summary(_row(data["rows"], "multi_user_decode")) # ── Architectural invariant: decode = one-shot per rank ───────────────── def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch): """ADR-0056 D3: decode kernel does ONE local partial-attention pass per rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the total = 2 × n_ranks. This distinguishes decode from prefill where each ring step has 2 GEMMs and the total scales as 2 × n_ranks².""" data = _sweep_json(monkeypatch) row = _row(data["rows"], "single_user_decode") n_ranks = row["n_ranks"] assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, ( f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; " f"got {row['op_log_summary']['gemm_count']}" ) def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch): """Same one-shot invariant as single_user_decode — the kernel is the same; what differs is who the ranks are (cubes vs PEs).""" data = _sweep_json(monkeypatch) row = _row(data["rows"], "multi_user_decode") n_ranks = row["n_ranks"] assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, ( f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; " f"got {row['op_log_summary']['gemm_count']}" )