Add deck builder + overview-with-ref diagram scripts
scripts/build_overview_slides.py renders a 5-slide PPTX (kernbench2_overview.pptx) summarizing architecture, model correctness, IPCQ, allreduce, and buffer-kind tier comparison. scripts/emit_overview_with_external_ref.py renders log-y and broken-y variants of the allreduce overview (overview_log.png, overview_broken.png) including a 366 µs ext-sim reference marker at 96 KB / PE. Also includes cube_mesh_view.png rendered from the SVG. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,192 @@
|
||||
"""One-shot: render overview.png with an external 366 µs reference, in two
|
||||
variants — log scale and broken y-axis. Reads docs/diagrams/allreduce_latency_plots/summary.csv
|
||||
and writes overview_log.png and overview_broken.png alongside it.
|
||||
|
||||
This is a derived-artifact generator (per CLAUDE.md): plotting only, no production
|
||||
or test logic touched.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mticker
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
PLOT_DIR = ROOT / "docs" / "diagrams" / "allreduce_latency_plots"
|
||||
CSV_PATH = PLOT_DIR / "summary.csv"
|
||||
|
||||
EXT_LABEL = "ext-sim single-device reduce: 366 µs"
|
||||
EXT_LATENCY_NS = 366_000.0
|
||||
|
||||
COLORS = {
|
||||
"ring_1d": "tab:blue",
|
||||
"torus_2d": "tab:orange",
|
||||
"mesh_2d_no_wrap": "tab:green",
|
||||
}
|
||||
|
||||
# Hand-derived theoretical model for torus_2d (6 SIPs). Mirrors
|
||||
# _aggregate_sweep_plots in tests/test_allreduce_multidevice.py.
|
||||
NOC_PACKET_BYTES = 128
|
||||
PES_PER_CUBE = 8
|
||||
T_STARTUP_NS = 1346.0
|
||||
TAU_NS = (8741.0 - 1346.0) / (6144 - 1)
|
||||
|
||||
|
||||
def _theoretical_torus_2d_ns(bytes_per_pe: int) -> float:
|
||||
bytes_per_cube = int(bytes_per_pe) * PES_PER_CUBE
|
||||
n_packets = max(1, -(-bytes_per_cube // NOC_PACKET_BYTES))
|
||||
return T_STARTUP_NS + (n_packets - 1) * TAU_NS
|
||||
|
||||
|
||||
def _plot_theoretical(ax, records):
|
||||
torus_rs = sorted(
|
||||
[r for r in records if r["sip_topology"] == "torus_2d"],
|
||||
key=lambda r: r["bytes_per_pe"],
|
||||
)
|
||||
if not torus_rs:
|
||||
return
|
||||
ax.plot(
|
||||
[r["bytes_per_pe"] for r in torus_rs],
|
||||
[_theoretical_torus_2d_ns(r["bytes_per_pe"]) for r in torus_rs],
|
||||
color="tab:red", linestyle="--", linewidth=1.6, marker="x",
|
||||
label="theoretical torus_2d (6 SIPs)",
|
||||
)
|
||||
|
||||
|
||||
def _bytes_fmt(x, _pos):
|
||||
if x >= 1024 * 1024:
|
||||
return f"{x / (1024 * 1024):.0f}M"
|
||||
if x >= 1024:
|
||||
return f"{x / 1024:.0f}K"
|
||||
return f"{int(x)}"
|
||||
|
||||
|
||||
def _load_records():
|
||||
rows = []
|
||||
with open(CSV_PATH, newline="") as f:
|
||||
r = csv.DictReader(f)
|
||||
for row in r:
|
||||
rows.append({
|
||||
"sip_topology": row["sip_topology"],
|
||||
"bytes_per_pe": int(row["bytes_per_pe"]),
|
||||
"latency_ns": float(row["latency_ns"]),
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
def _ext_x(records):
|
||||
"""Anchor the external reference at the largest payload (96 KB / PE)."""
|
||||
return max(r["bytes_per_pe"] for r in records)
|
||||
|
||||
|
||||
def _plot_curves(ax, records, topologies):
|
||||
for topo in topologies:
|
||||
rs = sorted([r for r in records if r["sip_topology"] == topo],
|
||||
key=lambda r: r["bytes_per_pe"])
|
||||
if not rs:
|
||||
continue
|
||||
ax.plot(
|
||||
[r["bytes_per_pe"] for r in rs],
|
||||
[r["latency_ns"] for r in rs],
|
||||
marker="o",
|
||||
label=f"{topo}",
|
||||
color=COLORS.get(topo),
|
||||
)
|
||||
|
||||
|
||||
def emit_log(records):
|
||||
topologies = sorted({r["sip_topology"] for r in records})
|
||||
fig, ax = plt.subplots(figsize=(9, 6))
|
||||
_plot_curves(ax, records, topologies)
|
||||
_plot_theoretical(ax, records)
|
||||
ax.scatter(
|
||||
[_ext_x(records)], [EXT_LATENCY_NS],
|
||||
marker="*", s=220, color="tab:red", zorder=5,
|
||||
label=EXT_LABEL,
|
||||
)
|
||||
ax.set_xscale("log", base=2)
|
||||
ax.set_yscale("log")
|
||||
ax.set_xlabel("Bytes per PE (log scale)")
|
||||
ax.set_ylabel("Time (ns) — log scale")
|
||||
ax.set_title("Multi-device allreduce latency vs external single-device reference")
|
||||
ax.grid(True, which="both", alpha=0.3)
|
||||
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt))
|
||||
ax.legend(loc="upper left")
|
||||
fig.tight_layout()
|
||||
out = PLOT_DIR / "overview_log.png"
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
print(f"wrote {out}")
|
||||
|
||||
|
||||
def emit_broken(records):
|
||||
topologies = sorted({r["sip_topology"] for r in records})
|
||||
max_local = max(r["latency_ns"] for r in records)
|
||||
|
||||
fig, (ax_top, ax_bot) = plt.subplots(
|
||||
2, 1, sharex=True,
|
||||
gridspec_kw={"height_ratios": [1, 4], "hspace": 0.05},
|
||||
figsize=(9, 6.5),
|
||||
)
|
||||
|
||||
# Bottom panel: today's three curves + theoretical, linear y.
|
||||
_plot_curves(ax_bot, records, topologies)
|
||||
_plot_theoretical(ax_bot, records)
|
||||
ax_bot.set_ylim(0, max_local * 1.10)
|
||||
|
||||
# Top panel: only the external reference marker, linear y around 366 µs.
|
||||
ax_top.scatter(
|
||||
[_ext_x(records)], [EXT_LATENCY_NS],
|
||||
marker="*", s=240, color="tab:red", zorder=5,
|
||||
label=EXT_LABEL,
|
||||
)
|
||||
ax_top.set_ylim(EXT_LATENCY_NS * 0.93, EXT_LATENCY_NS * 1.05)
|
||||
|
||||
# Hide the spine between the two panels and draw diagonal "break" ticks.
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bot.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False, bottom=False)
|
||||
ax_bot.xaxis.tick_bottom()
|
||||
|
||||
d = 0.012 # diagonal-tick size, in axis-fraction
|
||||
kw = dict(transform=ax_top.transAxes, color="k", clip_on=False, lw=1)
|
||||
ax_top.plot((-d, +d), (-d, +d), **kw)
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kw)
|
||||
kw.update(transform=ax_bot.transAxes)
|
||||
ax_bot.plot((-d, +d), (1 - d * 4, 1 + d * 4), **kw)
|
||||
ax_bot.plot((1 - d, 1 + d), (1 - d * 4, 1 + d * 4), **kw)
|
||||
|
||||
ax_bot.set_xscale("log", base=2)
|
||||
ax_bot.set_xlabel("Bytes per PE (log scale)")
|
||||
ax_bot.set_ylabel("Time (ns)")
|
||||
ax_top.set_ylabel("Time (ns)")
|
||||
ax_bot.grid(True, alpha=0.3)
|
||||
ax_top.grid(True, alpha=0.3)
|
||||
ax_bot.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt))
|
||||
|
||||
# One legend covering both axes.
|
||||
handles_bot, labels_bot = ax_bot.get_legend_handles_labels()
|
||||
handles_top, labels_top = ax_top.get_legend_handles_labels()
|
||||
ax_bot.legend(handles_bot + handles_top, labels_bot + labels_top,
|
||||
loc="upper left")
|
||||
|
||||
fig.suptitle("Multi-device allreduce latency vs external single-device reference (broken y-axis)")
|
||||
fig.tight_layout()
|
||||
out = PLOT_DIR / "overview_broken.png"
|
||||
fig.savefig(out, dpi=120)
|
||||
plt.close(fig)
|
||||
print(f"wrote {out}")
|
||||
|
||||
|
||||
def main():
|
||||
records = _load_records()
|
||||
if not records:
|
||||
raise SystemExit(f"no rows in {CSV_PATH}")
|
||||
emit_log(records)
|
||||
emit_broken(records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user