diff --git a/tests/allreduce_latency_plots/mesh_2d_no_wrap.png b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png index 2d57582..6a4d0ca 100644 Binary files a/tests/allreduce_latency_plots/mesh_2d_no_wrap.png and b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png differ diff --git a/tests/allreduce_latency_plots/overview.png b/tests/allreduce_latency_plots/overview.png index 12c58b9..0622007 100644 Binary files a/tests/allreduce_latency_plots/overview.png and b/tests/allreduce_latency_plots/overview.png differ diff --git a/tests/allreduce_latency_plots/ring_1d.png b/tests/allreduce_latency_plots/ring_1d.png index ac87a1d..beb73fa 100644 Binary files a/tests/allreduce_latency_plots/ring_1d.png and b/tests/allreduce_latency_plots/ring_1d.png differ diff --git a/tests/allreduce_latency_plots/summary.csv b/tests/allreduce_latency_plots/summary.csv index a58d290..d40f782 100644 --- a/tests/allreduce_latency_plots/summary.csv +++ b/tests/allreduce_latency_plots/summary.csv @@ -1,26 +1,4 @@ algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns -intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937 -intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947 -intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992 -intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865 -intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865 -intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865 -intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865 -intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965 -intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957 -intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932 -intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903 -intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923 -intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993 -intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905 -intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985 -intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985 -intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985 -intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985 -intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777 -intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705 -intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952 -intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923 intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993 intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55 intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975 @@ -32,3 +10,28 @@ intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959 intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945 intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937 intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872 +intercube_allreduce,mesh_2d_no_wrap,6,49152,98304,1572864,19817.049999999872 +intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937 +intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947 +intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992 +intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865 +intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865 +intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865 +intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865 +intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965 +intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957 +intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932 +intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903 +intercube_allreduce,ring_1d,6,49152,98304,1572864,18995.879999999917 +intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923 +intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993 +intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905 +intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985 +intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985 +intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985 +intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985 +intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777 +intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705 +intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952 +intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923 +intercube_allreduce,torus_2d,6,49152,98304,1572864,16195.479999999952 diff --git a/tests/allreduce_latency_plots/topology.png b/tests/allreduce_latency_plots/topology.png new file mode 100644 index 0000000..40e8719 Binary files /dev/null and b/tests/allreduce_latency_plots/topology.png differ diff --git a/tests/allreduce_latency_plots/torus_2d.png b/tests/allreduce_latency_plots/torus_2d.png index 5a8bf2d..ce4b502 100644 Binary files a/tests/allreduce_latency_plots/torus_2d.png and b/tests/allreduce_latency_plots/torus_2d.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 9eff856..1c6cced 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,11 +7,45 @@ stateful/SimPy-event-consuming and MUST NOT be shared). """ from __future__ import annotations +import os + import pytest from kernbench.topology.builder import resolve_topology +def pytest_sessionfinish(session, exitstatus): + """Aggregate parametrized sweep rows into combined CSV + PNG plots. + + Runs on the controller node only (xdist worker processes set + ``PYTEST_XDIST_WORKER``; we skip those). Idempotent — does nothing + if no sweep rows are present (e.g., when the sweep was filtered out). + """ + if os.environ.get("PYTEST_XDIST_WORKER"): + return + import importlib.util + import sys + from pathlib import Path + + mod_path = Path(__file__).parent / "test_allreduce_multidevice.py" + if not mod_path.exists(): + return + spec = importlib.util.spec_from_file_location( + "_test_allreduce_multidevice_for_aggregate", mod_path, + ) + if spec is None or spec.loader is None: + return + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + try: + spec.loader.exec_module(mod) + agg = getattr(mod, "_aggregate_sweep_plots", None) + if agg is not None: + agg() + except Exception as e: + print(f"[conftest] sweep aggregation failed: {e}") + + @pytest.fixture(scope="session") def topology(): """Session-scoped parsed topology (immutable graph + spec). diff --git a/tests/test_allreduce_multidevice.py b/tests/test_allreduce_multidevice.py index 81e1093..783b819 100644 --- a/tests/test_allreduce_multidevice.py +++ b/tests/test_allreduce_multidevice.py @@ -269,29 +269,143 @@ def test_allreduce( assert result["ok_cubes"] > 0 -# ── Latency sweep ───────────────────────────────────────────────────── +# ── Latency sweep (parametrized + xdist-friendly) ───────────────────── -# avoid 16 (== n_cubes, dim_map collision). Goes up to 1 MB per SIP: -# bytes_per_sip = n_cubes * n_elem * 2 = 32 * n_elem. +# avoid 16 (== n_cubes, dim_map collision). Goes up to 96 KB per PE: +# bytes_per_pe = n_elem * 2 (f16). 49152 elem * 2 = 96 KB / PE. _SWEEP_N_ELEM = [ 8, 32, 64, 128, 512, 1024, 2048, - 4096, 8192, 16384, 32768, + 4096, 8192, 16384, 32768, 49152, ] _ELEM_BYTES_F16 = 2 +_SWEEP_TOPOLOGIES = [ + ("intercube_allreduce", "ring_1d", 6, None, None), + ("intercube_allreduce", "torus_2d", 6, 2, 3), + ("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3), +] -def test_allreduce_latency_sweep(tmp_path): - """Sweep n_elem across each SIP topology; record max(pe_exec_ns) - as the critical-path kernel latency. Emits CSV + PNG plots to - tests/allreduce_latency_plots/. +# Shared on-disk staging dir for parametrized sweep rows. Each +# parametrized invocation writes one JSON file here; the aggregator +# (run from conftest.pytest_sessionfinish) reads them and emits the +# combined CSV + PNG plots. +_SWEEP_OUT_DIR = Path(__file__).parent / "allreduce_latency_plots" +_SWEEP_ROWS_DIR = _SWEEP_OUT_DIR / "_rows" + + +def _sweep_params(): + out = [] + for algorithm, sip_topology, n_sips, sip_w, sip_h in _SWEEP_TOPOLOGIES: + for n_elem in _SWEEP_N_ELEM: + out.append(pytest.param( + algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem, + id=f"{sip_topology}-n_elem{n_elem}", + )) + return out + + +@pytest.mark.parametrize( + "algorithm,sip_topology,n_sips,sip_w,sip_h,n_elem", _sweep_params(), +) +def test_allreduce_latency_one( + tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem, +): + """One config of the latency sweep. xdist parallelizes across params. + + Writes a single JSON row to ``_SWEEP_ROWS_DIR``. The conftest + sessionfinish hook aggregates rows into CSV + plots after all + parametrized cases finish. + """ + import json + + topo_path, ccl_path = _write_temp_configs( + tmp_path, sip_topology, n_sips, algorithm, + sip_w=sip_w, sip_h=sip_h, + n_elem_override=n_elem, + ) + topo = resolve_topology(topo_path) + engine = GraphEngine(topo.topology_obj, enable_data=True) + spec = topo.topology_obj.spec + + with RuntimeContext( + engine=engine, + target_device=DeviceSelector("all"), + correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}", + spec=spec, + ) as ctx: + result = run_allreduce( + ctx, engine, spec, + algorithm=algorithm, ccl_yaml=ccl_path, + ) + assert result["ok_cubes"] > 0 + + pe_exec_vals = [ + float(tr.get("pe_exec_ns", 0.0) or 0.0) + for _, (_, tr) in engine._results.items() + if isinstance(tr, dict) + ] + crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0 + + cm = spec["sip"]["cube_mesh"] + n_cubes = int(cm["w"]) * int(cm["h"]) + bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16 + bytes_per_pe = n_elem * _ELEM_BYTES_F16 + + record = { + "algorithm": algorithm, + "sip_topology": sip_topology, + "n_sips": n_sips, + "n_elem": n_elem, + "bytes_per_pe": bytes_per_pe, + "bytes_per_sip": bytes_per_sip, + "latency_ns": crit_ns, + } + + _SWEEP_ROWS_DIR.mkdir(parents=True, exist_ok=True) + row_path = _SWEEP_ROWS_DIR / f"{sip_topology}_{n_elem}.json" + with open(row_path, "w", encoding="utf-8") as f: + json.dump(record, f) + + +def _aggregate_sweep_plots() -> bool: + """Read all per-config rows and emit CSV + PNG plots. + + Called by ``conftest.pytest_sessionfinish`` (controller node only). + Returns True if any rows were aggregated, False otherwise. """ import csv + import json + + row_files = sorted(_SWEEP_ROWS_DIR.glob("*.json")) \ + if _SWEEP_ROWS_DIR.exists() else [] + records: list[dict] = [] + if row_files: + for p in row_files: + with open(p, encoding="utf-8") as f: + records.append(json.load(f)) + else: + # Fallback: replot from existing summary.csv (skip sweep re-run). + summary_path = _SWEEP_OUT_DIR / "summary.csv" + if not summary_path.exists(): + return False + with open(summary_path, encoding="utf-8") as f: + for row in csv.DictReader(f): + records.append({ + "algorithm": row["algorithm"], + "sip_topology": row["sip_topology"], + "n_sips": int(row["n_sips"]), + "n_elem": int(row["n_elem"]), + "bytes_per_pe": int(row["bytes_per_pe"]), + "bytes_per_sip": int(row["bytes_per_sip"]), + "latency_ns": float(row["latency_ns"]), + }) + if not records: + return False import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter def _fmt_bytes(x, _pos): - """Format tick as B / KB / MB.""" if x <= 0: return "0" if x >= 1024 * 1024: @@ -302,86 +416,27 @@ def test_allreduce_latency_sweep(tmp_path): _bytes_fmt = FuncFormatter(_fmt_bytes) - out_dir = Path(__file__).parent / "allreduce_latency_plots" - out_dir.mkdir(parents=True, exist_ok=True) - records: list[dict] = [] - - # Apples-to-apples: same n_sips across all three topologies. - for algorithm, sip_topology, n_sips, sip_w, sip_h in [ - ("intercube_allreduce", "ring_1d", 6, None, None), - ("intercube_allreduce", "torus_2d", 6, 2, 3), - ("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3), - ]: - for n_elem in _SWEEP_N_ELEM: - sub = tmp_path / f"{sip_topology}_{n_elem}" - sub.mkdir() - topo_path, ccl_path = _write_temp_configs( - sub, sip_topology, n_sips, algorithm, - sip_w=sip_w, sip_h=sip_h, - n_elem_override=n_elem, - ) - topo = resolve_topology(topo_path) - engine = GraphEngine(topo.topology_obj, enable_data=True) - spec = topo.topology_obj.spec - - with RuntimeContext( - engine=engine, - target_device=DeviceSelector("all"), - correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}", - spec=spec, - ) as ctx: - result = run_allreduce( - ctx, engine, spec, - algorithm=algorithm, ccl_yaml=ccl_path, - ) - assert result["ok_cubes"] > 0 - - pe_exec_vals = [ - float(tr.get("pe_exec_ns", 0.0) or 0.0) - for _, (_, tr) in engine._results.items() - if isinstance(tr, dict) - ] - crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0 - - cm = spec["sip"]["cube_mesh"] - n_cubes = int(cm["w"]) * int(cm["h"]) - bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16 - # pe="replicate" + num_pes=1 → one active PE per cube owns - # the whole cube row. Per-PE bytes == per-cube-tile bytes == - # per-message bytes over the IPCQ fabric. - bytes_per_pe = n_elem * _ELEM_BYTES_F16 - - records.append({ - "algorithm": algorithm, - "sip_topology": sip_topology, - "n_sips": n_sips, - "n_elem": n_elem, - "bytes_per_pe": bytes_per_pe, - "bytes_per_sip": bytes_per_sip, - "latency_ns": crit_ns, - }) - print( - f"[{sip_topology:<16} n_sips={n_sips} n_elem={n_elem:>5} " - f"bytes/pe={bytes_per_pe:>7} bytes/sip={bytes_per_sip:>9}] " - f"pe_exec_max = {crit_ns:8.1f} ns" - ) - - with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f: + _SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True) + with open(_SWEEP_OUT_DIR / "summary.csv", "w", + newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=[ "algorithm", "sip_topology", "n_sips", "n_elem", "bytes_per_pe", "bytes_per_sip", "latency_ns", ]) w.writeheader() - for r in records: + for r in sorted(records, key=lambda r: ( + r["sip_topology"], r["bytes_per_pe"], + )): w.writerow(r) topologies = sorted({r["sip_topology"] for r in records}) - # Per-topology plots, log-scale x-axis = bytes per PE. for topo_name in topologies: rs = sorted( [r for r in records if r["sip_topology"] == topo_name], key=lambda r: r["bytes_per_pe"], ) + if not rs: + continue xs = [r["bytes_per_pe"] for r in rs] ys = [r["latency_ns"] for r in rs] title = ( @@ -397,17 +452,20 @@ def test_allreduce_latency_sweep(tmp_path): ax.grid(True, alpha=0.3) ax.xaxis.set_major_formatter(_bytes_fmt) fig.tight_layout() - fig.savefig(out_dir / f"{topo_name}.png", dpi=120) + fig.savefig(_SWEEP_OUT_DIR / f"{topo_name}.png", dpi=120) plt.close(fig) colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange", "mesh_2d_no_wrap": "tab:green"} + THEORETICAL_TORUS_2D_6SIP_NS = 10600.0 fig, ax = plt.subplots(figsize=(9, 6)) for topo_name in topologies: rs = sorted( [r for r in records if r["sip_topology"] == topo_name], key=lambda r: r["bytes_per_pe"], ) + if not rs: + continue ax.plot( [r["bytes_per_pe"] for r in rs], [r["latency_ns"] for r in rs], @@ -415,15 +473,378 @@ def test_allreduce_latency_sweep(tmp_path): label=f"{topo_name} (n_sips={rs[0]['n_sips']})", color=colors.get(topo_name), ) + ax.axhline( + y=THEORETICAL_TORUS_2D_6SIP_NS, + color="tab:red", linestyle="--", linewidth=1.5, + label=f"theoretical torus_2d (6 SIPs) = " + f"{THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns", + ) + BYTES_96KB = 96 * 1024 + ax.axvline( + x=BYTES_96KB, ymin=0, ymax=1, + color="tab:red", linestyle=":", linewidth=1.2, + ) + ax.plot( + [BYTES_96KB], [THEORETICAL_TORUS_2D_6SIP_NS], + marker="x", color="tab:red", markersize=10, markeredgewidth=2, + ) + # Find simulated torus_2d latency at 96 KB (if present) for direct + # comparison with the theoretical value. + sim_torus_at_96kb = next( + (r["latency_ns"] for r in records + if r["sip_topology"] == "torus_2d" and r["bytes_per_pe"] == BYTES_96KB), + None, + ) + if sim_torus_at_96kb is not None: + ax.plot( + [BYTES_96KB], [sim_torus_at_96kb], + marker="o", color="tab:orange", + markersize=10, markeredgecolor="black", markeredgewidth=1.2, + ) + ax.annotate( + f"96 KB\n" + f"theoretical = {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns\n" + f"simulated = {sim_torus_at_96kb:.0f} ns", + xy=(BYTES_96KB, sim_torus_at_96kb), + xytext=(10, -20), textcoords="offset points", + color="tab:red", fontsize=9, + ) + else: + ax.annotate( + f"96 KB\n→ theoretical {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns", + xy=(BYTES_96KB, THEORETICAL_TORUS_2D_6SIP_NS), + xytext=(8, -20), textcoords="offset points", + color="tab:red", fontsize=9, + ) ax.set_xscale("log", base=2) ax.set_xlabel("Bytes per PE (log scale)") ax.set_ylabel("max pe_exec_ns (critical path)") ax.set_title("Multi-device allreduce latency by topology") ax.grid(True, alpha=0.3) + + # Drop 128 KB tick (overlaps visually with the explicit 96 KB marker) + # and add 96 KB. + BYTES_128KB = 128 * 1024 + existing_ticks = [t for t in ax.get_xticks() if int(t) != BYTES_128KB] + if BYTES_96KB not in existing_ticks: + existing_ticks.append(BYTES_96KB) + ax.set_xticks(sorted(existing_ticks)) + ax.set_xlim(left=min(r["bytes_per_pe"] for r in records) / 2, + right=BYTES_96KB * 1.5) ax.legend() ax.xaxis.set_major_formatter(_bytes_fmt) fig.tight_layout() - fig.savefig(out_dir / "overview.png", dpi=120) + fig.savefig(_SWEEP_OUT_DIR / "overview.png", dpi=120) plt.close(fig) - print(f"\nWrote {out_dir / 'overview.png'}") + # Cleanup row staging dir so a partial future run doesn't pick up + # stale rows. + for p in row_files: + try: + p.unlink() + except OSError: + pass + try: + _SWEEP_ROWS_DIR.rmdir() + except OSError: + pass + + print(f"\nWrote {_SWEEP_OUT_DIR / 'overview.png'} " + f"from {len(records)} rows") + return True + + +# ── Topology diagram (device-level + cube-level reduction) ──────────── + +# Convention: "rows × cols" everywhere, row-major rank assignment +# (rank = row * n_cols + col). For the 2×3 inter-SIP grid, this means +# 2 rows × 3 columns: SIP 0 1 2 / SIP 3 4 5. + +_PALETTE_BG = "#fafbfd" +_PALETTE_FRAME = "#3a3f4a" +_PALETTE_BLUE = "#2c6fb6" +_PALETTE_GREEN = "#2e8a4e" +_PALETTE_TEXT = "#1f2530" +_PALETTE_BOX_FILL = "#eaf2fb" +_PALETTE_BOX_EDGE = "#2c4a78" +_PALETTE_ROOT_FILL = "#ffd9b8" +_PALETTE_ROOT_EDGE = "#bd5a14" + + +def _arrow(ax, xy_from, xy_to, color="black", lw=1.4, alpha=1.0, + style="-|>", curve=0.0): + from matplotlib.patches import FancyArrowPatch + arrow = FancyArrowPatch( + xy_from, xy_to, + arrowstyle=style, mutation_scale=12, + color=color, lw=lw, alpha=alpha, + connectionstyle=f"arc3,rad={curve}", + ) + ax.add_patch(arrow) + + +def _draw_sip_box(ax, cx, cy, w, h, label, *, fill=_PALETTE_BOX_FILL, + edge=_PALETTE_BOX_EDGE, text_color=_PALETTE_TEXT, + font=10): + from matplotlib.patches import FancyBboxPatch + box = FancyBboxPatch( + (cx - w / 2, cy - h / 2), w, h, + boxstyle="round,pad=0.02,rounding_size=0.10", + linewidth=1.4, edgecolor=edge, facecolor=fill, + ) + ax.add_patch(box) + ax.text(cx, cy, label, ha="center", va="center", + color=text_color, fontsize=font, fontweight="bold") + + +def _frame_panel(ax, title, lim_x=10.0, lim_y=6.0): + """Set up a square-ish panel with a visible outer border.""" + from matplotlib.patches import FancyBboxPatch + ax.set_xlim(0, lim_x) + ax.set_ylim(0, lim_y) + ax.set_aspect("equal") + ax.axis("off") + ax.set_facecolor(_PALETTE_BG) + border = FancyBboxPatch( + (0.05, 0.05), lim_x - 0.10, lim_y - 0.10, + boxstyle="round,pad=0.01,rounding_size=0.12", + linewidth=1.4, edgecolor=_PALETTE_FRAME, facecolor=_PALETTE_BG, + zorder=0, + ) + ax.add_patch(border) + ax.set_title(title, fontsize=12, fontweight="bold", + color=_PALETTE_TEXT, pad=8) + + +def _draw_ring_topology(ax): + _frame_panel(ax, "ring_1d (6 SIPs)", lim_x=10.0, lim_y=6.0) + + xs = [1.2, 2.7, 4.2, 5.7, 7.2, 8.7] + y = 3.1 + box_w, box_h = 1.05, 0.9 + for i, x in enumerate(xs): + _draw_sip_box(ax, x, y, box_w, box_h, f"SIP {i}") + # Forward ring (global_E) — adjacent neighbours, anchored to box edges. + for i in range(5): + _arrow(ax, (xs[i] + box_w / 2, y), + (xs[i + 1] - box_w / 2, y), + color=_PALETTE_BLUE, lw=1.6) + # Wrap (SIP 5 → SIP 0). Anchor at right-CENTER of SIP 5 and + # left-CENTER of SIP 0; arc OUTSIDE (above) the row so it does not + # overlap any of the SIP boxes in between. + _arrow( + ax, + (xs[5] + box_w / 2, y), + (xs[0] - box_w / 2, y), + color=_PALETTE_BLUE, lw=1.6, curve=-0.40, + ) + ax.text(5.0, y + 2.0, "global_E (ring)", ha="center", + color=_PALETTE_BLUE, fontsize=10, style="italic") + ax.text(5.0, y - 1.5, + "(global_W = reverse direction, used by the algorithm)", + ha="center", color="gray", fontsize=8, style="italic") + + +def _draw_grid_topology(ax, kind, *, n_rows=2, n_cols=3): + """kind ∈ {'torus', 'mesh'}. Lays out as n_rows × n_cols (row-major). + + For the sweep we use 2 rows × 3 cols → SIP layout:: + + row 0: SIP 0 SIP 1 SIP 2 + row 1: SIP 3 SIP 4 SIP 5 + """ + title = f"torus_2d ({n_rows}×{n_cols}, 6 SIPs)" if kind == "torus" \ + else f"mesh_2d_no_wrap ({n_rows}×{n_cols}, 6 SIPs)" + _frame_panel(ax, title, lim_x=10.0, lim_y=6.0) + + col_xs = [2.0, 5.0, 8.0] # 3 cols + row_ys = [4.3, 1.8] # 2 rows + box_w, box_h = 1.3, 0.95 + pos: dict[tuple[int, int], tuple[float, float]] = {} + for r in range(n_rows): + for c in range(n_cols): + rank = r * n_cols + c + x, y = col_xs[c], row_ys[r] + pos[(r, c)] = (x, y) + _draw_sip_box(ax, x, y, box_w, box_h, f"SIP {rank}") + + # Row edges (E↔W) — between adjacent columns within each row. + for r in range(n_rows): + for c in range(n_cols - 1): + x0, y0 = pos[(r, c)] + x1, y1 = pos[(r, c + 1)] + _arrow(ax, (x0 + box_w / 2, y0 + 0.10), + (x1 - box_w / 2, y1 + 0.10), + color=_PALETTE_BLUE, lw=1.5) + _arrow(ax, (x1 - box_w / 2, y1 - 0.10), + (x0 + box_w / 2, y0 - 0.10), + color=_PALETTE_BLUE, lw=1.5) + # Col edges (N↔S) — between adjacent rows within each column. + for c in range(n_cols): + for r in range(n_rows - 1): + x0, y0 = pos[(r, c)] + x1, y1 = pos[(r + 1, c)] + _arrow(ax, (x0 - 0.12, y0 - box_h / 2), + (x1 - 0.12, y1 + box_h / 2), + color=_PALETTE_GREEN, lw=1.5) + _arrow(ax, (x1 + 0.12, y1 + box_h / 2), + (x0 + 0.12, y0 - box_h / 2), + color=_PALETTE_GREEN, lw=1.5) + # Wrap arrows for torus only — anchor to the centre of the OUTER + # edge of the end SIPs and arc OUTSIDE the row/column so they do + # not overlap the SIPs in between. + if kind == "torus": + # Row wrap: last col → first col. Top row arcs UP, bottom row + # arcs DOWN, so each wrap sits clearly outside its own row. + for r in range(n_rows): + x0, y0 = pos[(r, 0)] + x1, y1 = pos[(r, n_cols - 1)] + curve = -0.45 if r == 0 else 0.45 + _arrow( + ax, + (x1 + box_w / 2, y1), + (x0 - box_w / 2, y0), + color=_PALETTE_BLUE, lw=1.5, + curve=curve, alpha=0.9, + ) + # Col wrap: last row → first row. Leftmost col arcs LEFT, + # rightmost col arcs RIGHT. Middle col(s) get a small inline + # marker + legend note (drawing them through the panel would + # collide with the row arrows). + for c in range(n_cols): + x0, y0 = pos[(0, c)] + x1, y1 = pos[(n_rows - 1, c)] + if c == 0: + curve = 0.55 + elif c == n_cols - 1: + curve = -0.55 + else: + continue # skip middle col — see legend note + _arrow( + ax, + (x1, y1 - box_h / 2), + (x0, y0 + box_h / 2), + color=_PALETTE_GREEN, lw=1.5, + curve=curve, alpha=0.9, + ) + + ax.text(0.7, 5.6, "global_E/W (row)", color=_PALETTE_BLUE, + fontsize=9, style="italic", fontweight="bold") + ax.text(0.7, 5.25, "global_N/S (col)", color=_PALETTE_GREEN, + fontsize=9, style="italic", fontweight="bold") + ax.text(0.7, 4.92, + "wrap = torus" if kind == "torus" else "no wrap = mesh", + color="gray", fontsize=8, style="italic") + if kind == "torus" and n_cols > 2: + ax.text(0.7, 0.3, + "(middle-col wrap omitted for clarity — every row " + "and every column wraps)", + color="gray", fontsize=7.5, style="italic") + + +def _draw_cube_reduction(ax): + """4×4 cube grid inside SIP 0 — compact layout with phase legend.""" + from matplotlib.patches import Rectangle + _frame_panel(ax, "Cube-level reduction inside SIP 0 (4×4 cubes)", + lim_x=10.0, lim_y=6.0) + + cube_w = 0.65 + cube_gap = 0.18 + # Center the 4×4 grid in the left half of the panel. + grid_total = 4 * cube_w + 3 * cube_gap + grid_x0 = 0.7 + grid_y0 = 0.7 + centers: dict[tuple[int, int], tuple[float, float]] = {} + for r in range(4): + for c in range(4): + cx = grid_x0 + c * (cube_w + cube_gap) + cube_w / 2 + cy = grid_y0 + (3 - r) * (cube_w + cube_gap) + cube_w / 2 + centers[(r, c)] = (cx, cy) + cube_id = r * 4 + c + is_root = (r == 3 and c == 3) + face = _PALETTE_ROOT_FILL if is_root else _PALETTE_BOX_FILL + edge = _PALETTE_ROOT_EDGE if is_root else _PALETTE_BOX_EDGE + rect = Rectangle( + (cx - cube_w / 2, cy - cube_w / 2), cube_w, cube_w, + linewidth=1.2, edgecolor=edge, facecolor=face, + ) + ax.add_patch(rect) + label = f"c{cube_id}" + ax.text(cx, cy, label, ha="center", va="center", + fontsize=7.5, fontweight="bold", + color=_PALETTE_ROOT_EDGE if is_root + else _PALETTE_TEXT) + + # Phase 1: row reduce W→E. + for r in range(4): + for c in range(3): + x0, y0 = centers[(r, c)] + x1, y1 = centers[(r, c + 1)] + _arrow(ax, (x0 + cube_w / 2, y0), (x1 - cube_w / 2, y1), + color=_PALETTE_BLUE, lw=1.5) + # Phase 2: col reduce N→S along rightmost column. + for r in range(3): + x0, y0 = centers[(r, 3)] + x1, y1 = centers[(r + 1, 3)] + _arrow(ax, (x0, y0 - cube_w / 2), (x1, y1 + cube_w / 2), + color=_PALETTE_GREEN, lw=1.7) + + # Phase legend on the right side. + legend_x = grid_x0 + grid_total + 0.55 + ax.text(legend_x, 5.0, "Phase 1: row reduce (W → E)", + color=_PALETTE_BLUE, fontsize=10, fontweight="bold") + ax.text(legend_x, 4.55, "Phase 2: col reduce (N → S, rightmost col)", + color=_PALETTE_GREEN, fontsize=10, fontweight="bold") + ax.text(legend_x, 4.10, "Phase 3: inter-SIP exchange at root cube", + color=_PALETTE_ROOT_EDGE, fontsize=10, fontweight="bold") + ax.text(legend_x, 3.65, "Phase 4: col broadcast (S → N)", + color=_PALETTE_GREEN, fontsize=10, style="italic") + ax.text(legend_x, 3.20, "Phase 5: row broadcast (E → W)", + color=_PALETTE_BLUE, fontsize=10, style="italic") + ax.text(legend_x, 2.55, + "(broadcast phases reverse phases 2 & 1)", + color="gray", fontsize=8.5, style="italic") + ax.text(legend_x, 1.7, + "Root cube (c15, bottom-right) is the only\n" + "cube that performs the inter-SIP exchange.", + color=_PALETTE_ROOT_EDGE, fontsize=9, style="italic") + + +def emit_topology_diagram() -> str: + """Emit a 2×2-panel topology diagram into allreduce_latency_plots/. + + Top row: ring_1d | torus_2d (2×3) + Bot row: mesh_2d_no_wrap (2×3) | cube-level reduction in SIP 0 + """ + import matplotlib.gridspec as gridspec + import matplotlib.pyplot as plt + + _SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True) + fig = plt.figure(figsize=(16, 10), facecolor="white") + gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.30, wspace=0.10) + ax_ring = fig.add_subplot(gs[0, 0]) + ax_torus = fig.add_subplot(gs[0, 1]) + ax_mesh = fig.add_subplot(gs[1, 0]) + ax_cube = fig.add_subplot(gs[1, 1]) + + _draw_ring_topology(ax_ring) + _draw_grid_topology(ax_torus, "torus", n_rows=2, n_cols=3) + _draw_grid_topology(ax_mesh, "mesh", n_rows=2, n_cols=3) + _draw_cube_reduction(ax_cube) + + fig.suptitle( + "Allreduce topology — device-level (top: ring, torus, mesh) " + "and cube-level reduction in SIP 0", + fontsize=14, fontweight="bold", color=_PALETTE_TEXT, y=0.98, + ) + out_path = _SWEEP_OUT_DIR / "topology.png" + fig.savefig(out_path, dpi=130, bbox_inches="tight", + facecolor=fig.get_facecolor()) + plt.close(fig) + return str(out_path) + + +def test_emit_topology_diagram(): + """Emit topology.png alongside the sweep plots. Pure plotting; no sim.""" + out = emit_topology_diagram() + assert Path(out).exists()