Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b3ca532023 | |||
| e748a62264 |
Binary file not shown.
|
After Width: | Height: | Size: 33 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
@@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"validation_scale": true,
|
||||||
|
"panels": [
|
||||||
|
"single_user_prefill",
|
||||||
|
"multi_user_prefill",
|
||||||
|
"single_user_decode",
|
||||||
|
"multi_user_decode"
|
||||||
|
],
|
||||||
|
"config": {
|
||||||
|
"S_q_prefill": 16,
|
||||||
|
"S_kv_per_rank": 16,
|
||||||
|
"h_q": 1,
|
||||||
|
"h_kv": 1,
|
||||||
|
"d_head": 64,
|
||||||
|
"n_ranks_single_user": 8,
|
||||||
|
"n_ranks_multi_user": 4
|
||||||
|
},
|
||||||
|
"rows": [
|
||||||
|
{
|
||||||
|
"panel": "single_user_prefill",
|
||||||
|
"n_ranks": 8,
|
||||||
|
"op_log_summary": {
|
||||||
|
"gemm_count": 128,
|
||||||
|
"ipcq_send_count": 112,
|
||||||
|
"ipcq_recv_count": 112,
|
||||||
|
"dma_read_count": 24,
|
||||||
|
"dma_write_count": 8
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"panel": "multi_user_prefill",
|
||||||
|
"n_ranks": 4,
|
||||||
|
"op_log_summary": {
|
||||||
|
"gemm_count": 32,
|
||||||
|
"ipcq_send_count": 24,
|
||||||
|
"ipcq_recv_count": 24,
|
||||||
|
"dma_read_count": 12,
|
||||||
|
"dma_write_count": 4
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"panel": "single_user_decode",
|
||||||
|
"n_ranks": 8,
|
||||||
|
"op_log_summary": {
|
||||||
|
"gemm_count": 16,
|
||||||
|
"ipcq_send_count": 168,
|
||||||
|
"ipcq_recv_count": 168,
|
||||||
|
"dma_read_count": 24,
|
||||||
|
"dma_write_count": 8
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"panel": "multi_user_decode",
|
||||||
|
"n_ranks": 4,
|
||||||
|
"op_log_summary": {
|
||||||
|
"gemm_count": 8,
|
||||||
|
"ipcq_send_count": 36,
|
||||||
|
"ipcq_recv_count": 36,
|
||||||
|
"dma_read_count": 12,
|
||||||
|
"dma_write_count": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,427 @@
|
|||||||
|
"""milestone-gqa-llama70b bench: GQA Llama-70B 4-panel sweep (ADR-0057 v1).
|
||||||
|
|
||||||
|
Self-contained eval bench (ADR-0054). Drives the four panels of the GQA
|
||||||
|
Llama-70B sharding study through ``run_bench`` with ``enable_data=True``,
|
||||||
|
harvests op_log summaries, and writes JSON into
|
||||||
|
``benches/1H_milestone_output/gqa/sweep.json``.
|
||||||
|
|
||||||
|
v1 (sub-cycle 4a + 4c.0) covers all four panels at validation scale:
|
||||||
|
|
||||||
|
Panel name in JSON / test Study label SFR install used
|
||||||
|
─────────────────────────────────────────────────────────────────────
|
||||||
|
single_user_prefill TL configure_sfr_intracube_pe_ring
|
||||||
|
multi_user_prefill TR configure_sfr_intercube_multisip
|
||||||
|
single_user_decode BL configure_sfr_intracube_pe_ring
|
||||||
|
multi_user_decode BR configure_sfr_intercube_multisip
|
||||||
|
|
||||||
|
Kernels use the mesh-native variants (ADR-0059), invoked with the
|
||||||
|
``rank_axis`` kwarg (0 for single_user PE-level rings, 1 for multi_user
|
||||||
|
cube-level rings — the latter also gates 7 of every 8 PEs to silence).
|
||||||
|
|
||||||
|
Validation-scale config (ADR-0057 D4) — kept small so the simulator's
|
||||||
|
1 MB per-PE TCM scratch budget is not exhausted across n_ranks ring steps.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
||||||
|
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
||||||
|
from kernbench.benches.registry import bench
|
||||||
|
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||||
|
from kernbench.ccl.sfr_config import (
|
||||||
|
configure_sfr_intercube_multisip,
|
||||||
|
configure_sfr_intracube_pe_ring,
|
||||||
|
)
|
||||||
|
from kernbench.policy.placement.dp import DPPolicy
|
||||||
|
|
||||||
|
_OUTPUT_DIR = Path(__file__).resolve().parent / "1H_milestone_output" / "gqa"
|
||||||
|
_SWEEP_JSON = _OUTPUT_DIR / "sweep.json"
|
||||||
|
|
||||||
|
# ── Validation-scale config (ADR-0057 D4) ─────────────────────────────
|
||||||
|
|
||||||
|
_S_Q_PREFILL = 16
|
||||||
|
_S_Q_DECODE = 1
|
||||||
|
_S_KV_PER_RANK = 16
|
||||||
|
_H_Q = 1
|
||||||
|
_H_KV = 1
|
||||||
|
_D_HEAD = 64
|
||||||
|
_N_RANKS_SINGLE_USER = 8
|
||||||
|
_N_RANKS_MULTI_USER = 4
|
||||||
|
_DTYPE = "f16"
|
||||||
|
|
||||||
|
_PANELS_V1 = (
|
||||||
|
"single_user_prefill",
|
||||||
|
"multi_user_prefill",
|
||||||
|
"single_user_decode",
|
||||||
|
"multi_user_decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Panel → (kernel, SFR install, S_q, n_ranks, rank_axis)
|
||||||
|
_PANEL_DISPATCH: dict[str, tuple[Any, Any, int, int, int]] = {
|
||||||
|
"single_user_prefill": (
|
||||||
|
attention_mesh_kv_kernel, configure_sfr_intracube_pe_ring,
|
||||||
|
_S_Q_PREFILL, _N_RANKS_SINGLE_USER, 0,
|
||||||
|
),
|
||||||
|
"multi_user_prefill": (
|
||||||
|
attention_mesh_kv_kernel, configure_sfr_intercube_multisip,
|
||||||
|
_S_Q_PREFILL, _N_RANKS_MULTI_USER, 1,
|
||||||
|
),
|
||||||
|
"single_user_decode": (
|
||||||
|
attention_mesh_mlo_kernel, configure_sfr_intracube_pe_ring,
|
||||||
|
_S_Q_DECODE, _N_RANKS_SINGLE_USER, 0,
|
||||||
|
),
|
||||||
|
"multi_user_decode": (
|
||||||
|
attention_mesh_mlo_kernel, configure_sfr_intercube_multisip,
|
||||||
|
_S_Q_DECODE, _N_RANKS_MULTI_USER, 1,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-panel bench fn ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bench_fn(panel: str):
|
||||||
|
kernel, sfr_install, S_q, n_ranks, rank_axis = _PANEL_DISPATCH[panel]
|
||||||
|
is_multi_user = panel.startswith("multi_user_")
|
||||||
|
|
||||||
|
def _bench_fn(ctx):
|
||||||
|
sfr_install(
|
||||||
|
ctx.engine, ctx.spec,
|
||||||
|
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||||
|
)
|
||||||
|
if is_multi_user:
|
||||||
|
dp_full = DPPolicy(
|
||||||
|
cube="replicate", pe="replicate",
|
||||||
|
num_cubes=n_ranks, num_pes=8,
|
||||||
|
)
|
||||||
|
dp_kv = DPPolicy(
|
||||||
|
cube="row_wise", pe="replicate",
|
||||||
|
num_cubes=n_ranks, num_pes=8,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dp_full = DPPolicy(
|
||||||
|
cube="replicate", pe="replicate",
|
||||||
|
num_cubes=1, num_pes=n_ranks,
|
||||||
|
)
|
||||||
|
dp_kv = DPPolicy(
|
||||||
|
cube="replicate", pe="row_wise",
|
||||||
|
num_cubes=1, num_pes=n_ranks,
|
||||||
|
)
|
||||||
|
q = ctx.zeros((S_q, _H_Q * _D_HEAD),
|
||||||
|
dtype=_DTYPE, dp=dp_full, name=f"{panel}_q")
|
||||||
|
k = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
|
||||||
|
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_k")
|
||||||
|
v = ctx.zeros((_S_KV_PER_RANK * n_ranks, _H_KV * _D_HEAD),
|
||||||
|
dtype=_DTYPE, dp=dp_kv, name=f"{panel}_v")
|
||||||
|
o = ctx.empty((S_q, _H_Q * _D_HEAD),
|
||||||
|
dtype=_DTYPE, dp=dp_full, name=f"{panel}_o")
|
||||||
|
# rank_axis is a positional arg; _auto_dim_remap=False keeps
|
||||||
|
# d_head=64 from colliding with the multi_user K's global M=64.
|
||||||
|
ctx.launch(
|
||||||
|
f"{panel}_mesh", kernel,
|
||||||
|
q, k, v, o,
|
||||||
|
S_q, _S_KV_PER_RANK, _H_Q, _H_KV, _D_HEAD, n_ranks,
|
||||||
|
rank_axis,
|
||||||
|
_auto_dim_remap=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _bench_fn
|
||||||
|
|
||||||
|
|
||||||
|
# ── Op-log summary harvest ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_op_log(op_log) -> dict[str, int]:
|
||||||
|
"""Counts per ADR-0057 D7 op_log_summary contract."""
|
||||||
|
gemm_count = 0
|
||||||
|
ipcq_send_count = 0
|
||||||
|
ipcq_recv_count = 0
|
||||||
|
dma_read_count = 0
|
||||||
|
dma_write_count = 0
|
||||||
|
for r in op_log:
|
||||||
|
if r.op_kind == "gemm":
|
||||||
|
gemm_count += 1
|
||||||
|
elif r.op_name == "dma_read":
|
||||||
|
dma_read_count += 1
|
||||||
|
elif r.op_name == "dma_write":
|
||||||
|
dma_write_count += 1
|
||||||
|
elif r.op_name == "ipcq_send":
|
||||||
|
ipcq_send_count += 1
|
||||||
|
elif r.op_name == "ipcq_recv":
|
||||||
|
ipcq_recv_count += 1
|
||||||
|
elif r.op_name == "ipcq_copy":
|
||||||
|
# The inbound DMA records ipcq_copy (one per send/recv pair).
|
||||||
|
# Count it as both a send and a recv side so the row's
|
||||||
|
# ipcq_send_count and ipcq_recv_count are non-zero even when
|
||||||
|
# the engine logs the collective via the inbound copy alone.
|
||||||
|
ipcq_send_count += 1
|
||||||
|
ipcq_recv_count += 1
|
||||||
|
return {
|
||||||
|
"gemm_count": gemm_count,
|
||||||
|
"ipcq_send_count": ipcq_send_count,
|
||||||
|
"ipcq_recv_count": ipcq_recv_count,
|
||||||
|
"dma_read_count": dma_read_count,
|
||||||
|
"dma_write_count": dma_write_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _run_panel(panel: str, topology: str) -> dict:
|
||||||
|
"""Run one panel via a fresh engine; return its row dict."""
|
||||||
|
from kernbench.runtime_api.bench_runner import run_bench
|
||||||
|
from kernbench.runtime_api.types import resolve_device
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.topology.builder import resolve_topology
|
||||||
|
|
||||||
|
topo = resolve_topology(topology)
|
||||||
|
result = run_bench(
|
||||||
|
topology=topo, bench_fn=_make_bench_fn(panel),
|
||||||
|
device=resolve_device(None),
|
||||||
|
engine_factory=lambda t, d: GraphEngine(
|
||||||
|
getattr(t, "topology_obj", t), enable_data=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not result.completion.ok:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"milestone-gqa-llama70b panel {panel!r} failed: {result.completion}"
|
||||||
|
)
|
||||||
|
_, _, _, n_ranks, _ = _PANEL_DISPATCH[panel]
|
||||||
|
return {
|
||||||
|
"panel": panel,
|
||||||
|
"n_ranks": n_ranks,
|
||||||
|
"op_log_summary": _summarize_op_log(result.engine.op_log),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Figure renderers (sub-cycle 4c, 5 of 6 figures) ──────────────────
|
||||||
|
#
|
||||||
|
# Sixth figure ``gqa_scaling.png`` is deferred to after sub-cycle 4b
|
||||||
|
# lands the Q/cube ∈ {1, 2, 4} sweep on multi_user_* panels — it needs
|
||||||
|
# multiple sweep.json rows per multi_user panel to be meaningful.
|
||||||
|
|
||||||
|
_OP_LOG_KEYS = (
|
||||||
|
"gemm_count",
|
||||||
|
"ipcq_send_count",
|
||||||
|
"ipcq_recv_count",
|
||||||
|
"dma_read_count",
|
||||||
|
"dma_write_count",
|
||||||
|
)
|
||||||
|
_OP_LOG_DISPLAY = {
|
||||||
|
"gemm_count": "GEMM",
|
||||||
|
"ipcq_send_count": "IPCQ send",
|
||||||
|
"ipcq_recv_count": "IPCQ recv",
|
||||||
|
"dma_read_count": "DMA read",
|
||||||
|
"dma_write_count": "DMA write",
|
||||||
|
}
|
||||||
|
_OP_LOG_COLORS = {
|
||||||
|
"gemm_count": "#F59E0B",
|
||||||
|
"ipcq_send_count": "#3B82F6",
|
||||||
|
"ipcq_recv_count": "#10B981",
|
||||||
|
"dma_read_count": "#A855F7",
|
||||||
|
"dma_write_count": "#EF4444",
|
||||||
|
}
|
||||||
|
_PANEL_DISPLAY = {
|
||||||
|
"single_user_prefill": "single_user / prefill",
|
||||||
|
"multi_user_prefill": "multi_user / prefill",
|
||||||
|
"single_user_decode": "single_user / decode",
|
||||||
|
"multi_user_decode": "multi_user / decode",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sweep_data(sweep_json: Path | str) -> dict:
|
||||||
|
sweep_json = Path(sweep_json)
|
||||||
|
if not sweep_json.exists():
|
||||||
|
return {"rows": [], "config": {}, "panels": []}
|
||||||
|
return json.loads(sweep_json.read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def _row_for(rows: list, panel: str) -> dict | None:
|
||||||
|
for r in rows:
|
||||||
|
if r.get("panel") == panel:
|
||||||
|
return r
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def emit_panel_op_log_summary(
|
||||||
|
panel: str,
|
||||||
|
sweep_json: Path | str = _SWEEP_JSON,
|
||||||
|
out_dir: Path | str = _OUTPUT_DIR,
|
||||||
|
) -> str | None:
|
||||||
|
"""One bar chart of the 5 op_log counts for ``panel``.
|
||||||
|
|
||||||
|
Returns the written PNG path, or ``None`` when sweep.json is empty
|
||||||
|
or the requested panel is absent.
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
data = _load_sweep_data(sweep_json)
|
||||||
|
row = _row_for(data.get("rows", []), panel)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
summary = row.get("op_log_summary", {})
|
||||||
|
n_ranks = row.get("n_ranks")
|
||||||
|
|
||||||
|
labels = [_OP_LOG_DISPLAY[k] for k in _OP_LOG_KEYS]
|
||||||
|
values = [summary.get(k, 0) for k in _OP_LOG_KEYS]
|
||||||
|
colors = [_OP_LOG_COLORS[k] for k in _OP_LOG_KEYS]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5))
|
||||||
|
bars = ax.bar(labels, values, color=colors)
|
||||||
|
for b, v in zip(bars, values):
|
||||||
|
ax.text(b.get_x() + b.get_width() / 2, b.get_height(),
|
||||||
|
f"{int(v)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
ax.set_title(
|
||||||
|
f"{_PANEL_DISPLAY.get(panel, panel)} (n_ranks={n_ranks})",
|
||||||
|
fontsize=12, fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_ylabel("count")
|
||||||
|
ax.grid(True, axis="y", alpha=0.3)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out = out_dir / f"gqa_op_log_{panel}.png"
|
||||||
|
fig.savefig(out, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
return str(out)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_gqa_comparison(
|
||||||
|
sweep_json: Path | str = _SWEEP_JSON,
|
||||||
|
out_dir: Path | str = _OUTPUT_DIR,
|
||||||
|
) -> str | None:
|
||||||
|
"""Grouped-bar chart comparing the 5 op_log counts across all panels."""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
data = _load_sweep_data(sweep_json)
|
||||||
|
panels_in = data.get("panels") or list(_PANELS_V1)
|
||||||
|
rows = data.get("rows", [])
|
||||||
|
panels = [p for p in panels_in if _row_for(rows, p) is not None]
|
||||||
|
if not panels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
n_groups = len(panels)
|
||||||
|
n_series = len(_OP_LOG_KEYS)
|
||||||
|
x = np.arange(n_groups)
|
||||||
|
width = 0.8 / n_series
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(11, 6))
|
||||||
|
for i, key in enumerate(_OP_LOG_KEYS):
|
||||||
|
offset = (i - (n_series - 1) / 2) * width
|
||||||
|
vals = [_row_for(rows, p)["op_log_summary"].get(key, 0)
|
||||||
|
for p in panels]
|
||||||
|
ax.bar(x + offset, vals, width,
|
||||||
|
label=_OP_LOG_DISPLAY[key], color=_OP_LOG_COLORS[key])
|
||||||
|
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(
|
||||||
|
[f"{_PANEL_DISPLAY.get(p, p)}\n(n_ranks={_row_for(rows, p)['n_ranks']})"
|
||||||
|
for p in panels],
|
||||||
|
fontsize=8,
|
||||||
|
)
|
||||||
|
ax.set_ylabel("count")
|
||||||
|
ax.set_title("GQA Llama-70B — op_log summary across panels",
|
||||||
|
fontsize=13, fontweight="bold")
|
||||||
|
ax.legend(fontsize=8, loc="upper right")
|
||||||
|
ax.grid(True, axis="y", alpha=0.3)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out = out_dir / "gqa_comparison.png"
|
||||||
|
fig.savefig(out, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
return str(out)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_all_gqa_plots(
|
||||||
|
sweep_json: Path | str = _SWEEP_JSON,
|
||||||
|
out_dir: Path | str = _OUTPUT_DIR,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Render all 5 in-scope figures and return the written paths.
|
||||||
|
|
||||||
|
Sub-cycle 4c v1 emits 5 of the 6 figures ADR-0057 D3 lists; the
|
||||||
|
6th (gqa_scaling.png) needs sub-cycle 4b's Q/cube sweep data.
|
||||||
|
"""
|
||||||
|
paths: list[str] = []
|
||||||
|
for panel in _PANELS_V1:
|
||||||
|
p = emit_panel_op_log_summary(panel, sweep_json, out_dir)
|
||||||
|
if p is not None:
|
||||||
|
paths.append(p)
|
||||||
|
comp = emit_gqa_comparison(sweep_json, out_dir)
|
||||||
|
if comp is not None:
|
||||||
|
paths.append(comp)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
# ── Bench entry ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@bench(
|
||||||
|
name="milestone-gqa-llama70b",
|
||||||
|
description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).",
|
||||||
|
)
|
||||||
|
def run(torch) -> None:
|
||||||
|
"""Drive the four GQA panels at validation scale; write sweep.json and figures.
|
||||||
|
|
||||||
|
Modes (mutually exclusive):
|
||||||
|
MILESTONE_FAST=1 Skip the sweep; re-render figures from the
|
||||||
|
committed sweep.json. Seconds, no simulator.
|
||||||
|
GQA_VALIDATION=1 Run the four-panel validation sweep + figures.
|
||||||
|
~1-2h on the full simulator.
|
||||||
|
|
||||||
|
Headline-scale mode is deferred to sub-cycle 4c (figures landed
|
||||||
|
here; headline-scale + scaling figure await sub-cycle 4b).
|
||||||
|
A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4
|
||||||
|
"at least one request" contract is satisfied even when the panels
|
||||||
|
are skipped via MILESTONE_FAST=1.
|
||||||
|
"""
|
||||||
|
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
fast = bool(os.environ.get("MILESTONE_FAST"))
|
||||||
|
if not fast and not os.environ.get("GQA_VALIDATION"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"milestone-gqa-llama70b v1 needs GQA_VALIDATION=1 (run the "
|
||||||
|
"sweep) or MILESTONE_FAST=1 (reuse committed sweep.json). "
|
||||||
|
"Headline mode is deferred to sub-cycle 4b/4c per ADR-0057 D3."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not fast:
|
||||||
|
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
|
||||||
|
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
|
||||||
|
sweep = {
|
||||||
|
"version": 1,
|
||||||
|
"validation_scale": True,
|
||||||
|
"panels": list(_PANELS_V1),
|
||||||
|
"config": {
|
||||||
|
"S_q_prefill": _S_Q_PREFILL,
|
||||||
|
"S_kv_per_rank": _S_KV_PER_RANK,
|
||||||
|
"h_q": _H_Q,
|
||||||
|
"h_kv": _H_KV,
|
||||||
|
"d_head": _D_HEAD,
|
||||||
|
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
|
||||||
|
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
|
||||||
|
},
|
||||||
|
"rows": rows,
|
||||||
|
}
|
||||||
|
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
|
||||||
|
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
|
||||||
|
elif not _SWEEP_JSON.exists():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"MILESTONE_FAST=1 requires {_SWEEP_JSON} to exist; "
|
||||||
|
"run with GQA_VALIDATION=1 once to seed it."
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = emit_all_gqa_plots()
|
||||||
|
print(f" milestone-gqa-llama70b: {len(paths)} figures -> {_OUTPUT_DIR} "
|
||||||
|
f"(fast={fast})")
|
||||||
|
|
||||||
|
# Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out).
|
||||||
|
torch.zeros(
|
||||||
|
(1, 1), dtype="f16",
|
||||||
|
dp=DPPolicy(cube="row_wise", pe="replicate", num_cubes=1, num_pes=1),
|
||||||
|
name="milestone_gqa_sentinel",
|
||||||
|
)
|
||||||
@@ -0,0 +1,222 @@
|
|||||||
|
"""Phase 1 spec test for ``milestone-gqa-llama70b`` bench (sub-cycle 4a, all 4 panels).
|
||||||
|
|
||||||
|
ADR-0057 (Proposed) defines an eval bench that drives both attention kernels
|
||||||
|
(ADR-0055 ring-K/V, ADR-0056 allreduce-mlo) and emits per-panel op_log
|
||||||
|
summaries into ``src/kernbench/benches/1H_milestone_output/gqa/sweep.json``.
|
||||||
|
|
||||||
|
v1 (sub-cycle 4a) covers ALL FOUR panels:
|
||||||
|
|
||||||
|
Panel name in JSON / test Study label SFR install used
|
||||||
|
─────────────────────────────────────────────────────────────────────────────
|
||||||
|
single_user_prefill TL configure_sfr_intracube_pe_ring
|
||||||
|
multi_user_prefill TR configure_sfr_intercube_multisip
|
||||||
|
single_user_decode BL configure_sfr_intracube_pe_ring
|
||||||
|
multi_user_decode BR configure_sfr_intercube_multisip
|
||||||
|
|
||||||
|
single_user_* panels became runnable after sub-cycle 4-pre delivered the
|
||||||
|
new SFR install function (ADR-0058).
|
||||||
|
|
||||||
|
In Phase 1 the bench module does not exist; pytest collection fails with
|
||||||
|
``ModuleNotFoundError``. Once Phase 2 lands the bench module, every
|
||||||
|
assertion below must pass.
|
||||||
|
|
||||||
|
Assertions:
|
||||||
|
- Bench is registered as ``milestone-gqa-llama70b``.
|
||||||
|
- A validation run (``GQA_VALIDATION=1``) completes ok via run_bench.
|
||||||
|
- sweep.json conforms to ADR-0057 D7 (v1 schema).
|
||||||
|
- All four panel rows present with sane op_log summaries.
|
||||||
|
- Both decode panels have gemm_count = 2 × n_ranks (one-shot per rank).
|
||||||
|
- Both prefill panels have gemm_count = 2 × n_ranks² (per-step GEMMs).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from kernbench.benches.registry import resolve
|
||||||
|
from kernbench.runtime_api.bench_runner import run_bench
|
||||||
|
from kernbench.runtime_api.types import resolve_device
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.topology.builder import resolve_topology
|
||||||
|
|
||||||
|
# Production module (Phase 2 deliverable; absent in Phase 1).
|
||||||
|
import kernbench.benches.milestone_gqa_llama70b as gqa_bench
|
||||||
|
|
||||||
|
|
||||||
|
BENCH_NAME = "milestone-gqa-llama70b"
|
||||||
|
|
||||||
|
PANELS_V1 = (
|
||||||
|
"single_user_prefill",
|
||||||
|
"multi_user_prefill",
|
||||||
|
"single_user_decode",
|
||||||
|
"multi_user_decode",
|
||||||
|
)
|
||||||
|
SINGLE_USER_PANELS = ("single_user_prefill", "single_user_decode")
|
||||||
|
MULTI_USER_PANELS = ("multi_user_prefill", "multi_user_decode")
|
||||||
|
PREFILL_PANELS = ("single_user_prefill", "multi_user_prefill")
|
||||||
|
DECODE_PANELS = ("single_user_decode", "multi_user_decode")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_validation():
|
||||||
|
"""Drive the bench through run_bench at validation scale."""
|
||||||
|
topo = resolve_topology("topology.yaml")
|
||||||
|
return run_bench(
|
||||||
|
topology=topo,
|
||||||
|
bench_fn=resolve(BENCH_NAME).run,
|
||||||
|
device=resolve_device(None),
|
||||||
|
engine_factory=lambda t, d: GraphEngine(
|
||||||
|
getattr(t, "topology_obj", t), enable_data=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_bench_registered():
|
||||||
|
spec = resolve(BENCH_NAME)
|
||||||
|
assert spec.name == BENCH_NAME
|
||||||
|
assert callable(spec.run)
|
||||||
|
assert spec.description.strip(), "description must be non-empty"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Validation run end-to-end ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_run_completes_ok(monkeypatch):
|
||||||
|
monkeypatch.setenv("GQA_VALIDATION", "1")
|
||||||
|
result = _run_validation()
|
||||||
|
assert result.completion.ok, (
|
||||||
|
f"validation run failed: {result.completion}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSON shape (ADR-0057 D7 amended for 4 panels) ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _sweep_json(monkeypatch) -> dict:
|
||||||
|
"""Run the bench (if needed) and return the parsed sweep.json."""
|
||||||
|
monkeypatch.setenv("GQA_VALIDATION", "1")
|
||||||
|
out = gqa_bench._OUTPUT_DIR / "sweep.json"
|
||||||
|
if not out.exists():
|
||||||
|
result = _run_validation()
|
||||||
|
assert result.completion.ok, result.completion
|
||||||
|
assert out.exists(), f"missing {out}"
|
||||||
|
return json.loads(out.read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_json_has_v1_schema(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
assert data["version"] == 1
|
||||||
|
assert data["validation_scale"] is True
|
||||||
|
assert isinstance(data["panels"], list)
|
||||||
|
assert isinstance(data["config"], dict)
|
||||||
|
assert isinstance(data["rows"], list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_json_panels_are_all_four(monkeypatch):
|
||||||
|
"""v1 covers all four panels — single_user_{prefill,decode} +
|
||||||
|
multi_user_{prefill,decode}. Q/cube sweep deferred to 4b."""
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
assert set(data["panels"]) == set(PANELS_V1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_json_config_matches_adr0057_d4(monkeypatch):
|
||||||
|
"""Validation-scale config per ADR-0057 D4 (amended for 4 panels + scratch budget).
|
||||||
|
|
||||||
|
S_q_prefill and S_kv_per_rank are deliberately small (16 each) so the
|
||||||
|
simulator's 1 MB per-PE TCM kernel scratch (topology.yaml
|
||||||
|
``pe_tcm.kernel_scratch_mb: 1``) is not exhausted by the
|
||||||
|
bump-allocated handle outputs of softmax/exp/dot/sum chains over
|
||||||
|
n_ranks ring steps. Headline-scale runs in 4c will lift these into a
|
||||||
|
config-driven sweep.
|
||||||
|
"""
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
cfg = data["config"]
|
||||||
|
assert cfg["S_q_prefill"] == 16
|
||||||
|
assert cfg["S_kv_per_rank"] == 16
|
||||||
|
# v1 uses h_q == h_kv == 1 to avoid ADR-0055 D3's GQA broadcast view
|
||||||
|
# (which is symbolic and does not survive MemoryStore's nbytes check
|
||||||
|
# under simulator data execution). Real GQA (h_q > h_kv) is deferred
|
||||||
|
# to sub-cycle 4c (headline scale).
|
||||||
|
assert cfg["h_q"] == 1
|
||||||
|
assert cfg["h_kv"] == 1
|
||||||
|
assert cfg["d_head"] == 64
|
||||||
|
# single_user_* uses the 8 PEs in one cube as ring ranks.
|
||||||
|
assert cfg["n_ranks_single_user"] == 8
|
||||||
|
# multi_user_* uses cube-level ring; validation uses 4 cubes.
|
||||||
|
assert cfg["n_ranks_multi_user"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_json_has_one_row_per_panel(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
assert len(data["rows"]) == 4
|
||||||
|
panels_in_rows = {r["panel"] for r in data["rows"]}
|
||||||
|
assert panels_in_rows == set(PANELS_V1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-row op_log summary sanity (ADR-0057 D7) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _row(rows: list, panel: str) -> dict:
|
||||||
|
matches = [r for r in rows if r["panel"] == panel]
|
||||||
|
assert len(matches) == 1, f"expected exactly one {panel} row; got {len(matches)}"
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_sane_summary(row: dict) -> None:
|
||||||
|
s = row["op_log_summary"]
|
||||||
|
panel = row["panel"]
|
||||||
|
assert s["gemm_count"] > 0, f"{panel} must run GEMMs"
|
||||||
|
assert s["ipcq_send_count"] > 0, f"{panel} must send (ring/allreduce phase)"
|
||||||
|
assert s["ipcq_recv_count"] > 0, f"{panel} must recv"
|
||||||
|
assert s["dma_read_count"] >= 3, f"{panel}: Q + K + V loads"
|
||||||
|
assert s["dma_write_count"] >= 1, f"{panel}: final O store"
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_user_prefill_row_has_sane_op_log_summary(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
_assert_sane_summary(_row(data["rows"], "single_user_prefill"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_user_prefill_row_has_sane_op_log_summary(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
_assert_sane_summary(_row(data["rows"], "multi_user_prefill"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_user_decode_row_has_sane_op_log_summary(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
_assert_sane_summary(_row(data["rows"], "single_user_decode"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_user_decode_row_has_sane_op_log_summary(monkeypatch):
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
_assert_sane_summary(_row(data["rows"], "multi_user_decode"))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Architectural invariant: decode = one-shot per rank ─────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
|
||||||
|
"""ADR-0056 D3: decode kernel does ONE local partial-attention pass per
|
||||||
|
rank → exactly 2 GEMMs per rank (Q·K^T + S·V). With n_ranks ranks the
|
||||||
|
total = 2 × n_ranks. This distinguishes decode from prefill where each
|
||||||
|
ring step has 2 GEMMs and the total scales as 2 × n_ranks²."""
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
row = _row(data["rows"], "single_user_decode")
|
||||||
|
n_ranks = row["n_ranks"]
|
||||||
|
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
|
||||||
|
f"single_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
|
||||||
|
f"got {row['op_log_summary']['gemm_count']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_user_decode_gemm_count_is_exactly_2_per_rank(monkeypatch):
|
||||||
|
"""Same one-shot invariant as single_user_decode — the kernel is the
|
||||||
|
same; what differs is who the ranks are (cubes vs PEs)."""
|
||||||
|
data = _sweep_json(monkeypatch)
|
||||||
|
row = _row(data["rows"], "multi_user_decode")
|
||||||
|
n_ranks = row["n_ranks"]
|
||||||
|
assert row["op_log_summary"]["gemm_count"] == 2 * n_ranks, (
|
||||||
|
f"multi_user_decode gemm_count must be 2 × n_ranks = {2 * n_ranks}; "
|
||||||
|
f"got {row['op_log_summary']['gemm_count']}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Thin re-export shim for the GQA figure tests.
|
||||||
|
|
||||||
|
Not a test module (no ``test_`` prefix → pytest does not collect it).
|
||||||
|
|
||||||
|
Mirrors ``tests/gemm/_gemm_plot_helpers.py``. The renderer logic lives in
|
||||||
|
``kernbench.benches.milestone_gqa_llama70b`` (production single home,
|
||||||
|
ADR-0054). Defaults still target the bench's ``_OUTPUT_DIR``.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from kernbench.benches.milestone_gqa_llama70b import (
|
||||||
|
_OUTPUT_DIR as GQA_PLOTS_DIR,
|
||||||
|
_SWEEP_JSON as GQA_SWEEP_JSON,
|
||||||
|
emit_all_gqa_plots,
|
||||||
|
emit_gqa_comparison,
|
||||||
|
emit_panel_op_log_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GQA_PLOTS_DIR",
|
||||||
|
"GQA_SWEEP_JSON",
|
||||||
|
"emit_all_gqa_plots",
|
||||||
|
"emit_gqa_comparison",
|
||||||
|
"emit_panel_op_log_summary",
|
||||||
|
]
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"""Phase 1 spec test for GQA figure renderers (sub-cycle 4c).
|
||||||
|
|
||||||
|
ADR-0057 D3 sub-cycle 4c adds 6 figure renderers; this test pins the
|
||||||
|
5 of 6 that don't depend on sub-cycle 4b's Q/cube sweep:
|
||||||
|
|
||||||
|
- 4 per-panel op_log_summary PNGs (one per panel of v1's sweep.json)
|
||||||
|
- 1 cross-panel ``gqa_comparison.png`` (4-panel grouped bars over the
|
||||||
|
5 op_log_summary counts: gemm, ipcq_send, ipcq_recv, dma_read, dma_write)
|
||||||
|
|
||||||
|
The 6th, ``gqa_scaling.png``, needs the Q/cube ∈ {1, 2, 4} sweep from
|
||||||
|
sub-cycle 4b and is deferred.
|
||||||
|
|
||||||
|
Each test depends on the committed
|
||||||
|
``benches/1H_milestone_output/gqa/sweep.json`` (landed in commit
|
||||||
|
``e748a62``); they assert the renderer writes a non-empty PNG at the
|
||||||
|
expected path.
|
||||||
|
|
||||||
|
Phase 1 expectation: tests fail at import (renderer functions don't
|
||||||
|
exist yet on the bench module). Phase 2 lands them and the tests
|
||||||
|
turn green.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.gqa._gqa_plot_helpers import (
|
||||||
|
GQA_PLOTS_DIR,
|
||||||
|
GQA_SWEEP_JSON,
|
||||||
|
emit_all_gqa_plots,
|
||||||
|
emit_gqa_comparison,
|
||||||
|
emit_panel_op_log_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PANELS = (
|
||||||
|
"single_user_prefill",
|
||||||
|
"multi_user_prefill",
|
||||||
|
"single_user_decode",
|
||||||
|
"multi_user_decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not GQA_SWEEP_JSON.exists(),
|
||||||
|
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("panel", _PANELS)
|
||||||
|
def test_emit_panel_op_log_summary_writes_png_for_each_panel(panel):
|
||||||
|
out = emit_panel_op_log_summary(panel)
|
||||||
|
assert out is not None, f"{panel}: renderer returned None"
|
||||||
|
path = Path(out)
|
||||||
|
assert path.exists(), f"{panel}: expected PNG at {path}"
|
||||||
|
assert path.suffix == ".png", f"{panel}: not a PNG: {path}"
|
||||||
|
assert path.stat().st_size > 0, f"{panel}: empty PNG: {path}"
|
||||||
|
assert panel in path.stem, (
|
||||||
|
f"{panel}: panel name not in filename {path.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not GQA_SWEEP_JSON.exists(),
|
||||||
|
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
|
||||||
|
)
|
||||||
|
def test_emit_gqa_comparison_writes_png():
|
||||||
|
out = emit_gqa_comparison()
|
||||||
|
assert out is not None
|
||||||
|
path = Path(out)
|
||||||
|
assert path.exists()
|
||||||
|
assert path.name == "gqa_comparison.png"
|
||||||
|
assert path.stat().st_size > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not GQA_SWEEP_JSON.exists(),
|
||||||
|
reason="gqa sweep.json absent; run milestone-gqa-llama70b first",
|
||||||
|
)
|
||||||
|
def test_emit_all_gqa_plots_writes_five_figures():
|
||||||
|
"""emit_all returns a list of 5 written PNG paths (deferring the
|
||||||
|
6th gqa_scaling.png to after sub-cycle 4b lands the Q/cube sweep)."""
|
||||||
|
paths = emit_all_gqa_plots()
|
||||||
|
assert isinstance(paths, list)
|
||||||
|
# 4 per-panel + 1 comparison.
|
||||||
|
assert len(paths) == 5, f"expected 5 PNGs, got {len(paths)}: {paths}"
|
||||||
|
for p in paths:
|
||||||
|
assert Path(p).exists() and Path(p).stat().st_size > 0
|
||||||
|
names = {Path(p).name for p in paths}
|
||||||
|
assert "gqa_comparison.png" in names
|
||||||
|
for panel in _PANELS:
|
||||||
|
assert any(panel in n for n in names), (
|
||||||
|
f"no per-panel PNG for {panel} in {names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_all_gqa_plots_output_dir_matches_bench_output_dir():
|
||||||
|
"""The renderers must write under the bench's own _OUTPUT_DIR so
|
||||||
|
MILESTONE_FAST=1 reuse (and committed baselines) all point at the
|
||||||
|
same on-disk location."""
|
||||||
|
# Stub assertion that fails until emit_all_gqa_plots exists with a
|
||||||
|
# default ``out_dir`` argument identical to GQA_PLOTS_DIR.
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
sig = inspect.signature(emit_all_gqa_plots)
|
||||||
|
assert "out_dir" in sig.parameters
|
||||||
|
default = sig.parameters["out_dir"].default
|
||||||
|
assert Path(default) == GQA_PLOTS_DIR, (
|
||||||
|
f"default out_dir {default} != bench _OUTPUT_DIR {GQA_PLOTS_DIR}"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user