attention: land milestone-gqa-llama70b 4-panel sweep bench (ADR-0057 v1)

Self-contained eval bench (ADR-0054) that drives the four GQA Llama-70B
panels through run_bench with enable_data=True at validation scale and
emits sweep.json with the v1 schema (ADR-0057 D7).

Panel dispatch table maps each panel to (kernel, SFR install, S_q,
n_ranks, rank_axis):
  single_user_prefill   mesh_kv_kernel,  intracube_pe_ring,  S_q=16, n=8, rank_axis=0
  multi_user_prefill    mesh_kv_kernel,  intercube_multisip, S_q=16, n=4, rank_axis=1
  single_user_decode    mesh_mlo_kernel, intracube_pe_ring,  S_q=1,  n=8, rank_axis=0
  multi_user_decode     mesh_mlo_kernel, intercube_multisip, S_q=1,  n=4, rank_axis=1

multi_user panels pass _auto_dim_remap=False (avoid d_head=64
colliding with K's global M=64) and rank_axis=1 (cube-level ring,
gates 7 of every 8 PEs to silence).

Each panel runs on a fresh per-config GraphEngine, then op_log is
summarized into gemm/dma/ipcq counts. Both decode panels emit exactly
2*n_ranks GEMMs (one-shot partial attention per rank, ADR-0056 D3).

v1 supports GQA_VALIDATION=1 only; headline mode + figures deferred to
sub-cycles 4b/4c. Sentinel tensor satisfies the run_bench
"at least one request" contract (ADR-0045 D4 / ADR-0054 D2 carve-out).

Tests: tests/attention/test_milestone_gqa_llama70b.py — all 12 pass.
Includes committed sweep.json baseline at the bench's _OUTPUT_DIR so
subsequent test runs reuse it instead of re-simulating.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 21:57:12 -07:00
parent 222815d374
commit e748a62264
3 changed files with 536 additions and 0 deletions
@@ -0,0 +1,65 @@
{
"version": 1,
"validation_scale": true,
"panels": [
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode"
],
"config": {
"S_q_prefill": 16,
"S_kv_per_rank": 16,
"h_q": 1,
"h_kv": 1,
"d_head": 64,
"n_ranks_single_user": 8,
"n_ranks_multi_user": 4
},
"rows": [
{
"panel": "single_user_prefill",
"n_ranks": 8,
"op_log_summary": {
"gemm_count": 128,
"ipcq_send_count": 112,
"ipcq_recv_count": 112,
"dma_read_count": 24,
"dma_write_count": 8
}
},
{
"panel": "multi_user_prefill",
"n_ranks": 4,
"op_log_summary": {
"gemm_count": 32,
"ipcq_send_count": 24,
"ipcq_recv_count": 24,
"dma_read_count": 12,
"dma_write_count": 4
}
},
{
"panel": "single_user_decode",
"n_ranks": 8,
"op_log_summary": {
"gemm_count": 16,
"ipcq_send_count": 168,
"ipcq_recv_count": 168,
"dma_read_count": 24,
"dma_write_count": 8
}
},
{
"panel": "multi_user_decode",
"n_ranks": 4,
"op_log_summary": {
"gemm_count": 8,
"ipcq_send_count": 36,
"ipcq_recv_count": 36,
"dma_read_count": 12,
"dma_write_count": 4
}
}
]
}
@@ -0,0 +1,249 @@
"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1).
Self-contained eval bench (ADR-0054). Drives the four panels of the GQA
Llama-70B sharding study through ``run_bench`` with ``enable_data=True``,
harvests op_log summaries, and writes JSON into
``benches/1H_milestone_output/gqa/sweep.json``.
v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale:
Panel name in JSON / test Study label SFR install used
─────────────────────────────────────────────────────────────────────
single_user_prefill TL configure_sfr_intracube_pe_ring
multi_user_prefill TR configure_sfr_intercube_multisip
single_user_decode BL configure_sfr_intracube_pe_ring
multi_user_decode BR configure_sfr_intercube_multisip
Kernels use the mesh-native variants (ADR-0059), invoked with the
``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user
cube-level rings — the latter also gates 7 of every 8 PEs to silence).
Validation-scale config (ADR-0057 D4) — kept small so the simulator's
1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
from kernbench.benches.registry import bench
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import (
configure_sfr_intercube_multisip,
configure_sfr_intracube_pe_ring,
)
from kernbench.policy.placement.dp import DPPolicy
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa"
_SWEEP_JSON = _OUTPUT_DIR / "sweep.json"
# ── Validation-scale config (ADR-0057 D4) ─────────────────────────────
_S_Q_PREFILL = 16
_S_Q_DECODE = 1
_S_KV_PER_RANK = 16
_H_Q = 1
_H_KV = 1
_D_HEAD = 64
_N_RANKS_SINGLE_USER = 8
_N_RANKS_MULTI_USER = 4
_DTYPE = "f16"
_PANELS_V1 = (
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode",
)
# Panel → (kernel, SFR install, S_q, n_ranks, rank_axis)
_PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = {
"single_user_prefill": (
attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring,
_S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0,
),
"multi_user_prefill": (
attention_mesh_kv_kernel, configure_sfr_intercube_multisip,
_S_Q_PREFILL, _N_RANKS_MULTI_USER, 1,
),
"single_user_decode": (
attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring,
_S_Q_DECODE, _N_RANKS_SINGLE_USER, 0,
),
"multi_user_decode": (
attention_mesh_mlo_kernel, configure_sfr_intercube_multisip,
_S_Q_DECODE, _N_RANKS_MULTI_USER, 1,
),
}
# ── Per-panel bench fn ─────────────────────────────────────────────────
def _make_bench_fn(panel: str):
kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel]
is_multi_user = panel.startswith("multi_user_")
def _bench_fn(ctx):
sfr_install(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
if is_multi_user:
dp_full = DPPolicy(
cube="replicate", pe="replicate",
num_cubes=n_ranks, num_pes=8,
)
dp_kv = DPPolicy(
cube="row_wise", pe="replicate",
num_cubes=n_ranks, num_pes=8,
)
else:
dp_full = DPPolicy(
cube="replicate", pe="replicate",
num_cubes=1, num_pes=n_ranks,
)
dp_kv = DPPolicy(
cube="replicate", pe="row_wise",
num_cubes=1, num_pes=n_ranks,
)
q = ctx.zeros((S_q, _H_Q * _D_HEAD),
dtype=_DTYPE, dp=dp_full, name=f"{panel}_q")
k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k")
v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v")
o = ctx.empty((S_q, _H_Q * _D_HEAD),
dtype=_DTYPE, dp=dp_full, name=f"{panel}_o")
# rank_axis is a positional arg; _auto_dim_remap=False keeps
# d_head=64 from colliding with the multi_user K's global M=64.
ctx.launch(
f"{panel}_mesh", kernel,
q, k, v, o,
S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks,
rank_axis,
_auto_dim_remap=False,
)
return _bench_fn
# ── Op-log summary harvest ─────────────────────────────────────────────
def _summarize_op_log(op_log) -> dict[str, int]:
"""Counts per ADR-0057 D7 op_log_summary contract."""
gemm_count = 0
ipcq_send_count = 0
ipcq_recv_count = 0
dma_read_count = 0
dma_write_count = 0
for r in op_log:
if r.op_kind == "gemm":
gemm_count += 1
elif r.op_name == "dma_read":
dma_read_count += 1
elif r.op_name == "dma_write":
dma_write_count += 1
elif r.op_name == "ipcq_send":
ipcq_send_count += 1
elif r.op_name == "ipcq_recv":
ipcq_recv_count += 1
elif r.op_name == "ipcq_copy":
# The inbound DMA records ipcq_copy (one per send/recv pair).
# Count it as both a send and a recv side so the row's
# ipcq_send_count and ipcq_recv_count are non-zero even when
# the engine logs the collective via the inbound copy alone.
ipcq_send_count += 1
ipcq_recv_count += 1
return {
"gemm_count": gemm_count,
"ipcq_send_count": ipcq_send_count,
"ipcq_recv_count": ipcq_recv_count,
"dma_read_count": dma_read_count,
"dma_write_count": dma_write_count,
}
def _run_panel(panel: str, topology: str) -> dict:
"""Run one panel via a fresh engine; return its row dict."""
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology(topology)
result = run_bench(
topology=topo, bench_fn=_make_bench_fn(panel),
device=resolve_device(None),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True,
),
)
if not result.completion.ok:
raise RuntimeError(
f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}"
)
_, _, _, n_ranks, _ = _PANEL_DISPATCH[panel]
return {
"panel": panel,
"n_ranks": n_ranks,
"op_log_summary": _summarize_op_log(result.engine.op_log),
}
# ── Bench entry ────────────────────────────────────────────────────────
@bench(
name="milestone-gqa-llama70b",
description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).",
)
def run(torch) -> None:
"""Drive the four GQA panels at validation scale; write sweep.json.
v1 only supports validation mode (``GQA_VALIDATION=1``). Headline
mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3.
A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4
"at least one request" contract is satisfied even though the panels
drive their own engines.
"""
if not os.environ.get("GQA_VALIDATION"):
raise RuntimeError(
"milestone-gqa-llama70b v1 only supports validation mode. "
"Set GQA_VALIDATION=1 to run. Headline mode is deferred to "
"sub-cycle 4b/4c per ADR-0057 D3."
)
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
sweep = {
"version": 1,
"validation_scale": True,
"panels": list(_PANELS_V1),
"config": {
"S_q_prefill": _S_Q_PREFILL,
"S_kv_per_rank": _S_KV_PER_RANK,
"h_q": _H_Q,
"h_kv": _H_KV,
"d_head": _D_HEAD,
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
},
"rows": rows,
}
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
# Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out).
torch.zeros(
(1, 1), dtype="f16",
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
name="milestone_gqa_sentinel",
)
@@ -0,0 +1,222 @@
"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels).
ADR-0057 (Proposed) defines an eval bench that drives both attention kernels
(ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log
summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``.
v1 (sub-cycle 4a) covers ALL FOUR panels:
Panel name in JSON / test Study label SFR install used
─────────────────────────────────────────────────────────────────────────────
single_user_prefill TL configure_sfr_intracube_pe_ring
multi_user_prefill TR configure_sfr_intercube_multisip
single_user_decode BL configure_sfr_intracube_pe_ring
multi_user_decode BR configure_sfr_intercube_multisip
single_user_* panels became runnable after sub-cycle 4-pre delivered the
new SFR install function (ADR-0058).
In Phase 1 the bench module does not exist; pytest collection fails with
``ModuleNotFoundError``. Once Phase 2 lands the bench module, every
assertion below must pass.
Assertions:
- Bench is registered as ``milestone-gqa-llama70b``.
- A validation run (``GQA_VALIDATION=1``) completes ok via run_bench.
- sweep.json conforms to ADR-0057 D7 (v1 schema).
- All four panel rows present with sane op_log summaries.
- Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank).
- Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs).
"""
from __future__ import annotations
import json
from kernbench.benches.registry import resolve
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
# Production module (Phase 2 deliverable; absent in Phase 1).
import kernbench.benches.milestone_gqa_llama70b as gqa_bench
BENCH_NAME = "milestone-gqa-llama70b"
PANELS_V1 = (
"single_user_prefill",
"multi_user_prefill",
"single_user_decode",
"multi_user_decode",
)
SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode")
MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode")
PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill")
DECODE_PANELS = ("single_user_decode", "multi_user_decode")
def _run_validation():
"""Drive the bench through run_bench at validation scale."""
topo = resolve_topology("topology.yaml")
return run_bench(
topology=topo,
bench_fn=resolve(BENCH_NAME).run,
device=resolve_device(None),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True,
),
)
# ── Registration ────────────────────────────────────────────────────────
def test_bench_registered():
spec = resolve(BENCH_NAME)
assert spec.name == BENCH_NAME
assert callable(spec.run)
assert spec.description.strip(), "description must be non-empty"
# ── Validation run end-to-end ────────────────────────────────────────────
def test_validation_run_completes_ok(monkeypatch):
monkeypatch.setenv("GQA_VALIDATION", "1")
result = _run_validation()
assert result.completion.ok, (
f"validation run failed: {result.completion}"
)
# ── JSON shape (ADR-0057 D7 amended for 4 panels) ──────────────────────
def _sweep_json(monkeypatch) -> dict:
"""Run the bench (if needed) and return the parsed sweep.json."""
monkeypatch.setenv("GQA_VALIDATION", "1")
out = gqa_bench._OUTPUT_DIR / "sweep.json"
if not out.exists():
result = _run_validation()
assert result.completion.ok, result.completion
assert out.exists(), f"missing {out}"
return json.loads(out.read_text())
def test_sweep_json_has_v1_schema(monkeypatch):
data = _sweep_json(monkeypatch)
assert data["version"] == 1
assert data["validation_scale"] is True
assert isinstance(data["panels"], list)
assert isinstance(data["config"], dict)
assert isinstance(data["rows"], list)
def test_sweep_json_panels_are_all_four(monkeypatch):
"""v1 covers all four panels — single_user_{prefill,decode} +
multi_user_{prefill,decode}. Q/cube sweep deferred to 4b."""
data = _sweep_json(monkeypatch)
assert set(data["panels"]) == set(PANELS_V1)
def test_sweep_json_config_matches_adr0057_d4(monkeypatch):
"""Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget).
S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the
simulator's 1 MB per-PE TCM kernel scratch (topology.yaml
``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the
bump-allocated handle outputs of softmax/exp/dot/sum chains over
n_ranks ring steps. Headline-scale runs in 4c will lift these into a
config-driven sweep.
"""
data = _sweep_json(monkeypatch)
cfg = data["config"]
assert cfg["S_q_prefill"] == 16
assert cfg["S_kv_per_rank"] == 16
# v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view
# (which is symbolic and does not survive MemoryStore's nbytes check
# under simulator data execution). Real GQA (h_q > h_kv) is deferred
# to sub-cycle 4c (headline scale).
assert cfg["h_q"] == 1
assert cfg["h_kv"] == 1
assert cfg["d_head"] == 64
# single_user_* uses the 8 PEs in one cube as ring ranks.
assert cfg["n_ranks_single_user"] == 8
# multi_user_* uses cube-level ring; validation uses 4 cubes.
assert cfg["n_ranks_multi_user"] == 4
def test_sweep_json_has_one_row_per_panel(monkeypatch):
data = _sweep_json(monkeypatch)
assert len(data["rows"]) == 4
panels_in_rows = {r["panel"] for r in data["rows"]}
assert panels_in_rows == set(PANELS_V1)
# ── Per-row op_log summary sanity (ADR-0057 D7) ─────────────────────────
def _row(rows: list, panel: str) -> dict:
matches = [r for r in rows if r["panel"] == panel]
assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}"
return matches[0]
def _assert_sane_summary(row: dict) -> None:
s = row["op_log_summary"]
panel = row["panel"]
assert s["gemm_count"] > 0, f"{panel} must run GEMMs"
assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)"
assert s["ipcq_recv_count"] > 0, f"{panel} must recv"
assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads"
assert s["dma_write_count"] >= 1, f"{panel}: final O store"
def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "single_user_prefill"))
def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "multi_user_prefill"))
def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "single_user_decode"))
def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch):
data = _sweep_json(monkeypatch)
_assert_sane_summary(_row(data["rows"], "multi_user_decode"))
# ── Architectural invariant: decode = one-shot per rank ─────────────────
def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
"""ADR-0056 D3: decode kernel does ONE local partial-attention pass per
rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the
total = 2 × n_ranks. This distinguishes decode from prefill where each
ring step has 2 GEMMs and the total scales as 2 × n_ranks²."""
data = _sweep_json(monkeypatch)
row = _row(data["rows"], "single_user_decode")
n_ranks = row["n_ranks"]
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
f"got {row['op_log_summary']['gemm_count']}"
)
def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
"""Same one-shot invariant as single_user_decode — the kernel is the
same; what differs is who the ranks are (cubes vs PEs)."""
data = _sweep_json(monkeypatch)
row = _row(data["rows"], "multi_user_decode")
n_ranks = row["n_ranks"]
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
f"got {row['op_log_summary']['gemm_count']}"
)