attention: milestone-gqa-llama70b figures + MILESTONE_FAST (sub-cycle 4c, 5/6)
Add 5 of the 6 figure renderers ADR-0057 D3 sub-cycle 4c specifies:
- gqa_op_log_{panel}.png × 4 — per-panel bar chart of the 5 op_log
counts (gemm, ipcq_send, ipcq_recv, dma_read, dma_write).
- gqa_comparison.png — cross-panel grouped bars over the same 5 series.
Sixth figure (gqa_scaling.png) depends on sub-cycle 4b's Q/cube ∈
{1, 2, 4} sweep on multi_user_* panels and is deferred until that
data exists; emit_all_gqa_plots returns just the 5 in-scope paths.
Add MILESTONE_FAST=1 mode to run(): skip the panel sweep, reuse the
committed sweep.json, render figures only. Validation mode unchanged.
The runtime errors clearly when neither env var is set, listing the
two supported modes.
Renderers live in the bench module (the milestone-1h-gemm pattern);
tests/gqa/_gqa_plot_helpers.py re-exports them for figure tests.
Tests: tests/gqa/test_plot_gqa_figures.py — 7 tests, all green:
- 4 parametrized per-panel emit assertions
- 1 comparison emit assertion
- 1 emit_all returns exactly 5 PNG paths
- 1 default out_dir matches the bench _OUTPUT_DIR
Commits the 5 PNG baselines under the bench output dir alongside
sweep.json, mirroring milestone-1h-gemm's committed-figures pattern.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 33 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
@@ -196,6 +196,168 @@ def _run_panel(panel: str, topology: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
# ── Figure renderers (sub-cycle 4c, 5 of 6 figures) ──────────────────
|
||||
#
|
||||
# Sixth figure ``gqa_scaling.png`` is deferred to after sub-cycle 4b
|
||||
# lands the Q/cube ∈ {1, 2, 4} sweep on multi_user_* panels — it needs
|
||||
# multiple sweep.json rows per multi_user panel to be meaningful.
|
||||
|
||||
_OP_LOG_KEYS = (
|
||||
"gemm_count",
|
||||
"ipcq_send_count",
|
||||
"ipcq_recv_count",
|
||||
"dma_read_count",
|
||||
"dma_write_count",
|
||||
)
|
||||
_OP_LOG_DISPLAY = {
|
||||
"gemm_count": "GEMM",
|
||||
"ipcq_send_count": "IPCQ send",
|
||||
"ipcq_recv_count": "IPCQ recv",
|
||||
"dma_read_count": "DMA read",
|
||||
"dma_write_count": "DMA write",
|
||||
}
|
||||
_OP_LOG_COLORS = {
|
||||
"gemm_count": "#F59E0B",
|
||||
"ipcq_send_count": "#3B82F6",
|
||||
"ipcq_recv_count": "#10B981",
|
||||
"dma_read_count": "#A855F7",
|
||||
"dma_write_count": "#EF4444",
|
||||
}
|
||||
_PANEL_DISPLAY = {
|
||||
"single_user_prefill": "single_user / prefill",
|
||||
"multi_user_prefill": "multi_user / prefill",
|
||||
"single_user_decode": "single_user / decode",
|
||||
"multi_user_decode": "multi_user / decode",
|
||||
}
|
||||
|
||||
|
||||
def _load_sweep_data(sweep_json: Path | str) -> dict:
|
||||
sweep_json = Path(sweep_json)
|
||||
if not sweep_json.exists():
|
||||
return {"rows": [], "config": {}, "panels": []}
|
||||
return json.loads(sweep_json.read_text())
|
||||
|
||||
|
||||
def _row_for(rows: list, panel: str) -> dict | None:
|
||||
for r in rows:
|
||||
if r.get("panel") == panel:
|
||||
return r
|
||||
return None
|
||||
|
||||
|
||||
def emit_panel_op_log_summary(
|
||||
panel: str,
|
||||
sweep_json: Path | str = _SWEEP_JSON,
|
||||
out_dir: Path | str = _OUTPUT_DIR,
|
||||
) -> str | None:
|
||||
"""One bar chart of the 5 op_log counts for ``panel``.
|
||||
|
||||
Returns the written PNG path, or ``None`` when sweep.json is empty
|
||||
or the requested panel is absent.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
data = _load_sweep_data(sweep_json)
|
||||
row = _row_for(data.get("rows", []), panel)
|
||||
if row is None:
|
||||
return None
|
||||
summary = row.get("op_log_summary", {})
|
||||
n_ranks = row.get("n_ranks")
|
||||
|
||||
labels = [_OP_LOG_DISPLAY[k] for k in _OP_LOG_KEYS]
|
||||
values = [summary.get(k, 0) for k in _OP_LOG_KEYS]
|
||||
colors = [_OP_LOG_COLORS[k] for k in _OP_LOG_KEYS]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
bars = ax.bar(labels, values, color=colors)
|
||||
for b, v in zip(bars, values):
|
||||
ax.text(b.get_x() + b.get_width() / 2, b.get_height(),
|
||||
f"{int(v)}", ha="center", va="bottom", fontsize=9)
|
||||
ax.set_title(
|
||||
f"{_PANEL_DISPLAY.get(panel, panel)} (n_ranks={n_ranks})",
|
||||
fontsize=12, fontweight="bold",
|
||||
)
|
||||
ax.set_ylabel("count")
|
||||
ax.grid(True, axis="y", alpha=0.3)
|
||||
fig.tight_layout()
|
||||
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = out_dir / f"gqa_op_log_{panel}.png"
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
return str(out)
|
||||
|
||||
|
||||
def emit_gqa_comparison(
|
||||
sweep_json: Path | str = _SWEEP_JSON,
|
||||
out_dir: Path | str = _OUTPUT_DIR,
|
||||
) -> str | None:
|
||||
"""Grouped-bar chart comparing the 5 op_log counts across all panels."""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
data = _load_sweep_data(sweep_json)
|
||||
panels_in = data.get("panels") or list(_PANELS_V1)
|
||||
rows = data.get("rows", [])
|
||||
panels = [p for p in panels_in if _row_for(rows, p) is not None]
|
||||
if not panels:
|
||||
return None
|
||||
|
||||
n_groups = len(panels)
|
||||
n_series = len(_OP_LOG_KEYS)
|
||||
x = np.arange(n_groups)
|
||||
width = 0.8 / n_series
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 6))
|
||||
for i, key in enumerate(_OP_LOG_KEYS):
|
||||
offset = (i - (n_series - 1) / 2) * width
|
||||
vals = [_row_for(rows, p)["op_log_summary"].get(key, 0)
|
||||
for p in panels]
|
||||
ax.bar(x + offset, vals, width,
|
||||
label=_OP_LOG_DISPLAY[key], color=_OP_LOG_COLORS[key])
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(
|
||||
[f"{_PANEL_DISPLAY.get(p, p)}\n(n_ranks={_row_for(rows, p)['n_ranks']})"
|
||||
for p in panels],
|
||||
fontsize=8,
|
||||
)
|
||||
ax.set_ylabel("count")
|
||||
ax.set_title("GQA Llama-70B — op_log summary across panels",
|
||||
fontsize=13, fontweight="bold")
|
||||
ax.legend(fontsize=8, loc="upper right")
|
||||
ax.grid(True, axis="y", alpha=0.3)
|
||||
fig.tight_layout()
|
||||
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = out_dir / "gqa_comparison.png"
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
return str(out)
|
||||
|
||||
|
||||
def emit_all_gqa_plots(
|
||||
sweep_json: Path | str = _SWEEP_JSON,
|
||||
out_dir: Path | str = _OUTPUT_DIR,
|
||||
) -> list[str]:
|
||||
"""Render all 5 in-scope figures and return the written paths.
|
||||
|
||||
Sub-cycle 4c v1 emits 5 of the 6 figures ADR-0057 D3 lists; the
|
||||
6th (gqa_scaling.png) needs sub-cycle 4b's Q/cube sweep data.
|
||||
"""
|
||||
paths: list[str] = []
|
||||
for panel in _PANELS_V1:
|
||||
p = emit_panel_op_log_summary(panel, sweep_json, out_dir)
|
||||
if p is not None:
|
||||
paths.append(p)
|
||||
comp = emit_gqa_comparison(sweep_json, out_dir)
|
||||
if comp is not None:
|
||||
paths.append(comp)
|
||||
return paths
|
||||
|
||||
|
||||
# ── Bench entry ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -204,42 +366,58 @@ def _run_panel(panel: str, topology: str) -> dict:
|
||||
description="1H milestone: GQA Llama-70B 4-panel sweep (ADR-0057 v1).",
|
||||
)
|
||||
def run(torch) -> None:
|
||||
"""Drive the four GQA panels at validation scale; write sweep.json.
|
||||
"""Drive the four GQA panels at validation scale; write sweep.json and figures.
|
||||
|
||||
v1 only supports validation mode (``GQA_VALIDATION=1``). Headline
|
||||
mode and figures are deferred to sub-cycles 4b and 4c per ADR-0057 D3.
|
||||
Modes (mutually exclusive):
|
||||
MILESTONE_FAST=1 Skip the sweep; re-render figures from the
|
||||
committed sweep.json. Seconds, no simulator.
|
||||
GQA_VALIDATION=1 Run the four-panel validation sweep + figures.
|
||||
~1-2h on the full simulator.
|
||||
|
||||
Headline-scale mode is deferred to sub-cycle 4c (figures landed
|
||||
here; headline-scale + scaling figure await sub-cycle 4b).
|
||||
A sentinel tensor is submitted at the end so run_bench's ADR-0045 D4
|
||||
"at least one request" contract is satisfied even though the panels
|
||||
drive their own engines.
|
||||
"at least one request" contract is satisfied even when the panels
|
||||
are skipped via MILESTONE_FAST=1.
|
||||
"""
|
||||
if not os.environ.get("GQA_VALIDATION"):
|
||||
raise RuntimeError(
|
||||
"milestone-gqa-llama70b v1 only supports validation mode. "
|
||||
"Set GQA_VALIDATION=1 to run. Headline mode is deferred to "
|
||||
"sub-cycle 4b/4c per ADR-0057 D3."
|
||||
)
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
|
||||
fast = bool(os.environ.get("MILESTONE_FAST"))
|
||||
if not fast and not os.environ.get("GQA_VALIDATION"):
|
||||
raise RuntimeError(
|
||||
"milestone-gqa-llama70b v1 needs GQA_VALIDATION=1 (run the "
|
||||
"sweep) or MILESTONE_FAST=1 (reuse committed sweep.json). "
|
||||
"Headline mode is deferred to sub-cycle 4b/4c per ADR-0057 D3."
|
||||
)
|
||||
|
||||
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
|
||||
if not fast:
|
||||
topology = os.environ.get("GQA_TOPOLOGY", "topology.yaml")
|
||||
rows = [_run_panel(panel, topology) for panel in _PANELS_V1]
|
||||
sweep = {
|
||||
"version": 1,
|
||||
"validation_scale": True,
|
||||
"panels": list(_PANELS_V1),
|
||||
"config": {
|
||||
"S_q_prefill": _S_Q_PREFILL,
|
||||
"S_kv_per_rank": _S_KV_PER_RANK,
|
||||
"h_q": _H_Q,
|
||||
"h_kv": _H_KV,
|
||||
"d_head": _D_HEAD,
|
||||
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
|
||||
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
|
||||
},
|
||||
"rows": rows,
|
||||
}
|
||||
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
|
||||
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
|
||||
elif not _SWEEP_JSON.exists():
|
||||
raise RuntimeError(
|
||||
f"MILESTONE_FAST=1 requires {_SWEEP_JSON} to exist; "
|
||||
"run with GQA_VALIDATION=1 once to seed it."
|
||||
)
|
||||
|
||||
sweep = {
|
||||
"version": 1,
|
||||
"validation_scale": True,
|
||||
"panels": list(_PANELS_V1),
|
||||
"config": {
|
||||
"S_q_prefill": _S_Q_PREFILL,
|
||||
"S_kv_per_rank": _S_KV_PER_RANK,
|
||||
"h_q": _H_Q,
|
||||
"h_kv": _H_KV,
|
||||
"d_head": _D_HEAD,
|
||||
"n_ranks_single_user": _N_RANKS_SINGLE_USER,
|
||||
"n_ranks_multi_user": _N_RANKS_MULTI_USER,
|
||||
},
|
||||
"rows": rows,
|
||||
}
|
||||
_SWEEP_JSON.write_text(json.dumps(sweep, indent=2))
|
||||
print(f" milestone-gqa-llama70b: {len(rows)} rows -> {_SWEEP_JSON}")
|
||||
paths = emit_all_gqa_plots()
|
||||
print(f" milestone-gqa-llama70b: {len(paths)} figures -> {_OUTPUT_DIR} "
|
||||
f"(fast={fast})")
|
||||
|
||||
# Sentinel tensor (ADR-0045 D4 / ADR-0054 D2 carve-out).
|
||||
torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user