"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1). Self-contained eval bench (ADR-0054). Drives the four panels of the GQA Llama-70B sharding study through ``run_bench`` with ``enable_data=True``, harvests op_log summaries, and writes JSON into ``benches/1H_milestone_output/gqa/sweep.json``. v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale: Panel name in JSON / test Study label SFR install used ───────────────────────────────────────────────────────────────────── single_user_prefill TL configure_sfr_intracube_pe_ring multi_user_prefill TR configure_sfr_intercube_multisip single_user_decode BL configure_sfr_intracube_pe_ring multi_user_decode BR configure_sfr_intercube_multisip Kernels use the mesh-native variants (ADR-0059), invoked with the ``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user cube-level rings — the latter also gates 7 of every 8 PEs to silence). Validation-scale config (ADR-0057 D4) — kept small so the simulator's 1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps. """ from __future__ import annotations import json import os from pathlib import Path from typing import Any from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel from kernbench.benches.registry import bench from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.ccl.sfr_config import ( configure_sfr_intercube_multisip, configure_sfr_intracube_pe_ring, ) from kernbench.policy.placement.dp import DPPolicy _OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa" _SWEEP_JSON = _OUTPUT_DIR / "sweep.json" # ── Validation-scale config (ADR-0057 D4) ───────────────────────────── _S_Q_PREFILL = 16 _S_Q_DECODE = 1 _S_KV_PER_RANK = 16 _H_Q = 1 _H_KV = 1 _D_HEAD = 64 _N_RANKS_SINGLE_USER = 8 _N_RANKS_MULTI_USER = 4 _DTYPE = "f16" _PANELS_V1 = ( "single_user_prefill", "multi_user_prefill", "single_user_decode", "multi_user_decode", ) # Panel → (kernel, SFR install, S_q, n_ranks, rank_axis) _PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = { "single_user_prefill": ( attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring, _S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0, ), "multi_user_prefill": ( attention_mesh_kv_kernel, configure_sfr_intercube_multisip, _S_Q_PREFILL, _N_RANKS_MULTI_USER, 1, ), "single_user_decode": ( attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring, _S_Q_DECODE, _N_RANKS_SINGLE_USER, 0, ), "multi_user_decode": ( attention_mesh_mlo_kernel, configure_sfr_intercube_multisip, _S_Q_DECODE, _N_RANKS_MULTI_USER, 1, ), } # ── Per-panel bench fn ───────────────────────────────────────────────── def _make_bench_fn(panel: str): kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel] is_multi_user = panel.startswith("multi_user_") def _bench_fn(ctx): sfr_install( ctx.engine, ctx.spec, resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), ) if is_multi_user: dp_full = DPPolicy( cube="replicate", pe="replicate", num_cubes=n_ranks, num_pes=8, ) dp_kv = DPPolicy( cube="row_wise", pe="replicate", num_cubes=n_ranks, num_pes=8, ) else: dp_full = DPPolicy( cube="replicate", pe="replicate", num_cubes=1, num_pes=n_ranks, ) dp_kv = DPPolicy( cube="replicate", pe="row_wise", num_cubes=1, num_pes=n_ranks, ) q = ctx.zeros((S_q, _H_Q * _D_HEAD), dtype=_DTYPE, dp=dp_full, name=f"{panel}_q") k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD), dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k") v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD), dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v") o = ctx.empty((S_q, _H_Q * _D_HEAD), dtype=_DTYPE, dp=dp_full, name=f"{panel}_o") # rank_axis is a positional arg; _auto_dim_remap=False keeps # d_head=64 from colliding with the multi_user K's global M=64. ctx.launch( f"{panel}_mesh", kernel, q, k, v, o, S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks, rank_axis, _auto_dim_remap=False, ) return _bench_fn # ── Op-log summary harvest ───────────────────────────────────────────── def _summarize_op_log(op_log) -> dict[str, int]: """Counts per ADR-0057 D7 op_log_summary contract.""" gemm_count = 0 ipcq_send_count = 0 ipcq_recv_count = 0 dma_read_count = 0 dma_write_count = 0 for r in op_log: if r.op_kind == "gemm": gemm_count += 1 elif r.op_name == "dma_read": dma_read_count += 1 elif r.op_name == "dma_write": dma_write_count += 1 elif r.op_name == "ipcq_send": ipcq_send_count += 1 elif r.op_name == "ipcq_recv": ipcq_recv_count += 1 elif r.op_name == "ipcq_copy": # The inbound DMA records ipcq_copy (one per send/recv pair). # Count it as both a send and a recv side so the row's # ipcq_send_count and ipcq_recv_count are non-zero even when # the engine logs the collective via the inbound copy alone. ipcq_send_count += 1 ipcq_recv_count += 1 return { "gemm_count": gemm_count, "ipcq_send_count": ipcq_send_count, "ipcq_recv_count": ipcq_recv_count, "dma_read_count": dma_read_count, "dma_write_count": dma_write_count, } def _run_panel(panel: str, topology: str) -> dict: """Run one panel via a fresh engine; return its row dict.""" from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology topo = resolve_topology(topology) result = run_bench( topology=topo, bench_fn=_make_bench_fn(panel), device=resolve_device(None), engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True, ), ) if not result.completion.ok: raise RuntimeError( f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}" ) _, _, _, n_ranks, _ = _PANEL_DISPATCH[panel] return { "panel": panel, "n_ranks": n_ranks, "op_log_summary": _summarize_op_log(result.engine.op_log), } # ── Bench entry ──────────────────────────────────────────────────────── @bench( name="milestone-gqa-llama70b", description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).", ) def run(torch) -> None: """Drive the four GQA panels at validation scale; write sweep.json. v1 only supports validation mode (``GQA_VALIDATION=1``). Headline mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3. A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4 "at least one request" contract is satisfied even though the panels drive their own engines. """ if not os.environ.get("GQA_VALIDATION"): raise RuntimeError( "milestone-gqa-llama70b v1 only supports validation mode. " "Set GQA_VALIDATION=1 to run. Headline mode is deferred to " "sub-cycle 4b/4c per ADR-0057 D3." ) _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml") rows = [_run_panel(panel, topology) for panel in _PANELS_V1] sweep = { "version": 1, "validation_scale": True, "panels": list(_PANELS_V1), "config": { "S_q_prefill": _S_Q_PREFILL, "S_kv_per_rank": _S_KV_PER_RANK, "h_q": _H_Q, "h_kv": _H_KV, "d_head": _D_HEAD, "n_ranks_single_user": _N_RANKS_SINGLE_USER, "n_ranks_multi_user": _N_RANKS_MULTI_USER, }, "rows": rows, } _SWEEP_JSON.write_text(json.dumps(sweep, indent=2)) print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}") # Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out). torch.zeros( (1, 1), dtype="f16", dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1), name="milestone_gqa_sentinel", )