e748a62264
Self-contained eval bench (ADR-0054) that drives the four GQA Llama-70B panels through run_bench with enable_data=True at validation scale and emits sweep.json with the v1 schema (ADR-0057 D7). Panel dispatch table maps each panel to (kernel, SFR install, S_q, n_ranks, rank_axis): single_user_prefill mesh_kv_kernel, intracube_pe_ring, S_q=16, n=8, rank_axis=0 multi_user_prefill mesh_kv_kernel, intercube_multisip, S_q=16, n=4, rank_axis=1 single_user_decode mesh_mlo_kernel, intracube_pe_ring, S_q=1, n=8, rank_axis=0 multi_user_decode mesh_mlo_kernel, intercube_multisip, S_q=1, n=4, rank_axis=1 multi_user panels pass _auto_dim_remap=False (avoid d_head=64 colliding with K's global M=64) and rank_axis=1 (cube-level ring, gates 7 of every 8 PEs to silence). Each panel runs on a fresh per-config GraphEngine, then op_log is summarized into gemm/dma/ipcq counts. Both decode panels emit exactly 2*n_ranks GEMMs (one-shot partial attention per rank, ADR-0056 D3). v1 supports GQA_VALIDATION=1 only; headline mode + figures deferred to sub-cycles 4b/4c. Sentinel tensor satisfies the run_bench "at least one request" contract (ADR-0045 D4 / ADR-0054 D2 carve-out). Tests: tests/attention/test_milestone_gqa_llama70b.py — all 12 pass. Includes committed sweep.json baseline at the bench's _OUTPUT_DIR so subsequent test runs reuse it instead of re-simulating. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
223 lines
8.8 KiB
Python
223 lines
8.8 KiB
Python
"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels).
|
||
|
||
ADR-0057 (Proposed) defines an eval bench that drives both attention kernels
|
||
(ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log
|
||
summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``.
|
||
|
||
v1 (sub-cycle 4a) covers ALL FOUR panels:
|
||
|
||
Panel name in JSON / test Study label SFR install used
|
||
─────────────────────────────────────────────────────────────────────────────
|
||
single_user_prefill TL configure_sfr_intracube_pe_ring
|
||
multi_user_prefill TR configure_sfr_intercube_multisip
|
||
single_user_decode BL configure_sfr_intracube_pe_ring
|
||
multi_user_decode BR configure_sfr_intercube_multisip
|
||
|
||
single_user_* panels became runnable after sub-cycle 4-pre delivered the
|
||
new SFR install function (ADR-0058).
|
||
|
||
In Phase 1 the bench module does not exist; pytest collection fails with
|
||
``ModuleNotFoundError``. Once Phase 2 lands the bench module, every
|
||
assertion below must pass.
|
||
|
||
Assertions:
|
||
- Bench is registered as ``milestone-gqa-llama70b``.
|
||
- A validation run (``GQA_VALIDATION=1``) completes ok via run_bench.
|
||
- sweep.json conforms to ADR-0057 D7 (v1 schema).
|
||
- All four panel rows present with sane op_log summaries.
|
||
- Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank).
|
||
- Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs).
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
|
||
from kernbench.benches.registry import resolve
|
||
from kernbench.runtime_api.bench_runner import run_bench
|
||
from kernbench.runtime_api.types import resolve_device
|
||
from kernbench.sim_engine.engine import GraphEngine
|
||
from kernbench.topology.builder import resolve_topology
|
||
|
||
# Production module (Phase 2 deliverable; absent in Phase 1).
|
||
import kernbench.benches.milestone_gqa_llama70b as gqa_bench
|
||
|
||
|
||
BENCH_NAME = "milestone-gqa-llama70b"
|
||
|
||
PANELS_V1 = (
|
||
"single_user_prefill",
|
||
"multi_user_prefill",
|
||
"single_user_decode",
|
||
"multi_user_decode",
|
||
)
|
||
SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode")
|
||
MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode")
|
||
PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill")
|
||
DECODE_PANELS = ("single_user_decode", "multi_user_decode")
|
||
|
||
|
||
def _run_validation():
|
||
"""Drive the bench through run_bench at validation scale."""
|
||
topo = resolve_topology("topology.yaml")
|
||
return run_bench(
|
||
topology=topo,
|
||
bench_fn=resolve(BENCH_NAME).run,
|
||
device=resolve_device(None),
|
||
engine_factory=lambda t, d: GraphEngine(
|
||
getattr(t, "topology_obj", t), enable_data=True,
|
||
),
|
||
)
|
||
|
||
|
||
# ── Registration ────────────────────────────────────────────────────────
|
||
|
||
|
||
def test_bench_registered():
|
||
spec = resolve(BENCH_NAME)
|
||
assert spec.name == BENCH_NAME
|
||
assert callable(spec.run)
|
||
assert spec.description.strip(), "description must be non-empty"
|
||
|
||
|
||
# ── Validation run end-to-end ────────────────────────────────────────────
|
||
|
||
|
||
def test_validation_run_completes_ok(monkeypatch):
|
||
monkeypatch.setenv("GQA_VALIDATION", "1")
|
||
result = _run_validation()
|
||
assert result.completion.ok, (
|
||
f"validation run failed: {result.completion}"
|
||
)
|
||
|
||
|
||
# ── JSON shape (ADR-0057 D7 amended for 4 panels) ──────────────────────
|
||
|
||
|
||
def _sweep_json(monkeypatch) -> dict:
|
||
"""Run the bench (if needed) and return the parsed sweep.json."""
|
||
monkeypatch.setenv("GQA_VALIDATION", "1")
|
||
out = gqa_bench._OUTPUT_DIR / "sweep.json"
|
||
if not out.exists():
|
||
result = _run_validation()
|
||
assert result.completion.ok, result.completion
|
||
assert out.exists(), f"missing {out}"
|
||
return json.loads(out.read_text())
|
||
|
||
|
||
def test_sweep_json_has_v1_schema(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
assert data["version"] == 1
|
||
assert data["validation_scale"] is True
|
||
assert isinstance(data["panels"], list)
|
||
assert isinstance(data["config"], dict)
|
||
assert isinstance(data["rows"], list)
|
||
|
||
|
||
def test_sweep_json_panels_are_all_four(monkeypatch):
|
||
"""v1 covers all four panels — single_user_{prefill,decode} +
|
||
multi_user_{prefill,decode}. Q/cube sweep deferred to 4b."""
|
||
data = _sweep_json(monkeypatch)
|
||
assert set(data["panels"]) == set(PANELS_V1)
|
||
|
||
|
||
def test_sweep_json_config_matches_adr0057_d4(monkeypatch):
|
||
"""Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget).
|
||
|
||
S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the
|
||
simulator's 1 MB per-PE TCM kernel scratch (topology.yaml
|
||
``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the
|
||
bump-allocated handle outputs of softmax/exp/dot/sum chains over
|
||
n_ranks ring steps. Headline-scale runs in 4c will lift these into a
|
||
config-driven sweep.
|
||
"""
|
||
data = _sweep_json(monkeypatch)
|
||
cfg = data["config"]
|
||
assert cfg["S_q_prefill"] == 16
|
||
assert cfg["S_kv_per_rank"] == 16
|
||
# v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view
|
||
# (which is symbolic and does not survive MemoryStore's nbytes check
|
||
# under simulator data execution). Real GQA (h_q > h_kv) is deferred
|
||
# to sub-cycle 4c (headline scale).
|
||
assert cfg["h_q"] == 1
|
||
assert cfg["h_kv"] == 1
|
||
assert cfg["d_head"] == 64
|
||
# single_user_* uses the 8 PEs in one cube as ring ranks.
|
||
assert cfg["n_ranks_single_user"] == 8
|
||
# multi_user_* uses cube-level ring; validation uses 4 cubes.
|
||
assert cfg["n_ranks_multi_user"] == 4
|
||
|
||
|
||
def test_sweep_json_has_one_row_per_panel(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
assert len(data["rows"]) == 4
|
||
panels_in_rows = {r["panel"] for r in data["rows"]}
|
||
assert panels_in_rows == set(PANELS_V1)
|
||
|
||
|
||
# ── Per-row op_log summary sanity (ADR-0057 D7) ─────────────────────────
|
||
|
||
|
||
def _row(rows: list, panel: str) -> dict:
|
||
matches = [r for r in rows if r["panel"] == panel]
|
||
assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}"
|
||
return matches[0]
|
||
|
||
|
||
def _assert_sane_summary(row: dict) -> None:
|
||
s = row["op_log_summary"]
|
||
panel = row["panel"]
|
||
assert s["gemm_count"] > 0, f"{panel} must run GEMMs"
|
||
assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)"
|
||
assert s["ipcq_recv_count"] > 0, f"{panel} must recv"
|
||
assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads"
|
||
assert s["dma_write_count"] >= 1, f"{panel}: final O store"
|
||
|
||
|
||
def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
_assert_sane_summary(_row(data["rows"], "single_user_prefill"))
|
||
|
||
|
||
def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
_assert_sane_summary(_row(data["rows"], "multi_user_prefill"))
|
||
|
||
|
||
def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
_assert_sane_summary(_row(data["rows"], "single_user_decode"))
|
||
|
||
|
||
def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch):
|
||
data = _sweep_json(monkeypatch)
|
||
_assert_sane_summary(_row(data["rows"], "multi_user_decode"))
|
||
|
||
|
||
# ── Architectural invariant: decode = one-shot per rank ─────────────────
|
||
|
||
|
||
def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
|
||
"""ADR-0056 D3: decode kernel does ONE local partial-attention pass per
|
||
rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the
|
||
total = 2 × n_ranks. This distinguishes decode from prefill where each
|
||
ring step has 2 GEMMs and the total scales as 2 × n_ranks²."""
|
||
data = _sweep_json(monkeypatch)
|
||
row = _row(data["rows"], "single_user_decode")
|
||
n_ranks = row["n_ranks"]
|
||
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
|
||
f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
|
||
f"got {row['op_log_summary']['gemm_count']}"
|
||
)
|
||
|
||
|
||
def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
|
||
"""Same one-shot invariant as single_user_decode — the kernel is the
|
||
same; what differs is who the ranks are (cubes vs PEs)."""
|
||
data = _sweep_json(monkeypatch)
|
||
row = _row(data["rows"], "multi_user_decode")
|
||
n_ranks = row["n_ranks"]
|
||
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
|
||
f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
|
||
f"got {row['op_log_summary']['gemm_count']}"
|
||
)
|