Allreduce sweep: parametrized + xdist parallelism + topology diagram

Refactor the latency sweep from one giant test into 36 parametrized
cases that run in parallel under xdist (~6-8x faster: 1:49 instead of
~10 min). Each case writes a JSON row to a staging dir; conftest
sessionfinish hook aggregates rows on the controller node into
summary.csv and the per-topology + overview plots.

Aggregator gains a CSV fallback so plot-only tweaks no longer require
re-running the sweep.

Overview plot updates:
- 96 KB explicit x-axis marker with vertical dotted line
- horizontal theoretical 2D-torus reference (10600 ns)
- annotation showing both theoretical and simulated values at 96 KB
- drop overlapping 128 KB tick

New topology.png: 2x2 panel diagram showing device-level topology
(ring, torus 2x3, mesh 2x3) and the cube-level reduction inside SIP 0.
Wrap arrows anchor on box edges and arc outside rows/columns so they
do not overlap any SIP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 16:43:19 -07:00
parent 1c33afec55
commit 04c912f53e
8 changed files with 559 additions and 101 deletions
Binary file not shown.

Before

Width:  |  Height:  |  Size: 39 KiB

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 71 KiB

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

After

Width:  |  Height:  |  Size: 41 KiB

+25 -22
View File
@@ -1,26 +1,4 @@
algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns
intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937
intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947
intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992
intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865
intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865
intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865
intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865
intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965
intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957
intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932
intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903
intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923
intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993
intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905
intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985
intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985
intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985
intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985
intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777
intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705
intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952
intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923
intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993 intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993
intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55 intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55
intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975 intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975
@@ -32,3 +10,28 @@ intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959
intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945 intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945
intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937 intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937
intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872 intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872
intercube_allreduce,mesh_2d_no_wrap,6,49152,98304,1572864,19817.049999999872
intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937
intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947
intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992
intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865
intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865
intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865
intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865
intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965
intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957
intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932
intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903
intercube_allreduce,ring_1d,6,49152,98304,1572864,18995.879999999917
intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923
intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993
intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905
intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985
intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985
intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985
intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985
intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777
intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705
intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952
intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923
intercube_allreduce,torus_2d,6,49152,98304,1572864,16195.479999999952
1 algorithm sip_topology n_sips n_elem bytes_per_pe bytes_per_sip latency_ns
intercube_allreduce ring_1d 6 8 16 256 3073.1299999999937
intercube_allreduce ring_1d 6 32 64 1024 3079.8799999999947
intercube_allreduce ring_1d 6 64 128 2048 3088.879999999992
intercube_allreduce ring_1d 6 128 256 4096 3106.8799999999865
intercube_allreduce ring_1d 6 512 1024 16384 3225.8799999999865
intercube_allreduce ring_1d 6 1024 2048 32768 3391.8799999999865
intercube_allreduce ring_1d 6 2048 4096 65536 3723.8799999999865
intercube_allreduce ring_1d 6 4096 8192 131072 4387.879999999965
intercube_allreduce ring_1d 6 8192 16384 262144 5715.879999999957
intercube_allreduce ring_1d 6 16384 32768 524288 8371.879999999932
intercube_allreduce ring_1d 6 32768 65536 1048576 13683.879999999903
intercube_allreduce torus_2d 6 8 16 256 2190.4799999999923
intercube_allreduce torus_2d 6 32 64 1024 2196.479999999993
intercube_allreduce torus_2d 6 64 128 2048 2204.4799999999905
intercube_allreduce torus_2d 6 128 256 4096 2220.479999999985
intercube_allreduce torus_2d 6 512 1024 16384 2325.479999999985
intercube_allreduce torus_2d 6 1024 2048 32768 2471.479999999985
intercube_allreduce torus_2d 6 2048 4096 65536 2763.479999999985
intercube_allreduce torus_2d 6 4096 8192 131072 3347.4799999999777
intercube_allreduce torus_2d 6 8192 16384 262144 4515.4799999999705
intercube_allreduce torus_2d 6 16384 32768 524288 6851.479999999952
intercube_allreduce torus_2d 6 32768 65536 1048576 11523.479999999923
2 intercube_allreduce mesh_2d_no_wrap 6 8 16 256 3508.4249999999993
3 intercube_allreduce mesh_2d_no_wrap 6 32 64 1024 3515.55
4 intercube_allreduce mesh_2d_no_wrap 6 64 128 2048 3525.0499999999975
10 intercube_allreduce mesh_2d_no_wrap 6 8192 16384 262144 6217.049999999945
11 intercube_allreduce mesh_2d_no_wrap 6 16384 32768 524288 8937.049999999937
12 intercube_allreduce mesh_2d_no_wrap 6 32768 65536 1048576 14377.049999999872
13 intercube_allreduce mesh_2d_no_wrap 6 49152 98304 1572864 19817.049999999872
14 intercube_allreduce ring_1d 6 8 16 256 3073.1299999999937
15 intercube_allreduce ring_1d 6 32 64 1024 3079.8799999999947
16 intercube_allreduce ring_1d 6 64 128 2048 3088.879999999992
17 intercube_allreduce ring_1d 6 128 256 4096 3106.8799999999865
18 intercube_allreduce ring_1d 6 512 1024 16384 3225.8799999999865
19 intercube_allreduce ring_1d 6 1024 2048 32768 3391.8799999999865
20 intercube_allreduce ring_1d 6 2048 4096 65536 3723.8799999999865
21 intercube_allreduce ring_1d 6 4096 8192 131072 4387.879999999965
22 intercube_allreduce ring_1d 6 8192 16384 262144 5715.879999999957
23 intercube_allreduce ring_1d 6 16384 32768 524288 8371.879999999932
24 intercube_allreduce ring_1d 6 32768 65536 1048576 13683.879999999903
25 intercube_allreduce ring_1d 6 49152 98304 1572864 18995.879999999917
26 intercube_allreduce torus_2d 6 8 16 256 2190.4799999999923
27 intercube_allreduce torus_2d 6 32 64 1024 2196.479999999993
28 intercube_allreduce torus_2d 6 64 128 2048 2204.4799999999905
29 intercube_allreduce torus_2d 6 128 256 4096 2220.479999999985
30 intercube_allreduce torus_2d 6 512 1024 16384 2325.479999999985
31 intercube_allreduce torus_2d 6 1024 2048 32768 2471.479999999985
32 intercube_allreduce torus_2d 6 2048 4096 65536 2763.479999999985
33 intercube_allreduce torus_2d 6 4096 8192 131072 3347.4799999999777
34 intercube_allreduce torus_2d 6 8192 16384 262144 4515.4799999999705
35 intercube_allreduce torus_2d 6 16384 32768 524288 6851.479999999952
36 intercube_allreduce torus_2d 6 32768 65536 1048576 11523.479999999923
37 intercube_allreduce torus_2d 6 49152 98304 1572864 16195.479999999952
Binary file not shown.

After

Width:  |  Height:  |  Size: 194 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 41 KiB

+34
View File
@@ -7,11 +7,45 @@ stateful/SimPy-event-consuming and MUST NOT be shared).
""" """
from __future__ import annotations from __future__ import annotations
import os
import pytest import pytest
from kernbench.topology.builder import resolve_topology from kernbench.topology.builder import resolve_topology
def pytest_sessionfinish(session, exitstatus):
"""Aggregate parametrized sweep rows into combined CSV + PNG plots.
Runs on the controller node only (xdist worker processes set
``PYTEST_XDIST_WORKER``; we skip those). Idempotent — does nothing
if no sweep rows are present (e.g., when the sweep was filtered out).
"""
if os.environ.get("PYTEST_XDIST_WORKER"):
return
import importlib.util
import sys
from pathlib import Path
mod_path = Path(__file__).parent / "test_allreduce_multidevice.py"
if not mod_path.exists():
return
spec = importlib.util.spec_from_file_location(
"_test_allreduce_multidevice_for_aggregate", mod_path,
)
if spec is None or spec.loader is None:
return
mod = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = mod
try:
spec.loader.exec_module(mod)
agg = getattr(mod, "_aggregate_sweep_plots", None)
if agg is not None:
agg()
except Exception as e:
print(f"[conftest] sweep aggregation failed: {e}")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def topology(): def topology():
"""Session-scoped parsed topology (immutable graph + spec). """Session-scoped parsed topology (immutable graph + spec).
+500 -79
View File
@@ -269,29 +269,143 @@ def test_allreduce(
assert result["ok_cubes"] > 0 assert result["ok_cubes"] > 0
# ── Latency sweep ───────────────────────────────────────────────────── # ── Latency sweep (parametrized + xdist-friendly) ─────────────────────
# avoid 16 (== n_cubes, dim_map collision). Goes up to 1 MB per SIP: # avoid 16 (== n_cubes, dim_map collision). Goes up to 96 KB per PE:
# bytes_per_sip = n_cubes * n_elem * 2 = 32 * n_elem. # bytes_per_pe = n_elem * 2 (f16). 49152 elem * 2 = 96 KB / PE.
_SWEEP_N_ELEM = [ _SWEEP_N_ELEM = [
8, 32, 64, 128, 512, 1024, 2048, 8, 32, 64, 128, 512, 1024, 2048,
4096, 8192, 16384, 32768, 4096, 8192, 16384, 32768, 49152,
] ]
_ELEM_BYTES_F16 = 2 _ELEM_BYTES_F16 = 2
_SWEEP_TOPOLOGIES = [
("intercube_allreduce", "ring_1d", 6, None, None),
("intercube_allreduce", "torus_2d", 6, 2, 3),
("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
]
def test_allreduce_latency_sweep(tmp_path): # Shared on-disk staging dir for parametrized sweep rows. Each
"""Sweep n_elem across each SIP topology; record max(pe_exec_ns) # parametrized invocation writes one JSON file here; the aggregator
as the critical-path kernel latency. Emits CSV + PNG plots to # (run from conftest.pytest_sessionfinish) reads them and emits the
tests/allreduce_latency_plots/. # combined CSV + PNG plots.
_SWEEP_OUT_DIR = Path(__file__).parent / "allreduce_latency_plots"
_SWEEP_ROWS_DIR = _SWEEP_OUT_DIR / "_rows"
def _sweep_params():
out = []
for algorithm, sip_topology, n_sips, sip_w, sip_h in _SWEEP_TOPOLOGIES:
for n_elem in _SWEEP_N_ELEM:
out.append(pytest.param(
algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
id=f"{sip_topology}-n_elem{n_elem}",
))
return out
@pytest.mark.parametrize(
"algorithm,sip_topology,n_sips,sip_w,sip_h,n_elem", _sweep_params(),
)
def test_allreduce_latency_one(
tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
):
"""One config of the latency sweep. xdist parallelizes across params.
Writes a single JSON row to ``_SWEEP_ROWS_DIR``. The conftest
sessionfinish hook aggregates rows into CSV + plots after all
parametrized cases finish.
"""
import json
topo_path, ccl_path = _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
n_elem_override=n_elem,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}",
spec=spec,
) as ctx:
result = run_allreduce(
ctx, engine, spec,
algorithm=algorithm, ccl_yaml=ccl_path,
)
assert result["ok_cubes"] > 0
pe_exec_vals = [
float(tr.get("pe_exec_ns", 0.0) or 0.0)
for _, (_, tr) in engine._results.items()
if isinstance(tr, dict)
]
crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0
cm = spec["sip"]["cube_mesh"]
n_cubes = int(cm["w"]) * int(cm["h"])
bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16
bytes_per_pe = n_elem * _ELEM_BYTES_F16
record = {
"algorithm": algorithm,
"sip_topology": sip_topology,
"n_sips": n_sips,
"n_elem": n_elem,
"bytes_per_pe": bytes_per_pe,
"bytes_per_sip": bytes_per_sip,
"latency_ns": crit_ns,
}
_SWEEP_ROWS_DIR.mkdir(parents=True, exist_ok=True)
row_path = _SWEEP_ROWS_DIR / f"{sip_topology}_{n_elem}.json"
with open(row_path, "w", encoding="utf-8") as f:
json.dump(record, f)
def _aggregate_sweep_plots() -> bool:
"""Read all per-config rows and emit CSV + PNG plots.
Called by ``conftest.pytest_sessionfinish`` (controller node only).
Returns True if any rows were aggregated, False otherwise.
""" """
import csv import csv
import json
row_files = sorted(_SWEEP_ROWS_DIR.glob("*.json")) \
if _SWEEP_ROWS_DIR.exists() else []
records: list[dict] = []
if row_files:
for p in row_files:
with open(p, encoding="utf-8") as f:
records.append(json.load(f))
else:
# Fallback: replot from existing summary.csv (skip sweep re-run).
summary_path = _SWEEP_OUT_DIR / "summary.csv"
if not summary_path.exists():
return False
with open(summary_path, encoding="utf-8") as f:
for row in csv.DictReader(f):
records.append({
"algorithm": row["algorithm"],
"sip_topology": row["sip_topology"],
"n_sips": int(row["n_sips"]),
"n_elem": int(row["n_elem"]),
"bytes_per_pe": int(row["bytes_per_pe"]),
"bytes_per_sip": int(row["bytes_per_sip"]),
"latency_ns": float(row["latency_ns"]),
})
if not records:
return False
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter from matplotlib.ticker import FuncFormatter
def _fmt_bytes(x, _pos): def _fmt_bytes(x, _pos):
"""Format tick as B / KB / MB."""
if x <= 0: if x <= 0:
return "0" return "0"
if x >= 1024 * 1024: if x >= 1024 * 1024:
@@ -302,86 +416,27 @@ def test_allreduce_latency_sweep(tmp_path):
_bytes_fmt = FuncFormatter(_fmt_bytes) _bytes_fmt = FuncFormatter(_fmt_bytes)
out_dir = Path(__file__).parent / "allreduce_latency_plots" _SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
out_dir.mkdir(parents=True, exist_ok=True) with open(_SWEEP_OUT_DIR / "summary.csv", "w",
records: list[dict] = [] newline="", encoding="utf-8") as f:
# Apples-to-apples: same n_sips across all three topologies.
for algorithm, sip_topology, n_sips, sip_w, sip_h in [
("intercube_allreduce", "ring_1d", 6, None, None),
("intercube_allreduce", "torus_2d", 6, 2, 3),
("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
]:
for n_elem in _SWEEP_N_ELEM:
sub = tmp_path / f"{sip_topology}_{n_elem}"
sub.mkdir()
topo_path, ccl_path = _write_temp_configs(
sub, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
n_elem_override=n_elem,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}",
spec=spec,
) as ctx:
result = run_allreduce(
ctx, engine, spec,
algorithm=algorithm, ccl_yaml=ccl_path,
)
assert result["ok_cubes"] > 0
pe_exec_vals = [
float(tr.get("pe_exec_ns", 0.0) or 0.0)
for _, (_, tr) in engine._results.items()
if isinstance(tr, dict)
]
crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0
cm = spec["sip"]["cube_mesh"]
n_cubes = int(cm["w"]) * int(cm["h"])
bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16
# pe="replicate" + num_pes=1 → one active PE per cube owns
# the whole cube row. Per-PE bytes == per-cube-tile bytes ==
# per-message bytes over the IPCQ fabric.
bytes_per_pe = n_elem * _ELEM_BYTES_F16
records.append({
"algorithm": algorithm,
"sip_topology": sip_topology,
"n_sips": n_sips,
"n_elem": n_elem,
"bytes_per_pe": bytes_per_pe,
"bytes_per_sip": bytes_per_sip,
"latency_ns": crit_ns,
})
print(
f"[{sip_topology:<16} n_sips={n_sips} n_elem={n_elem:>5} "
f"bytes/pe={bytes_per_pe:>7} bytes/sip={bytes_per_sip:>9}] "
f"pe_exec_max = {crit_ns:8.1f} ns"
)
with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=[ w = csv.DictWriter(f, fieldnames=[
"algorithm", "sip_topology", "n_sips", "n_elem", "algorithm", "sip_topology", "n_sips", "n_elem",
"bytes_per_pe", "bytes_per_sip", "latency_ns", "bytes_per_pe", "bytes_per_sip", "latency_ns",
]) ])
w.writeheader() w.writeheader()
for r in records: for r in sorted(records, key=lambda r: (
r["sip_topology"], r["bytes_per_pe"],
)):
w.writerow(r) w.writerow(r)
topologies = sorted({r["sip_topology"] for r in records}) topologies = sorted({r["sip_topology"] for r in records})
# Per-topology plots, log-scale x-axis = bytes per PE.
for topo_name in topologies: for topo_name in topologies:
rs = sorted( rs = sorted(
[r for r in records if r["sip_topology"] == topo_name], [r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"], key=lambda r: r["bytes_per_pe"],
) )
if not rs:
continue
xs = [r["bytes_per_pe"] for r in rs] xs = [r["bytes_per_pe"] for r in rs]
ys = [r["latency_ns"] for r in rs] ys = [r["latency_ns"] for r in rs]
title = ( title = (
@@ -397,17 +452,20 @@ def test_allreduce_latency_sweep(tmp_path):
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(_bytes_fmt) ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout() fig.tight_layout()
fig.savefig(out_dir / f"{topo_name}.png", dpi=120) fig.savefig(_SWEEP_OUT_DIR / f"{topo_name}.png", dpi=120)
plt.close(fig) plt.close(fig)
colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange", colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange",
"mesh_2d_no_wrap": "tab:green"} "mesh_2d_no_wrap": "tab:green"}
THEORETICAL_TORUS_2D_6SIP_NS = 10600.0
fig, ax = plt.subplots(figsize=(9, 6)) fig, ax = plt.subplots(figsize=(9, 6))
for topo_name in topologies: for topo_name in topologies:
rs = sorted( rs = sorted(
[r for r in records if r["sip_topology"] == topo_name], [r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"], key=lambda r: r["bytes_per_pe"],
) )
if not rs:
continue
ax.plot( ax.plot(
[r["bytes_per_pe"] for r in rs], [r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs], [r["latency_ns"] for r in rs],
@@ -415,15 +473,378 @@ def test_allreduce_latency_sweep(tmp_path):
label=f"{topo_name} (n_sips={rs[0]['n_sips']})", label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
color=colors.get(topo_name), color=colors.get(topo_name),
) )
ax.axhline(
y=THEORETICAL_TORUS_2D_6SIP_NS,
color="tab:red", linestyle="--", linewidth=1.5,
label=f"theoretical torus_2d (6 SIPs) = "
f"{THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns",
)
BYTES_96KB = 96 * 1024
ax.axvline(
x=BYTES_96KB, ymin=0, ymax=1,
color="tab:red", linestyle=":", linewidth=1.2,
)
ax.plot(
[BYTES_96KB], [THEORETICAL_TORUS_2D_6SIP_NS],
marker="x", color="tab:red", markersize=10, markeredgewidth=2,
)
# Find simulated torus_2d latency at 96 KB (if present) for direct
# comparison with the theoretical value.
sim_torus_at_96kb = next(
(r["latency_ns"] for r in records
if r["sip_topology"] == "torus_2d" and r["bytes_per_pe"] == BYTES_96KB),
None,
)
if sim_torus_at_96kb is not None:
ax.plot(
[BYTES_96KB], [sim_torus_at_96kb],
marker="o", color="tab:orange",
markersize=10, markeredgecolor="black", markeredgewidth=1.2,
)
ax.annotate(
f"96 KB\n"
f"theoretical = {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns\n"
f"simulated = {sim_torus_at_96kb:.0f} ns",
xy=(BYTES_96KB, sim_torus_at_96kb),
xytext=(10, -20), textcoords="offset points",
color="tab:red", fontsize=9,
)
else:
ax.annotate(
f"96 KB\n→ theoretical {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns",
xy=(BYTES_96KB, THEORETICAL_TORUS_2D_6SIP_NS),
xytext=(8, -20), textcoords="offset points",
color="tab:red", fontsize=9,
)
ax.set_xscale("log", base=2) ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)") ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("max pe_exec_ns (critical path)") ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title("Multi-device allreduce latency by topology") ax.set_title("Multi-device allreduce latency by topology")
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
# Drop 128 KB tick (overlaps visually with the explicit 96 KB marker)
# and add 96 KB.
BYTES_128KB = 128 * 1024
existing_ticks = [t for t in ax.get_xticks() if int(t) != BYTES_128KB]
if BYTES_96KB not in existing_ticks:
existing_ticks.append(BYTES_96KB)
ax.set_xticks(sorted(existing_ticks))
ax.set_xlim(left=min(r["bytes_per_pe"] for r in records) / 2,
right=BYTES_96KB * 1.5)
ax.legend() ax.legend()
ax.xaxis.set_major_formatter(_bytes_fmt) ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout() fig.tight_layout()
fig.savefig(out_dir / "overview.png", dpi=120) fig.savefig(_SWEEP_OUT_DIR / "overview.png", dpi=120)
plt.close(fig) plt.close(fig)
print(f"\nWrote {out_dir / 'overview.png'}") # Cleanup row staging dir so a partial future run doesn't pick up
# stale rows.
for p in row_files:
try:
p.unlink()
except OSError:
pass
try:
_SWEEP_ROWS_DIR.rmdir()
except OSError:
pass
print(f"\nWrote {_SWEEP_OUT_DIR / 'overview.png'} "
f"from {len(records)} rows")
return True
# ── Topology diagram (device-level + cube-level reduction) ────────────
# Convention: "rows × cols" everywhere, row-major rank assignment
# (rank = row * n_cols + col). For the 2×3 inter-SIP grid, this means
# 2 rows × 3 columns: SIP 0 1 2 / SIP 3 4 5.
_PALETTE_BG = "#fafbfd"
_PALETTE_FRAME = "#3a3f4a"
_PALETTE_BLUE = "#2c6fb6"
_PALETTE_GREEN = "#2e8a4e"
_PALETTE_TEXT = "#1f2530"
_PALETTE_BOX_FILL = "#eaf2fb"
_PALETTE_BOX_EDGE = "#2c4a78"
_PALETTE_ROOT_FILL = "#ffd9b8"
_PALETTE_ROOT_EDGE = "#bd5a14"
def _arrow(ax, xy_from, xy_to, color="black", lw=1.4, alpha=1.0,
style="-|>", curve=0.0):
from matplotlib.patches import FancyArrowPatch
arrow = FancyArrowPatch(
xy_from, xy_to,
arrowstyle=style, mutation_scale=12,
color=color, lw=lw, alpha=alpha,
connectionstyle=f"arc3,rad={curve}",
)
ax.add_patch(arrow)
def _draw_sip_box(ax, cx, cy, w, h, label, *, fill=_PALETTE_BOX_FILL,
edge=_PALETTE_BOX_EDGE, text_color=_PALETTE_TEXT,
font=10):
from matplotlib.patches import FancyBboxPatch
box = FancyBboxPatch(
(cx - w / 2, cy - h / 2), w, h,
boxstyle="round,pad=0.02,rounding_size=0.10",
linewidth=1.4, edgecolor=edge, facecolor=fill,
)
ax.add_patch(box)
ax.text(cx, cy, label, ha="center", va="center",
color=text_color, fontsize=font, fontweight="bold")
def _frame_panel(ax, title, lim_x=10.0, lim_y=6.0):
"""Set up a square-ish panel with a visible outer border."""
from matplotlib.patches import FancyBboxPatch
ax.set_xlim(0, lim_x)
ax.set_ylim(0, lim_y)
ax.set_aspect("equal")
ax.axis("off")
ax.set_facecolor(_PALETTE_BG)
border = FancyBboxPatch(
(0.05, 0.05), lim_x - 0.10, lim_y - 0.10,
boxstyle="round,pad=0.01,rounding_size=0.12",
linewidth=1.4, edgecolor=_PALETTE_FRAME, facecolor=_PALETTE_BG,
zorder=0,
)
ax.add_patch(border)
ax.set_title(title, fontsize=12, fontweight="bold",
color=_PALETTE_TEXT, pad=8)
def _draw_ring_topology(ax):
_frame_panel(ax, "ring_1d (6 SIPs)", lim_x=10.0, lim_y=6.0)
xs = [1.2, 2.7, 4.2, 5.7, 7.2, 8.7]
y = 3.1
box_w, box_h = 1.05, 0.9
for i, x in enumerate(xs):
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {i}")
# Forward ring (global_E) — adjacent neighbours, anchored to box edges.
for i in range(5):
_arrow(ax, (xs[i] + box_w / 2, y),
(xs[i + 1] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6)
# Wrap (SIP 5 → SIP 0). Anchor at right-CENTER of SIP 5 and
# left-CENTER of SIP 0; arc OUTSIDE (above) the row so it does not
# overlap any of the SIP boxes in between.
_arrow(
ax,
(xs[5] + box_w / 2, y),
(xs[0] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6, curve=-0.40,
)
ax.text(5.0, y + 2.0, "global_E (ring)", ha="center",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(5.0, y - 1.5,
"(global_W = reverse direction, used by the algorithm)",
ha="center", color="gray", fontsize=8, style="italic")
def _draw_grid_topology(ax, kind, *, n_rows=2, n_cols=3):
"""kind ∈ {'torus', 'mesh'}. Lays out as n_rows × n_cols (row-major).
For the sweep we use 2 rows × 3 cols → SIP layout::
row 0: SIP 0 SIP 1 SIP 2
row 1: SIP 3 SIP 4 SIP 5
"""
title = f"torus_2d ({n_rows}×{n_cols}, 6 SIPs)" if kind == "torus" \
else f"mesh_2d_no_wrap ({n_rows}×{n_cols}, 6 SIPs)"
_frame_panel(ax, title, lim_x=10.0, lim_y=6.0)
col_xs = [2.0, 5.0, 8.0] # 3 cols
row_ys = [4.3, 1.8] # 2 rows
box_w, box_h = 1.3, 0.95
pos: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(n_rows):
for c in range(n_cols):
rank = r * n_cols + c
x, y = col_xs[c], row_ys[r]
pos[(r, c)] = (x, y)
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {rank}")
# Row edges (E↔W) — between adjacent columns within each row.
for r in range(n_rows):
for c in range(n_cols - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r, c + 1)]
_arrow(ax, (x0 + box_w / 2, y0 + 0.10),
(x1 - box_w / 2, y1 + 0.10),
color=_PALETTE_BLUE, lw=1.5)
_arrow(ax, (x1 - box_w / 2, y1 - 0.10),
(x0 + box_w / 2, y0 - 0.10),
color=_PALETTE_BLUE, lw=1.5)
# Col edges (N↔S) — between adjacent rows within each column.
for c in range(n_cols):
for r in range(n_rows - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r + 1, c)]
_arrow(ax, (x0 - 0.12, y0 - box_h / 2),
(x1 - 0.12, y1 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
_arrow(ax, (x1 + 0.12, y1 + box_h / 2),
(x0 + 0.12, y0 - box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
# Wrap arrows for torus only — anchor to the centre of the OUTER
# edge of the end SIPs and arc OUTSIDE the row/column so they do
# not overlap the SIPs in between.
if kind == "torus":
# Row wrap: last col → first col. Top row arcs UP, bottom row
# arcs DOWN, so each wrap sits clearly outside its own row.
for r in range(n_rows):
x0, y0 = pos[(r, 0)]
x1, y1 = pos[(r, n_cols - 1)]
curve = -0.45 if r == 0 else 0.45
_arrow(
ax,
(x1 + box_w / 2, y1),
(x0 - box_w / 2, y0),
color=_PALETTE_BLUE, lw=1.5,
curve=curve, alpha=0.9,
)
# Col wrap: last row → first row. Leftmost col arcs LEFT,
# rightmost col arcs RIGHT. Middle col(s) get a small inline
# marker + legend note (drawing them through the panel would
# collide with the row arrows).
for c in range(n_cols):
x0, y0 = pos[(0, c)]
x1, y1 = pos[(n_rows - 1, c)]
if c == 0:
curve = 0.55
elif c == n_cols - 1:
curve = -0.55
else:
continue # skip middle col — see legend note
_arrow(
ax,
(x1, y1 - box_h / 2),
(x0, y0 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5,
curve=curve, alpha=0.9,
)
ax.text(0.7, 5.6, "global_E/W (row)", color=_PALETTE_BLUE,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 5.25, "global_N/S (col)", color=_PALETTE_GREEN,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 4.92,
"wrap = torus" if kind == "torus" else "no wrap = mesh",
color="gray", fontsize=8, style="italic")
if kind == "torus" and n_cols > 2:
ax.text(0.7, 0.3,
"(middle-col wrap omitted for clarity — every row "
"and every column wraps)",
color="gray", fontsize=7.5, style="italic")
def _draw_cube_reduction(ax):
"""4×4 cube grid inside SIP 0 — compact layout with phase legend."""
from matplotlib.patches import Rectangle
_frame_panel(ax, "Cube-level reduction inside SIP 0 (4×4 cubes)",
lim_x=10.0, lim_y=6.0)
cube_w = 0.65
cube_gap = 0.18
# Center the 4×4 grid in the left half of the panel.
grid_total = 4 * cube_w + 3 * cube_gap
grid_x0 = 0.7
grid_y0 = 0.7
centers: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(4):
for c in range(4):
cx = grid_x0 + c * (cube_w + cube_gap) + cube_w / 2
cy = grid_y0 + (3 - r) * (cube_w + cube_gap) + cube_w / 2
centers[(r, c)] = (cx, cy)
cube_id = r * 4 + c
is_root = (r == 3 and c == 3)
face = _PALETTE_ROOT_FILL if is_root else _PALETTE_BOX_FILL
edge = _PALETTE_ROOT_EDGE if is_root else _PALETTE_BOX_EDGE
rect = Rectangle(
(cx - cube_w / 2, cy - cube_w / 2), cube_w, cube_w,
linewidth=1.2, edgecolor=edge, facecolor=face,
)
ax.add_patch(rect)
label = f"c{cube_id}"
ax.text(cx, cy, label, ha="center", va="center",
fontsize=7.5, fontweight="bold",
color=_PALETTE_ROOT_EDGE if is_root
else _PALETTE_TEXT)
# Phase 1: row reduce W→E.
for r in range(4):
for c in range(3):
x0, y0 = centers[(r, c)]
x1, y1 = centers[(r, c + 1)]
_arrow(ax, (x0 + cube_w / 2, y0), (x1 - cube_w / 2, y1),
color=_PALETTE_BLUE, lw=1.5)
# Phase 2: col reduce N→S along rightmost column.
for r in range(3):
x0, y0 = centers[(r, 3)]
x1, y1 = centers[(r + 1, 3)]
_arrow(ax, (x0, y0 - cube_w / 2), (x1, y1 + cube_w / 2),
color=_PALETTE_GREEN, lw=1.7)
# Phase legend on the right side.
legend_x = grid_x0 + grid_total + 0.55
ax.text(legend_x, 5.0, "Phase 1: row reduce (W → E)",
color=_PALETTE_BLUE, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.55, "Phase 2: col reduce (N → S, rightmost col)",
color=_PALETTE_GREEN, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.10, "Phase 3: inter-SIP exchange at root cube",
color=_PALETTE_ROOT_EDGE, fontsize=10, fontweight="bold")
ax.text(legend_x, 3.65, "Phase 4: col broadcast (S → N)",
color=_PALETTE_GREEN, fontsize=10, style="italic")
ax.text(legend_x, 3.20, "Phase 5: row broadcast (E → W)",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(legend_x, 2.55,
"(broadcast phases reverse phases 2 & 1)",
color="gray", fontsize=8.5, style="italic")
ax.text(legend_x, 1.7,
"Root cube (c15, bottom-right) is the only\n"
"cube that performs the inter-SIP exchange.",
color=_PALETTE_ROOT_EDGE, fontsize=9, style="italic")
def emit_topology_diagram() -> str:
"""Emit a 2×2-panel topology diagram into allreduce_latency_plots/.
Top row: ring_1d | torus_2d (2×3)
Bot row: mesh_2d_no_wrap (2×3) | cube-level reduction in SIP 0
"""
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
fig = plt.figure(figsize=(16, 10), facecolor="white")
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.30, wspace=0.10)
ax_ring = fig.add_subplot(gs[0, 0])
ax_torus = fig.add_subplot(gs[0, 1])
ax_mesh = fig.add_subplot(gs[1, 0])
ax_cube = fig.add_subplot(gs[1, 1])
_draw_ring_topology(ax_ring)
_draw_grid_topology(ax_torus, "torus", n_rows=2, n_cols=3)
_draw_grid_topology(ax_mesh, "mesh", n_rows=2, n_cols=3)
_draw_cube_reduction(ax_cube)
fig.suptitle(
"Allreduce topology — device-level (top: ring, torus, mesh) "
"and cube-level reduction in SIP 0",
fontsize=14, fontweight="bold", color=_PALETTE_TEXT, y=0.98,
)
out_path = _SWEEP_OUT_DIR / "topology.png"
fig.savefig(out_path, dpi=130, bbox_inches="tight",
facecolor=fig.get_facecolor())
plt.close(fig)
return str(out_path)
def test_emit_topology_diagram():
"""Emit topology.png alongside the sweep plots. Pure plotting; no sim."""
out = emit_topology_diagram()
assert Path(out).exists()