5accd98171
scripts/build_overview_slides.py renders a 5-slide PPTX (kernbench2_overview.pptx) summarizing architecture, model correctness, IPCQ, allreduce, and buffer-kind tier comparison. scripts/emit_overview_with_external_ref.py renders log-y and broken-y variants of the allreduce overview (overview_log.png, overview_broken.png) including a 366 µs ext-sim reference marker at 96 KB / PE. Also includes cube_mesh_view.png rendered from the SVG. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
193 lines
6.0 KiB
Python
193 lines
6.0 KiB
Python
"""One-shot: render overview.png with an external 366 µs reference, in two
|
|
variants — log scale and broken y-axis. Reads docs/diagrams/allreduce_latency_plots/summary.csv
|
|
and writes overview_log.png and overview_broken.png alongside it.
|
|
|
|
This is a derived-artifact generator (per CLAUDE.md): plotting only, no production
|
|
or test logic touched.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mticker
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
PLOT_DIR = ROOT / "docs" / "diagrams" / "allreduce_latency_plots"
|
|
CSV_PATH = PLOT_DIR / "summary.csv"
|
|
|
|
EXT_LABEL = "ext-sim single-device reduce: 366 µs"
|
|
EXT_LATENCY_NS = 366_000.0
|
|
|
|
COLORS = {
|
|
"ring_1d": "tab:blue",
|
|
"torus_2d": "tab:orange",
|
|
"mesh_2d_no_wrap": "tab:green",
|
|
}
|
|
|
|
# Hand-derived theoretical model for torus_2d (6 SIPs). Mirrors
|
|
# _aggregate_sweep_plots in tests/test_allreduce_multidevice.py.
|
|
NOC_PACKET_BYTES = 128
|
|
PES_PER_CUBE = 8
|
|
T_STARTUP_NS = 1346.0
|
|
TAU_NS = (8741.0 - 1346.0) / (6144 - 1)
|
|
|
|
|
|
def _theoretical_torus_2d_ns(bytes_per_pe: int) -> float:
|
|
bytes_per_cube = int(bytes_per_pe) * PES_PER_CUBE
|
|
n_packets = max(1, -(-bytes_per_cube // NOC_PACKET_BYTES))
|
|
return T_STARTUP_NS + (n_packets - 1) * TAU_NS
|
|
|
|
|
|
def _plot_theoretical(ax, records):
|
|
torus_rs = sorted(
|
|
[r for r in records if r["sip_topology"] == "torus_2d"],
|
|
key=lambda r: r["bytes_per_pe"],
|
|
)
|
|
if not torus_rs:
|
|
return
|
|
ax.plot(
|
|
[r["bytes_per_pe"] for r in torus_rs],
|
|
[_theoretical_torus_2d_ns(r["bytes_per_pe"]) for r in torus_rs],
|
|
color="tab:red", linestyle="--", linewidth=1.6, marker="x",
|
|
label="theoretical torus_2d (6 SIPs)",
|
|
)
|
|
|
|
|
|
def _bytes_fmt(x, _pos):
|
|
if x >= 1024 * 1024:
|
|
return f"{x / (1024 * 1024):.0f}M"
|
|
if x >= 1024:
|
|
return f"{x / 1024:.0f}K"
|
|
return f"{int(x)}"
|
|
|
|
|
|
def _load_records():
|
|
rows = []
|
|
with open(CSV_PATH, newline="") as f:
|
|
r = csv.DictReader(f)
|
|
for row in r:
|
|
rows.append({
|
|
"sip_topology": row["sip_topology"],
|
|
"bytes_per_pe": int(row["bytes_per_pe"]),
|
|
"latency_ns": float(row["latency_ns"]),
|
|
})
|
|
return rows
|
|
|
|
|
|
def _ext_x(records):
|
|
"""Anchor the external reference at the largest payload (96 KB / PE)."""
|
|
return max(r["bytes_per_pe"] for r in records)
|
|
|
|
|
|
def _plot_curves(ax, records, topologies):
|
|
for topo in topologies:
|
|
rs = sorted([r for r in records if r["sip_topology"] == topo],
|
|
key=lambda r: r["bytes_per_pe"])
|
|
if not rs:
|
|
continue
|
|
ax.plot(
|
|
[r["bytes_per_pe"] for r in rs],
|
|
[r["latency_ns"] for r in rs],
|
|
marker="o",
|
|
label=f"{topo}",
|
|
color=COLORS.get(topo),
|
|
)
|
|
|
|
|
|
def emit_log(records):
|
|
topologies = sorted({r["sip_topology"] for r in records})
|
|
fig, ax = plt.subplots(figsize=(9, 6))
|
|
_plot_curves(ax, records, topologies)
|
|
_plot_theoretical(ax, records)
|
|
ax.scatter(
|
|
[_ext_x(records)], [EXT_LATENCY_NS],
|
|
marker="*", s=220, color="tab:red", zorder=5,
|
|
label=EXT_LABEL,
|
|
)
|
|
ax.set_xscale("log", base=2)
|
|
ax.set_yscale("log")
|
|
ax.set_xlabel("Bytes per PE (log scale)")
|
|
ax.set_ylabel("Time (ns) — log scale")
|
|
ax.set_title("Multi-device allreduce latency vs external single-device reference")
|
|
ax.grid(True, which="both", alpha=0.3)
|
|
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt))
|
|
ax.legend(loc="upper left")
|
|
fig.tight_layout()
|
|
out = PLOT_DIR / "overview_log.png"
|
|
fig.savefig(out, dpi=120)
|
|
plt.close(fig)
|
|
print(f"wrote {out}")
|
|
|
|
|
|
def emit_broken(records):
|
|
topologies = sorted({r["sip_topology"] for r in records})
|
|
max_local = max(r["latency_ns"] for r in records)
|
|
|
|
fig, (ax_top, ax_bot) = plt.subplots(
|
|
2, 1, sharex=True,
|
|
gridspec_kw={"height_ratios": [1, 4], "hspace": 0.05},
|
|
figsize=(9, 6.5),
|
|
)
|
|
|
|
# Bottom panel: today's three curves + theoretical, linear y.
|
|
_plot_curves(ax_bot, records, topologies)
|
|
_plot_theoretical(ax_bot, records)
|
|
ax_bot.set_ylim(0, max_local * 1.10)
|
|
|
|
# Top panel: only the external reference marker, linear y around 366 µs.
|
|
ax_top.scatter(
|
|
[_ext_x(records)], [EXT_LATENCY_NS],
|
|
marker="*", s=240, color="tab:red", zorder=5,
|
|
label=EXT_LABEL,
|
|
)
|
|
ax_top.set_ylim(EXT_LATENCY_NS * 0.93, EXT_LATENCY_NS * 1.05)
|
|
|
|
# Hide the spine between the two panels and draw diagonal "break" ticks.
|
|
ax_top.spines["bottom"].set_visible(False)
|
|
ax_bot.spines["top"].set_visible(False)
|
|
ax_top.tick_params(labeltop=False, bottom=False)
|
|
ax_bot.xaxis.tick_bottom()
|
|
|
|
d = 0.012 # diagonal-tick size, in axis-fraction
|
|
kw = dict(transform=ax_top.transAxes, color="k", clip_on=False, lw=1)
|
|
ax_top.plot((-d, +d), (-d, +d), **kw)
|
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kw)
|
|
kw.update(transform=ax_bot.transAxes)
|
|
ax_bot.plot((-d, +d), (1 - d * 4, 1 + d * 4), **kw)
|
|
ax_bot.plot((1 - d, 1 + d), (1 - d * 4, 1 + d * 4), **kw)
|
|
|
|
ax_bot.set_xscale("log", base=2)
|
|
ax_bot.set_xlabel("Bytes per PE (log scale)")
|
|
ax_bot.set_ylabel("Time (ns)")
|
|
ax_top.set_ylabel("Time (ns)")
|
|
ax_bot.grid(True, alpha=0.3)
|
|
ax_top.grid(True, alpha=0.3)
|
|
ax_bot.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt))
|
|
|
|
# One legend covering both axes.
|
|
handles_bot, labels_bot = ax_bot.get_legend_handles_labels()
|
|
handles_top, labels_top = ax_top.get_legend_handles_labels()
|
|
ax_bot.legend(handles_bot + handles_top, labels_bot + labels_top,
|
|
loc="upper left")
|
|
|
|
fig.suptitle("Multi-device allreduce latency vs external single-device reference (broken y-axis)")
|
|
fig.tight_layout()
|
|
out = PLOT_DIR / "overview_broken.png"
|
|
fig.savefig(out, dpi=120)
|
|
plt.close(fig)
|
|
print(f"wrote {out}")
|
|
|
|
|
|
def main():
|
|
records = _load_records()
|
|
if not records:
|
|
raise SystemExit(f"no rows in {CSV_PATH}")
|
|
emit_log(records)
|
|
emit_broken(records)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|