"""One-shot: render overview.png with an external 366 µs reference, in two variants — log scale and broken y-axis. Reads docs/diagrams/allreduce_latency_plots/summary.csv and writes overview_log.png and overview_broken.png alongside it. This is a derived-artifact generator (per CLAUDE.md): plotting only, no production or test logic touched. """ from __future__ import annotations import csv from pathlib import Path import matplotlib.pyplot as plt import matplotlib.ticker as mticker ROOT = Path(__file__).resolve().parent.parent PLOT_DIR = ROOT / "docs" / "diagrams" / "allreduce_latency_plots" CSV_PATH = PLOT_DIR / "summary.csv" EXT_LABEL = "ext-sim single-device reduce: 366 µs" EXT_LATENCY_NS = 366_000.0 COLORS = { "ring_1d": "tab:blue", "torus_2d": "tab:orange", "mesh_2d_no_wrap": "tab:green", } # Hand-derived theoretical model for torus_2d (6 SIPs). Mirrors # _aggregate_sweep_plots in tests/test_allreduce_multidevice.py. NOC_PACKET_BYTES = 128 PES_PER_CUBE = 8 T_STARTUP_NS = 1346.0 TAU_NS = (8741.0 - 1346.0) / (6144 - 1) def _theoretical_torus_2d_ns(bytes_per_pe: int) -> float: bytes_per_cube = int(bytes_per_pe) * PES_PER_CUBE n_packets = max(1, -(-bytes_per_cube // NOC_PACKET_BYTES)) return T_STARTUP_NS + (n_packets - 1) * TAU_NS def _plot_theoretical(ax, records): torus_rs = sorted( [r for r in records if r["sip_topology"] == "torus_2d"], key=lambda r: r["bytes_per_pe"], ) if not torus_rs: return ax.plot( [r["bytes_per_pe"] for r in torus_rs], [_theoretical_torus_2d_ns(r["bytes_per_pe"]) for r in torus_rs], color="tab:red", linestyle="--", linewidth=1.6, marker="x", label="theoretical torus_2d (6 SIPs)", ) def _bytes_fmt(x, _pos): if x >= 1024 * 1024: return f"{x / (1024 * 1024):.0f}M" if x >= 1024: return f"{x / 1024:.0f}K" return f"{int(x)}" def _load_records(): rows = [] with open(CSV_PATH, newline="") as f: r = csv.DictReader(f) for row in r: rows.append({ "sip_topology": row["sip_topology"], "bytes_per_pe": int(row["bytes_per_pe"]), "latency_ns": float(row["latency_ns"]), }) return rows def _ext_x(records): """Anchor the external reference at the largest payload (96 KB / PE).""" return max(r["bytes_per_pe"] for r in records) def _plot_curves(ax, records, topologies): for topo in topologies: rs = sorted([r for r in records if r["sip_topology"] == topo], key=lambda r: r["bytes_per_pe"]) if not rs: continue ax.plot( [r["bytes_per_pe"] for r in rs], [r["latency_ns"] for r in rs], marker="o", label=f"{topo}", color=COLORS.get(topo), ) def emit_log(records): topologies = sorted({r["sip_topology"] for r in records}) fig, ax = plt.subplots(figsize=(9, 6)) _plot_curves(ax, records, topologies) _plot_theoretical(ax, records) ax.scatter( [_ext_x(records)], [EXT_LATENCY_NS], marker="*", s=220, color="tab:red", zorder=5, label=EXT_LABEL, ) ax.set_xscale("log", base=2) ax.set_yscale("log") ax.set_xlabel("Bytes per PE (log scale)") ax.set_ylabel("Time (ns) — log scale") ax.set_title("Multi-device allreduce latency vs external single-device reference") ax.grid(True, which="both", alpha=0.3) ax.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt)) ax.legend(loc="upper left") fig.tight_layout() out = PLOT_DIR / "overview_log.png" fig.savefig(out, dpi=120) plt.close(fig) print(f"wrote {out}") def emit_broken(records): topologies = sorted({r["sip_topology"] for r in records}) max_local = max(r["latency_ns"] for r in records) fig, (ax_top, ax_bot) = plt.subplots( 2, 1, sharex=True, gridspec_kw={"height_ratios": [1, 4], "hspace": 0.05}, figsize=(9, 6.5), ) # Bottom panel: today's three curves + theoretical, linear y. _plot_curves(ax_bot, records, topologies) _plot_theoretical(ax_bot, records) ax_bot.set_ylim(0, max_local * 1.10) # Top panel: only the external reference marker, linear y around 366 µs. ax_top.scatter( [_ext_x(records)], [EXT_LATENCY_NS], marker="*", s=240, color="tab:red", zorder=5, label=EXT_LABEL, ) ax_top.set_ylim(EXT_LATENCY_NS * 0.93, EXT_LATENCY_NS * 1.05) # Hide the spine between the two panels and draw diagonal "break" ticks. ax_top.spines["bottom"].set_visible(False) ax_bot.spines["top"].set_visible(False) ax_top.tick_params(labeltop=False, bottom=False) ax_bot.xaxis.tick_bottom() d = 0.012 # diagonal-tick size, in axis-fraction kw = dict(transform=ax_top.transAxes, color="k", clip_on=False, lw=1) ax_top.plot((-d, +d), (-d, +d), **kw) ax_top.plot((1 - d, 1 + d), (-d, +d), **kw) kw.update(transform=ax_bot.transAxes) ax_bot.plot((-d, +d), (1 - d * 4, 1 + d * 4), **kw) ax_bot.plot((1 - d, 1 + d), (1 - d * 4, 1 + d * 4), **kw) ax_bot.set_xscale("log", base=2) ax_bot.set_xlabel("Bytes per PE (log scale)") ax_bot.set_ylabel("Time (ns)") ax_top.set_ylabel("Time (ns)") ax_bot.grid(True, alpha=0.3) ax_top.grid(True, alpha=0.3) ax_bot.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt)) # One legend covering both axes. handles_bot, labels_bot = ax_bot.get_legend_handles_labels() handles_top, labels_top = ax_top.get_legend_handles_labels() ax_bot.legend(handles_bot + handles_top, labels_bot + labels_top, loc="upper left") fig.suptitle("Multi-device allreduce latency vs external single-device reference (broken y-axis)") fig.tight_layout() out = PLOT_DIR / "overview_broken.png" fig.savefig(out, dpi=120) plt.close(fig) print(f"wrote {out}") def main(): records = _load_records() if not records: raise SystemExit(f"no rows in {CSV_PATH}") emit_log(records) emit_broken(records) if __name__ == "__main__": main()