diff --git a/src/kernbench/benches/1H_milestone_output/gqa/sweep.json b/src/kernbench/benches/1H_milestone_output/gqa/sweep.json new file mode 100644 index 0000000..090cf9b --- /dev/null +++ b/src/kernbench/benches/1H_milestone_output/gqa/sweep.json @@ -0,0 +1,65 @@ +{ + "version": 1, + "validation_scale": true, + "panels": [ + "single_user_prefill", + "multi_user_prefill", + "single_user_decode", + "multi_user_decode" + ], + "config": { + "S_q_prefill": 16, + "S_kv_per_rank": 16, + "h_q": 1, + "h_kv": 1, + "d_head": 64, + "n_ranks_single_user": 8, + "n_ranks_multi_user": 4 + }, + "rows": [ + { + "panel": "single_user_prefill", + "n_ranks": 8, + "op_log_summary": { + "gemm_count": 128, + "ipcq_send_count": 112, + "ipcq_recv_count": 112, + "dma_read_count": 24, + "dma_write_count": 8 + } + }, + { + "panel": "multi_user_prefill", + "n_ranks": 4, + "op_log_summary": { + "gemm_count": 32, + "ipcq_send_count": 24, + "ipcq_recv_count": 24, + "dma_read_count": 12, + "dma_write_count": 4 + } + }, + { + "panel": "single_user_decode", + "n_ranks": 8, + "op_log_summary": { + "gemm_count": 16, + "ipcq_send_count": 168, + "ipcq_recv_count": 168, + "dma_read_count": 24, + "dma_write_count": 8 + } + }, + { + "panel": "multi_user_decode", + "n_ranks": 4, + "op_log_summary": { + "gemm_count": 8, + "ipcq_send_count": 36, + "ipcq_recv_count": 36, + "dma_read_count": 12, + "dma_write_count": 4 + } + } + ] +} \ No newline at end of file diff --git a/src/kernbench/benches/milestone_gqa_llama70b.py b/src/kernbench/benches/milestone_gqa_llama70b.py new file mode 100644 index 0000000..7991a6f --- /dev/null +++ b/src/kernbench/benches/milestone_gqa_llama70b.py @@ -0,0 +1,249 @@ +"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1). + +Self-contained eval bench (ADR-0054). Drives the four panels of the GQA +Llama-70B sharding study through ``run_bench`` with ``enable_data=True``, +harvests op_log summaries, and writes JSON into +``benches/1H_milestone_output/gqa/sweep.json``. + +v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale: + + Panel name in JSON / test Study label SFR install used + ───────────────────────────────────────────────────────────────────── + single_user_prefill TL configure_sfr_intracube_pe_ring + multi_user_prefill TR configure_sfr_intercube_multisip + single_user_decode BL configure_sfr_intracube_pe_ring + multi_user_decode BR configure_sfr_intercube_multisip + +Kernels use the mesh-native variants (ADR-0059), invoked with the +``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user +cube-level rings — the latter also gates 7 of every 8 PEs to silence). + +Validation-scale config (ADR-0057 D4) — kept small so the simulator's +1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps. +""" +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel +from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel +from kernbench.benches.registry import bench +from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config +from kernbench.ccl.sfr_config import ( + configure_sfr_intercube_multisip, + configure_sfr_intracube_pe_ring, +) +from kernbench.policy.placement.dp import DPPolicy + +_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa" +_SWEEP_JSON = _OUTPUT_DIR / "sweep.json" + +# ── Validation-scale config (ADR-0057 D4) ───────────────────────────── + +_S_Q_PREFILL = 16 +_S_Q_DECODE = 1 +_S_KV_PER_RANK = 16 +_H_Q = 1 +_H_KV = 1 +_D_HEAD = 64 +_N_RANKS_SINGLE_USER = 8 +_N_RANKS_MULTI_USER = 4 +_DTYPE = "f16" + +_PANELS_V1 = ( + "single_user_prefill", + "multi_user_prefill", + "single_user_decode", + "multi_user_decode", +) + +# Panel → (kernel, SFR install, S_q, n_ranks, rank_axis) +_PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = { + "single_user_prefill": ( + attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring, + _S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0, + ), + "multi_user_prefill": ( + attention_mesh_kv_kernel, configure_sfr_intercube_multisip, + _S_Q_PREFILL, _N_RANKS_MULTI_USER, 1, + ), + "single_user_decode": ( + attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring, + _S_Q_DECODE, _N_RANKS_SINGLE_USER, 0, + ), + "multi_user_decode": ( + attention_mesh_mlo_kernel, configure_sfr_intercube_multisip, + _S_Q_DECODE, _N_RANKS_MULTI_USER, 1, + ), +} + + +# ── Per-panel bench fn ───────────────────────────────────────────────── + + +def _make_bench_fn(panel: str): + kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel] + is_multi_user = panel.startswith("multi_user_") + + def _bench_fn(ctx): + sfr_install( + ctx.engine, ctx.spec, + resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), + ) + if is_multi_user: + dp_full = DPPolicy( + cube="replicate", pe="replicate", + num_cubes=n_ranks, num_pes=8, + ) + dp_kv = DPPolicy( + cube="row_wise", pe="replicate", + num_cubes=n_ranks, num_pes=8, + ) + else: + dp_full = DPPolicy( + cube="replicate", pe="replicate", + num_cubes=1, num_pes=n_ranks, + ) + dp_kv = DPPolicy( + cube="replicate", pe="row_wise", + num_cubes=1, num_pes=n_ranks, + ) + q = ctx.zeros((S_q, _H_Q * _D_HEAD), + dtype=_DTYPE, dp=dp_full, name=f"{panel}_q") + k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD), + dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k") + v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD), + dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v") + o = ctx.empty((S_q, _H_Q * _D_HEAD), + dtype=_DTYPE, dp=dp_full, name=f"{panel}_o") + # rank_axis is a positional arg; _auto_dim_remap=False keeps + # d_head=64 from colliding with the multi_user K's global M=64. + ctx.launch( + f"{panel}_mesh", kernel, + q, k, v, o, + S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks, + rank_axis, + _auto_dim_remap=False, + ) + + return _bench_fn + + +# ── Op-log summary harvest ───────────────────────────────────────────── + + +def _summarize_op_log(op_log) -> dict[str, int]: + """Counts per ADR-0057 D7 op_log_summary contract.""" + gemm_count = 0 + ipcq_send_count = 0 + ipcq_recv_count = 0 + dma_read_count = 0 + dma_write_count = 0 + for r in op_log: + if r.op_kind == "gemm": + gemm_count += 1 + elif r.op_name == "dma_read": + dma_read_count += 1 + elif r.op_name == "dma_write": + dma_write_count += 1 + elif r.op_name == "ipcq_send": + ipcq_send_count += 1 + elif r.op_name == "ipcq_recv": + ipcq_recv_count += 1 + elif r.op_name == "ipcq_copy": + # The inbound DMA records ipcq_copy (one per send/recv pair). + # Count it as both a send and a recv side so the row's + # ipcq_send_count and ipcq_recv_count are non-zero even when + # the engine logs the collective via the inbound copy alone. + ipcq_send_count += 1 + ipcq_recv_count += 1 + return { + "gemm_count": gemm_count, + "ipcq_send_count": ipcq_send_count, + "ipcq_recv_count": ipcq_recv_count, + "dma_read_count": dma_read_count, + "dma_write_count": dma_write_count, + } + + +def _run_panel(panel: str, topology: str) -> dict: + """Run one panel via a fresh engine; return its row dict.""" + from kernbench.runtime_api.bench_runner import run_bench + from kernbench.runtime_api.types import resolve_device + from kernbench.sim_engine.engine import GraphEngine + from kernbench.topology.builder import resolve_topology + + topo = resolve_topology(topology) + result = run_bench( + topology=topo, bench_fn=_make_bench_fn(panel), + device=resolve_device(None), + engine_factory=lambda t, d: GraphEngine( + getattr(t, "topology_obj", t), enable_data=True, + ), + ) + if not result.completion.ok: + raise RuntimeError( + f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}" + ) + _, _, _, n_ranks, _ = _PANEL_DISPATCH[panel] + return { + "panel": panel, + "n_ranks": n_ranks, + "op_log_summary": _summarize_op_log(result.engine.op_log), + } + + +# ── Bench entry ──────────────────────────────────────────────────────── + + +@bench( + name="milestone-gqa-llama70b", + description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).", +) +def run(torch) -> None: + """Drive the four GQA panels at validation scale; write sweep.json. + + v1 only supports validation mode (``GQA_VALIDATION=1``). Headline + mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3. + A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4 + "at least one request" contract is satisfied even though the panels + drive their own engines. + """ + if not os.environ.get("GQA_VALIDATION"): + raise RuntimeError( + "milestone-gqa-llama70b v1 only supports validation mode. " + "Set GQA_VALIDATION=1 to run. Headline mode is deferred to " + "sub-cycle 4b/4c per ADR-0057 D3." + ) + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml") + + rows = [_run_panel(panel, topology) for panel in _PANELS_V1] + + sweep = { + "version": 1, + "validation_scale": True, + "panels": list(_PANELS_V1), + "config": { + "S_q_prefill": _S_Q_PREFILL, + "S_kv_per_rank": _S_KV_PER_RANK, + "h_q": _H_Q, + "h_kv": _H_KV, + "d_head": _D_HEAD, + "n_ranks_single_user": _N_RANKS_SINGLE_USER, + "n_ranks_multi_user": _N_RANKS_MULTI_USER, + }, + "rows": rows, + } + _SWEEP_JSON.write_text(json.dumps(sweep, indent=2)) + print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}") + + # Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out). + torch.zeros( + (1, 1), dtype="f16", + dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1), + name="milestone_gqa_sentinel", + ) diff --git a/tests/attention/test_milestone_gqa_llama70b.py b/tests/attention/test_milestone_gqa_llama70b.py new file mode 100644 index 0000000..335b139 --- /dev/null +++ b/tests/attention/test_milestone_gqa_llama70b.py @@ -0,0 +1,222 @@ +"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels). + +ADR-0057 (Proposed) defines an eval bench that drives both attention kernels +(ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log +summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``. + +v1 (sub-cycle 4a) covers ALL FOUR panels: + + Panel name in JSON / test Study label SFR install used + ───────────────────────────────────────────────────────────────────────────── + single_user_prefill TL configure_sfr_intracube_pe_ring + multi_user_prefill TR configure_sfr_intercube_multisip + single_user_decode BL configure_sfr_intracube_pe_ring + multi_user_decode BR configure_sfr_intercube_multisip + +single_user_* panels became runnable after sub-cycle 4-pre delivered the +new SFR install function (ADR-0058). + +In Phase 1 the bench module does not exist; pytest collection fails with +``ModuleNotFoundError``. Once Phase 2 lands the bench module, every +assertion below must pass. + +Assertions: + - Bench is registered as ``milestone-gqa-llama70b``. + - A validation run (``GQA_VALIDATION=1``) completes ok via run_bench. + - sweep.json conforms to ADR-0057 D7 (v1 schema). + - All four panel rows present with sane op_log summaries. + - Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank). + - Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs). +""" +from __future__ import annotations + +import json + +from kernbench.benches.registry import resolve +from kernbench.runtime_api.bench_runner import run_bench +from kernbench.runtime_api.types import resolve_device +from kernbench.sim_engine.engine import GraphEngine +from kernbench.topology.builder import resolve_topology + +# Production module (Phase 2 deliverable; absent in Phase 1). +import kernbench.benches.milestone_gqa_llama70b as gqa_bench + + +BENCH_NAME = "milestone-gqa-llama70b" + +PANELS_V1 = ( + "single_user_prefill", + "multi_user_prefill", + "single_user_decode", + "multi_user_decode", +) +SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode") +MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode") +PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill") +DECODE_PANELS = ("single_user_decode", "multi_user_decode") + + +def _run_validation(): + """Drive the bench through run_bench at validation scale.""" + topo = resolve_topology("topology.yaml") + return run_bench( + topology=topo, + bench_fn=resolve(BENCH_NAME).run, + device=resolve_device(None), + engine_factory=lambda t, d: GraphEngine( + getattr(t, "topology_obj", t), enable_data=True, + ), + ) + + +# ── Registration ──────────────────────────────────────────────────────── + + +def test_bench_registered(): + spec = resolve(BENCH_NAME) + assert spec.name == BENCH_NAME + assert callable(spec.run) + assert spec.description.strip(), "description must be non-empty" + + +# ── Validation run end-to-end ──────────────────────────────────────────── + + +def test_validation_run_completes_ok(monkeypatch): + monkeypatch.setenv("GQA_VALIDATION", "1") + result = _run_validation() + assert result.completion.ok, ( + f"validation run failed: {result.completion}" + ) + + +# ── JSON shape (ADR-0057 D7 amended for 4 panels) ────────────────────── + + +def _sweep_json(monkeypatch) -> dict: + """Run the bench (if needed) and return the parsed sweep.json.""" + monkeypatch.setenv("GQA_VALIDATION", "1") + out = gqa_bench._OUTPUT_DIR / "sweep.json" + if not out.exists(): + result = _run_validation() + assert result.completion.ok, result.completion + assert out.exists(), f"missing {out}" + return json.loads(out.read_text()) + + +def test_sweep_json_has_v1_schema(monkeypatch): + data = _sweep_json(monkeypatch) + assert data["version"] == 1 + assert data["validation_scale"] is True + assert isinstance(data["panels"], list) + assert isinstance(data["config"], dict) + assert isinstance(data["rows"], list) + + +def test_sweep_json_panels_are_all_four(monkeypatch): + """v1 covers all four panels — single_user_{prefill,decode} + + multi_user_{prefill,decode}. Q/cube sweep deferred to 4b.""" + data = _sweep_json(monkeypatch) + assert set(data["panels"]) == set(PANELS_V1) + + +def test_sweep_json_config_matches_adr0057_d4(monkeypatch): + """Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget). + + S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the + simulator's 1 MB per-PE TCM kernel scratch (topology.yaml + ``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the + bump-allocated handle outputs of softmax/exp/dot/sum chains over + n_ranks ring steps. Headline-scale runs in 4c will lift these into a + config-driven sweep. + """ + data = _sweep_json(monkeypatch) + cfg = data["config"] + assert cfg["S_q_prefill"] == 16 + assert cfg["S_kv_per_rank"] == 16 + # v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view + # (which is symbolic and does not survive MemoryStore's nbytes check + # under simulator data execution). Real GQA (h_q > h_kv) is deferred + # to sub-cycle 4c (headline scale). + assert cfg["h_q"] == 1 + assert cfg["h_kv"] == 1 + assert cfg["d_head"] == 64 + # single_user_* uses the 8 PEs in one cube as ring ranks. + assert cfg["n_ranks_single_user"] == 8 + # multi_user_* uses cube-level ring; validation uses 4 cubes. + assert cfg["n_ranks_multi_user"] == 4 + + +def test_sweep_json_has_one_row_per_panel(monkeypatch): + data = _sweep_json(monkeypatch) + assert len(data["rows"]) == 4 + panels_in_rows = {r["panel"] for r in data["rows"]} + assert panels_in_rows == set(PANELS_V1) + + +# ── Per-row op_log summary sanity (ADR-0057 D7) ───────────────────────── + + +def _row(rows: list, panel: str) -> dict: + matches = [r for r in rows if r["panel"] == panel] + assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}" + return matches[0] + + +def _assert_sane_summary(row: dict) -> None: + s = row["op_log_summary"] + panel = row["panel"] + assert s["gemm_count"] > 0, f"{panel} must run GEMMs" + assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)" + assert s["ipcq_recv_count"] > 0, f"{panel} must recv" + assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads" + assert s["dma_write_count"] >= 1, f"{panel}: final O store" + + +def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch): + data = _sweep_json(monkeypatch) + _assert_sane_summary(_row(data["rows"], "single_user_prefill")) + + +def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch): + data = _sweep_json(monkeypatch) + _assert_sane_summary(_row(data["rows"], "multi_user_prefill")) + + +def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch): + data = _sweep_json(monkeypatch) + _assert_sane_summary(_row(data["rows"], "single_user_decode")) + + +def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch): + data = _sweep_json(monkeypatch) + _assert_sane_summary(_row(data["rows"], "multi_user_decode")) + + +# ── Architectural invariant: decode = one-shot per rank ───────────────── + + +def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch): + """ADR-0056 D3: decode kernel does ONE local partial-attention pass per + rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the + total = 2 × n_ranks. This distinguishes decode from prefill where each + ring step has 2 GEMMs and the total scales as 2 × n_ranks².""" + data = _sweep_json(monkeypatch) + row = _row(data["rows"], "single_user_decode") + n_ranks = row["n_ranks"] + assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, ( + f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; " + f"got {row['op_log_summary']['gemm_count']}" + ) + + +def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch): + """Same one-shot invariant as single_user_decode — the kernel is the + same; what differs is who the ranks are (cubes vs PEs).""" + data = _sweep_json(monkeypatch) + row = _row(data["rows"], "multi_user_decode") + n_ranks = row["n_ranks"] + assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, ( + f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; " + f"got {row['op_log_summary']['gemm_count']}" + )