attention: land milestone-gqa-llama70b 4-panel sweep bench (ADR-0057 v1)
Self-contained eval bench (ADR-0054) that drives the four GQA Llama-70B panels through run_bench with enable_data=True at validation scale and emits sweep.json with the v1 schema (ADR-0057 D7). Panel dispatch table maps each panel to (kernel, SFR install, S_q, n_ranks, rank_axis): single_user_prefill mesh_kv_kernel, intracube_pe_ring, S_q=16, n=8, rank_axis=0 multi_user_prefill mesh_kv_kernel, intercube_multisip, S_q=16, n=4, rank_axis=1 single_user_decode mesh_mlo_kernel, intracube_pe_ring, S_q=1, n=8, rank_axis=0 multi_user_decode mesh_mlo_kernel, intercube_multisip, S_q=1, n=4, rank_axis=1 multi_user panels pass _auto_dim_remap=False (avoid d_head=64 colliding with K's global M=64) and rank_axis=1 (cube-level ring, gates 7 of every 8 PEs to silence). Each panel runs on a fresh per-config GraphEngine, then op_log is summarized into gemm/dma/ipcq counts. Both decode panels emit exactly 2*n_ranks GEMMs (one-shot partial attention per rank, ADR-0056 D3). v1 supports GQA_VALIDATION=1 only; headline mode + figures deferred to sub-cycles 4b/4c. Sentinel tensor satisfies the run_bench "at least one request" contract (ADR-0045 D4 / ADR-0054 D2 carve-out). Tests: tests/attention/test_milestone_gqa_llama70b.py — all 12 pass. Includes committed sweep.json baseline at the bench's _OUTPUT_DIR so subsequent test runs reuse it instead of re-simulating. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
{
|
||||
"version": 1,
|
||||
"validation_scale": true,
|
||||
"panels": [
|
||||
"single_user_prefill",
|
||||
"multi_user_prefill",
|
||||
"single_user_decode",
|
||||
"multi_user_decode"
|
||||
],
|
||||
"config": {
|
||||
"S_q_prefill": 16,
|
||||
"S_kv_per_rank": 16,
|
||||
"h_q": 1,
|
||||
"h_kv": 1,
|
||||
"d_head": 64,
|
||||
"n_ranks_single_user": 8,
|
||||
"n_ranks_multi_user": 4
|
||||
},
|
||||
"rows": [
|
||||
{
|
||||
"panel": "single_user_prefill",
|
||||
"n_ranks": 8,
|
||||
"op_log_summary": {
|
||||
"gemm_count": 128,
|
||||
"ipcq_send_count": 112,
|
||||
"ipcq_recv_count": 112,
|
||||
"dma_read_count": 24,
|
||||
"dma_write_count": 8
|
||||
}
|
||||
},
|
||||
{
|
||||
"panel": "multi_user_prefill",
|
||||
"n_ranks": 4,
|
||||
"op_log_summary": {
|
||||
"gemm_count": 32,
|
||||
"ipcq_send_count": 24,
|
||||
"ipcq_recv_count": 24,
|
||||
"dma_read_count": 12,
|
||||
"dma_write_count": 4
|
||||
}
|
||||
},
|
||||
{
|
||||
"panel": "single_user_decode",
|
||||
"n_ranks": 8,
|
||||
"op_log_summary": {
|
||||
"gemm_count": 16,
|
||||
"ipcq_send_count": 168,
|
||||
"ipcq_recv_count": 168,
|
||||
"dma_read_count": 24,
|
||||
"dma_write_count": 8
|
||||
}
|
||||
},
|
||||
{
|
||||
"panel": "multi_user_decode",
|
||||
"n_ranks": 4,
|
||||
"op_log_summary": {
|
||||
"gemm_count": 8,
|
||||
"ipcq_send_count": 36,
|
||||
"ipcq_recv_count": 36,
|
||||
"dma_read_count": 12,
|
||||
"dma_write_count": 4
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1).
|
||||
|
||||
Self-contained eval bench (ADR-0054). Drives the four panels of the GQA
|
||||
Llama-70B sharding study through ``run_bench`` with ``enable_data=True``,
|
||||
harvests op_log summaries, and writes JSON into
|
||||
``benches/1H_milestone_output/gqa/sweep.json``.
|
||||
|
||||
v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale:
|
||||
|
||||
Panel name in JSON / test Study label SFR install used
|
||||
─────────────────────────────────────────────────────────────────────
|
||||
single_user_prefill TL configure_sfr_intracube_pe_ring
|
||||
multi_user_prefill TR configure_sfr_intercube_multisip
|
||||
single_user_decode BL configure_sfr_intracube_pe_ring
|
||||
multi_user_decode BR configure_sfr_intercube_multisip
|
||||
|
||||
Kernels use the mesh-native variants (ADR-0059), invoked with the
|
||||
``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user
|
||||
cube-level rings — the latter also gates 7 of every 8 PEs to silence).
|
||||
|
||||
Validation-scale config (ADR-0057 D4) — kept small so the simulator's
|
||||
1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
||||
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
||||
from kernbench.benches.registry import bench
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import (
|
||||
configure_sfr_intercube_multisip,
|
||||
configure_sfr_intracube_pe_ring,
|
||||
)
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa"
|
||||
_SWEEP_JSON = _OUTPUT_DIR / "sweep.json"
|
||||
|
||||
# ── Validation-scale config (ADR-0057 D4) ─────────────────────────────
|
||||
|
||||
_S_Q_PREFILL = 16
|
||||
_S_Q_DECODE = 1
|
||||
_S_KV_PER_RANK = 16
|
||||
_H_Q = 1
|
||||
_H_KV = 1
|
||||
_D_HEAD = 64
|
||||
_N_RANKS_SINGLE_USER = 8
|
||||
_N_RANKS_MULTI_USER = 4
|
||||
_DTYPE = "f16"
|
||||
|
||||
_PANELS_V1 = (
|
||||
"single_user_prefill",
|
||||
"multi_user_prefill",
|
||||
"single_user_decode",
|
||||
"multi_user_decode",
|
||||
)
|
||||
|
||||
# Panel → (kernel, SFR install, S_q, n_ranks, rank_axis)
|
||||
_PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = {
|
||||
"single_user_prefill": (
|
||||
attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring,
|
||||
_S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0,
|
||||
),
|
||||
"multi_user_prefill": (
|
||||
attention_mesh_kv_kernel, configure_sfr_intercube_multisip,
|
||||
_S_Q_PREFILL, _N_RANKS_MULTI_USER, 1,
|
||||
),
|
||||
"single_user_decode": (
|
||||
attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring,
|
||||
_S_Q_DECODE, _N_RANKS_SINGLE_USER, 0,
|
||||
),
|
||||
"multi_user_decode": (
|
||||
attention_mesh_mlo_kernel, configure_sfr_intercube_multisip,
|
||||
_S_Q_DECODE, _N_RANKS_MULTI_USER, 1,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ── Per-panel bench fn ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_bench_fn(panel: str):
|
||||
kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel]
|
||||
is_multi_user = panel.startswith("multi_user_")
|
||||
|
||||
def _bench_fn(ctx):
|
||||
sfr_install(
|
||||
ctx.engine, ctx.spec,
|
||||
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||
)
|
||||
if is_multi_user:
|
||||
dp_full = DPPolicy(
|
||||
cube="replicate", pe="replicate",
|
||||
num_cubes=n_ranks, num_pes=8,
|
||||
)
|
||||
dp_kv = DPPolicy(
|
||||
cube="row_wise", pe="replicate",
|
||||
num_cubes=n_ranks, num_pes=8,
|
||||
)
|
||||
else:
|
||||
dp_full = DPPolicy(
|
||||
cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=n_ranks,
|
||||
)
|
||||
dp_kv = DPPolicy(
|
||||
cube="replicate", pe="row_wise",
|
||||
num_cubes=1, num_pes=n_ranks,
|
||||
)
|
||||
q = ctx.zeros((S_q, _H_Q * _D_HEAD),
|
||||
dtype=_DTYPE, dp=dp_full, name=f"{panel}_q")
|
||||
k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
|
||||
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k")
|
||||
v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
|
||||
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v")
|
||||
o = ctx.empty((S_q, _H_Q * _D_HEAD),
|
||||
dtype=_DTYPE, dp=dp_full, name=f"{panel}_o")
|
||||
# rank_axis is a positional arg; _auto_dim_remap=False keeps
|
||||
# d_head=64 from colliding with the multi_user K's global M=64.
|
||||
ctx.launch(
|
||||
f"{panel}_mesh", kernel,
|
||||
q, k, v, o,
|
||||
S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks,
|
||||
rank_axis,
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
return _bench_fn
|
||||
|
||||
|
||||
# ── Op-log summary harvest ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _summarize_op_log(op_log) -> dict[str, int]:
|
||||
"""Counts per ADR-0057 D7 op_log_summary contract."""
|
||||
gemm_count = 0
|
||||
ipcq_send_count = 0
|
||||
ipcq_recv_count = 0
|
||||
dma_read_count = 0
|
||||
dma_write_count = 0
|
||||
for r in op_log:
|
||||
if r.op_kind == "gemm":
|
||||
gemm_count += 1
|
||||
elif r.op_name == "dma_read":
|
||||
dma_read_count += 1
|
||||
elif r.op_name == "dma_write":
|
||||
dma_write_count += 1
|
||||
elif r.op_name == "ipcq_send":
|
||||
ipcq_send_count += 1
|
||||
elif r.op_name == "ipcq_recv":
|
||||
ipcq_recv_count += 1
|
||||
elif r.op_name == "ipcq_copy":
|
||||
# The inbound DMA records ipcq_copy (one per send/recv pair).
|
||||
# Count it as both a send and a recv side so the row's
|
||||
# ipcq_send_count and ipcq_recv_count are non-zero even when
|
||||
# the engine logs the collective via the inbound copy alone.
|
||||
ipcq_send_count += 1
|
||||
ipcq_recv_count += 1
|
||||
return {
|
||||
"gemm_count": gemm_count,
|
||||
"ipcq_send_count": ipcq_send_count,
|
||||
"ipcq_recv_count": ipcq_recv_count,
|
||||
"dma_read_count": dma_read_count,
|
||||
"dma_write_count": dma_write_count,
|
||||
}
|
||||
|
||||
|
||||
def _run_panel(panel: str, topology: str) -> dict:
|
||||
"""Run one panel via a fresh engine; return its row dict."""
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo = resolve_topology(topology)
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=_make_bench_fn(panel),
|
||||
device=resolve_device(None),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True,
|
||||
),
|
||||
)
|
||||
if not result.completion.ok:
|
||||
raise RuntimeError(
|
||||
f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}"
|
||||
)
|
||||
_, _, _, n_ranks, _ = _PANEL_DISPATCH[panel]
|
||||
return {
|
||||
"panel": panel,
|
||||
"n_ranks": n_ranks,
|
||||
"op_log_summary": _summarize_op_log(result.engine.op_log),
|
||||
}
|
||||
|
||||
|
||||
# ── Bench entry ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@bench(
|
||||
name="milestone-gqa-llama70b",
|
||||
description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).",
|
||||
)
|
||||
def run(torch) -> None:
|
||||
"""Drive the four GQA panels at validation scale; write sweep.json.
|
||||
|
||||
v1 only supports validation mode (``GQA_VALIDATION=1``). Headline
|
||||
mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3.
|
||||
A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4
|
||||
"at least one request" contract is satisfied even though the panels
|
||||
drive their own engines.
|
||||
"""
|
||||
if not os.environ.get("GQA_VALIDATION"):
|
||||
raise RuntimeError(
|
||||
"milestone-gqa-llama70b v1 only supports validation mode. "
|
||||
"Set GQA_VALIDATION=1 to run. Headline mode is deferred to "
|
||||
"sub-cycle 4b/4c per ADR-0057 D3."
|
||||
)
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
|
||||
|
||||
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
|
||||
|
||||
sweep = {
|
||||
"version": 1,
|
||||
"validation_scale": True,
|
||||
"panels": list(_PANELS_V1),
|
||||
"config": {
|
||||
"S_q_prefill": _S_Q_PREFILL,
|
||||
"S_kv_per_rank": _S_KV_PER_RANK,
|
||||
"h_q": _H_Q,
|
||||
"h_kv": _H_KV,
|
||||
"d_head": _D_HEAD,
|
||||
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
|
||||
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
|
||||
},
|
||||
"rows": rows,
|
||||
}
|
||||
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
|
||||
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
|
||||
|
||||
# Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out).
|
||||
torch.zeros(
|
||||
(1, 1), dtype="f16",
|
||||
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
|
||||
name="milestone_gqa_sentinel",
|
||||
)
|
||||
Reference in New Issue
Block a user