From 19dfc86dc39f9ff7eb8712ca838c98967638ac1f Mon Sep 17 00:00:00 2001 From: Mukesh Garg Date: Mon, 27 Apr 2026 10:16:29 -0700 Subject: [PATCH] Allreduce latency sweep across topologies and data sizes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds test_allreduce_latency_sweep that runs the existing intercube allreduce kernel under three SIP topologies (ring_1d, torus_2d, mesh_2d_no_wrap, all at n_sips=4) across 11 data sizes from 256 B/SIP up to 1 MB/SIP. For each point, captures max(pe_exec_ns) — the critical-path kernel time — and emits CSV plus log-x and linear-x plots, both per-topology and combined overview, with KB/MB-formatted tick labels. Reuses run_allreduce + _write_temp_configs and adds a slot_size auto-bump when n_elem*2 exceeds the default IPCQ slot. Sweep skips n_elem=16 because the runtime's dim_map scalar-arg remapping (context.py:761) collides any int-valued kernel scalar that matches a global tensor dim with its local shard size. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_allreduce_multidevice.py | 201 +++++++++++++++++++++++++++- 1 file changed, 200 insertions(+), 1 deletion(-) diff --git a/tests/test_allreduce_multidevice.py b/tests/test_allreduce_multidevice.py index 397a34c..347d61e 100644 --- a/tests/test_allreduce_multidevice.py +++ b/tests/test_allreduce_multidevice.py @@ -179,7 +179,9 @@ CONFIGS = [ ] -def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm): +def _write_temp_configs( + tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None, +): """Write temp topology.yaml and ccl.yaml with the given overrides.""" with open(TOPOLOGY_PATH) as f: topo_cfg = yaml.safe_load(f) @@ -193,6 +195,15 @@ def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm): with open(ccl_path) as f: ccl_cfg = yaml.safe_load(f) ccl_cfg["defaults"]["algorithm"] = algorithm + if n_elem_override is not None: + ccl_cfg.setdefault("algorithms", {}).setdefault( + algorithm, {}, + )["n_elem"] = int(n_elem_override) + # Ensure IPCQ slot is big enough for the per-message payload. + per_msg_bytes = int(n_elem_override) * 2 # f16 + default_slot = int(ccl_cfg["defaults"].get("slot_size", 4096)) + if per_msg_bytes > default_slot: + ccl_cfg["defaults"]["slot_size"] = per_msg_bytes tmp_ccl = tmp_path / "ccl.yaml" with open(tmp_ccl, "w") as f: yaml.dump(ccl_cfg, f, default_flow_style=False) @@ -220,3 +231,191 @@ def test_allreduce(tmp_path, algorithm, sip_topology, n_sips): algorithm=algorithm, ccl_yaml=ccl_path, ) assert result["ok_cubes"] > 0 + + +# ── Latency sweep ───────────────────────────────────────────────────── + +# avoid 16 (== n_cubes, dim_map collision). Goes up to 1 MB per SIP: +# bytes_per_sip = n_cubes * n_elem * 2 = 32 * n_elem. +_SWEEP_N_ELEM = [ + 8, 32, 64, 128, 512, 1024, 2048, + 4096, 8192, 16384, 32768, +] +_ELEM_BYTES_F16 = 2 + + +def test_allreduce_latency_sweep(tmp_path): + """Sweep n_elem across each SIP topology; record max(pe_exec_ns) + as the critical-path kernel latency. Emits CSV + PNG plots to + tests/allreduce_latency_plots/. + """ + import csv + + import matplotlib.pyplot as plt + from matplotlib.ticker import FuncFormatter + + def _fmt_bytes(x, _pos): + """Format tick as B / KB / MB.""" + if x <= 0: + return "0" + if x >= 1024 * 1024: + return f"{x / (1024 * 1024):.0f} MB" + if x >= 1024: + return f"{x / 1024:.0f} KB" + return f"{x:.0f} B" + + _bytes_fmt = FuncFormatter(_fmt_bytes) + + out_dir = Path(__file__).parent / "allreduce_latency_plots" + out_dir.mkdir(parents=True, exist_ok=True) + records: list[dict] = [] + + # Apples-to-apples: same n_sips across all three topologies. + for algorithm, sip_topology, n_sips in [ + ("intercube_allreduce", "ring_1d", 4), + ("intercube_allreduce", "torus_2d", 4), + ("intercube_allreduce", "mesh_2d_no_wrap", 4), + ]: + for n_elem in _SWEEP_N_ELEM: + sub = tmp_path / f"{sip_topology}_{n_elem}" + sub.mkdir() + topo_path, ccl_path = _write_temp_configs( + sub, sip_topology, n_sips, algorithm, + n_elem_override=n_elem, + ) + topo = resolve_topology(topo_path) + engine = GraphEngine(topo.topology_obj, enable_data=True) + spec = topo.topology_obj.spec + + with RuntimeContext( + engine=engine, + target_device=DeviceSelector("all"), + correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}", + spec=spec, + ) as ctx: + result = run_allreduce( + ctx, engine, spec, + algorithm=algorithm, ccl_yaml=ccl_path, + ) + assert result["ok_cubes"] > 0 + + pe_exec_vals = [ + float(tr.get("pe_exec_ns", 0.0) or 0.0) + for _, (_, tr) in engine._results.items() + if isinstance(tr, dict) + ] + crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0 + + cm = spec["sip"]["cube_mesh"] + n_cubes = int(cm["w"]) * int(cm["h"]) + bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16 + # pe="replicate" + num_pes=1 → one active PE per cube owns + # the whole cube row. Per-PE bytes == per-cube-tile bytes == + # per-message bytes over the IPCQ fabric. + bytes_per_pe = n_elem * _ELEM_BYTES_F16 + + records.append({ + "algorithm": algorithm, + "sip_topology": sip_topology, + "n_sips": n_sips, + "n_elem": n_elem, + "bytes_per_pe": bytes_per_pe, + "bytes_per_sip": bytes_per_sip, + "latency_ns": crit_ns, + }) + print( + f"[{sip_topology:<16} n_sips={n_sips} n_elem={n_elem:>5} " + f"bytes/pe={bytes_per_pe:>7} bytes/sip={bytes_per_sip:>9}] " + f"pe_exec_max = {crit_ns:8.1f} ns" + ) + + with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=[ + "algorithm", "sip_topology", "n_sips", "n_elem", + "bytes_per_pe", "bytes_per_sip", "latency_ns", + ]) + w.writeheader() + for r in records: + w.writerow(r) + + topologies = sorted({r["sip_topology"] for r in records}) + # Per-topology plots: log-scale + linear-scale side-by-side. + # X-axis = bytes per PE (per-message payload size). + for topo_name in topologies: + rs = sorted( + [r for r in records if r["sip_topology"] == topo_name], + key=lambda r: r["bytes_per_pe"], + ) + xs = [r["bytes_per_pe"] for r in rs] + ys = [r["latency_ns"] for r in rs] + title = ( + f"Allreduce latency — {topo_name} " + f"(n_sips={rs[0]['n_sips']})" + ) + # Log-scale + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(xs, ys, marker="o", color="tab:blue") + ax.set_xscale("log", base=2) + ax.set_xlabel("Bytes per PE (log scale)") + ax.set_ylabel("max pe_exec_ns (critical path)") + ax.set_title(title) + ax.grid(True, alpha=0.3) + ax.xaxis.set_major_formatter(_bytes_fmt) + fig.tight_layout() + fig.savefig(out_dir / f"{topo_name}.png", dpi=120) + plt.close(fig) + # Linear-scale companion + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(xs, ys, marker="o", color="tab:blue") + ax.set_xlabel("Bytes per PE") + ax.set_ylabel("max pe_exec_ns (critical path)") + ax.set_title(title + " [linear scale]") + ax.grid(True, alpha=0.3) + ax.xaxis.set_major_formatter(_bytes_fmt) + fig.tight_layout() + fig.savefig(out_dir / f"{topo_name}_linear.png", dpi=120) + plt.close(fig) + + # Combined overview — two variants: log-scale (overview.png) and + # linear-scale (overview_linear.png). + colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange", + "mesh_2d_no_wrap": "tab:green"} + + def _draw_overview(log_x: bool, filename: str, title_suffix: str) -> None: + fig, ax = plt.subplots(figsize=(9, 6)) + for topo_name in topologies: + rs = sorted( + [r for r in records if r["sip_topology"] == topo_name], + key=lambda r: r["bytes_per_pe"], + ) + ax.plot( + [r["bytes_per_pe"] for r in rs], + [r["latency_ns"] for r in rs], + marker="o", + label=f"{topo_name} (n_sips={rs[0]['n_sips']})", + color=colors.get(topo_name), + ) + if log_x: + ax.set_xscale("log", base=2) + ax.set_xlabel("Bytes per PE (log scale)") + else: + ax.set_xlabel("Bytes per PE") + ax.set_ylabel("max pe_exec_ns (critical path)") + ax.set_title("Multi-device allreduce latency by topology" + title_suffix) + ax.grid(True, alpha=0.3) + ax.legend() + ax.xaxis.set_major_formatter(_bytes_fmt) + fig.tight_layout() + fig.savefig(out_dir / filename, dpi=120) + plt.close(fig) + + _draw_overview(log_x=True, filename="overview.png", title_suffix="") + _draw_overview( + log_x=False, filename="overview_linear.png", + title_suffix=" [linear scale]", + ) + + print( + f"\nWrote {out_dir / 'overview.png'} + " + f"{out_dir / 'overview_linear.png'}" + )