Files
kernbench2/tests/test_allreduce_multidevice.py
T
mukesh 04c912f53e Allreduce sweep: parametrized + xdist parallelism + topology diagram
Refactor the latency sweep from one giant test into 36 parametrized
cases that run in parallel under xdist (~6-8x faster: 1:49 instead of
~10 min). Each case writes a JSON row to a staging dir; conftest
sessionfinish hook aggregates rows on the controller node into
summary.csv and the per-topology + overview plots.

Aggregator gains a CSV fallback so plot-only tweaks no longer require
re-running the sweep.

Overview plot updates:
- 96 KB explicit x-axis marker with vertical dotted line
- horizontal theoretical 2D-torus reference (10600 ns)
- annotation showing both theoretical and simulated values at 96 KB
- drop overlapping 128 KB tick

New topology.png: 2x2 panel diagram showing device-level topology
(ring, torus 2x3, mesh 2x3) and the cube-level reduction inside SIP 0.
Wrap arrows anchor on box edges and arc outside rows/columns so they
do not overlap any SIP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 16:43:19 -07:00

851 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Config-driven multi-device allreduce test application.
Reads ``ccl.yaml`` + ``topology.yaml``, dynamically loads the kernel
module from ``ccl.yaml → module``, and picks the inter-SIP exchange
pattern from ``topology.yaml → system.sips.topology``.
Run directly::
python -m pytest tests/allreduce_app.py -v -s
"""
from __future__ import annotations
import importlib
import math
from pathlib import Path
from typing import Any
import numpy as np
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
from kernbench.policy.placement.dp import DPPolicy
def _sip_topo_dims(
sip_topo: str, n_sips: int,
spec_w: int | None = None, spec_h: int | None = None,
) -> tuple[int, int]:
if sip_topo == "ring_1d":
return (0, 0)
if spec_w is not None and spec_h is not None:
if spec_w * spec_h != n_sips:
raise ValueError(
f"sip layout {spec_w}x{spec_h} != n_sips ({n_sips})"
)
return (spec_w, spec_h)
side = int(round(math.sqrt(n_sips)))
if side * side != n_sips:
raise ValueError(
f"SIP topology '{sip_topo}' requires square n_sips or "
f"explicit w/h in spec, got {n_sips}"
)
return (side, side)
def run_allreduce(
ctx: Any,
engine: Any,
spec: dict,
*,
algorithm: str | None = None,
ccl_yaml: str | None = None,
) -> dict:
"""Config-driven allreduce: read yaml, load kernel, run.
Everything is resolved from config — no hardcoded kernel imports.
"""
cfg_all = load_ccl_config(ccl_yaml)
cfg = resolve_algorithm_config(cfg_all, algorithm)
# Dynamic import from ccl.yaml → module
algo_module = importlib.import_module(cfg["module"])
kernel_fn = algo_module.kernel
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
n_elem = int(cfg.get("n_elem", 8))
sips_cfg = spec.get("system", {}).get("sips", {})
n_sips = int(sips_cfg.get("count", 1))
sip_topo = str(sips_cfg.get("topology", "ring_1d"))
spec_sip_w = sips_cfg.get("w")
spec_sip_h = sips_cfg.get("h")
spec_sip_w = int(spec_sip_w) if spec_sip_w is not None else None
spec_sip_h = int(spec_sip_h) if spec_sip_h is not None else None
cm = spec["sip"]["cube_mesh"]
cube_w = int(cm["w"])
cube_h = int(cm["h"])
n_cubes = cube_w * cube_h
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
sip_topo_w, sip_topo_h = _sip_topo_dims(
sip_topo, n_sips, spec_w=spec_sip_w, spec_h=spec_sip_h,
)
algo_name = cfg.get("algorithm", "allreduce")
print(f"\n{'=' * 60}")
print(f"algorithm: {algo_name}")
print(f"module: {cfg['module']}")
print(f"sip_topology: {sip_topo}")
print(f"kernel: {kernel_fn.__name__}")
print(f"n_sips: {n_sips}")
print(f"n_cubes: {n_cubes}")
print(f"n_elem: {n_elem}")
print(f"{'=' * 60}")
configure_sfr_intercube_multisip(engine, spec, cfg)
dp = DPPolicy(
cube="row_wise", pe="replicate",
num_pes=1, num_cubes=n_cubes,
)
tensors = []
for sip in range(n_sips):
ctx.ahbm.set_device(sip)
t = ctx.zeros(
(n_cubes, n_elem), dtype="f16", dp=dp,
name=f"sip{sip}",
)
t.copy_(ctx.from_numpy(
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
))
tensors.append(t)
for sip in range(n_sips):
arr = tensors[sip].numpy()
print(f"[SIP {sip}] input cube0[:4] = {arr[0][:4].tolist()} "
f"cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
t_start = engine._env.now
all_pending = []
for sip_rank, t in enumerate(tensors):
pending = ctx.launch(
algo_name, kernel_fn, t,
n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h,
_defer_wait=True,
)
all_pending.extend(pending)
for h, sip_id, meta in all_pending:
ctx.wait(h, _meta=meta)
t_end = engine._env.now
latency_ns = t_end - t_start
print(f"\n[{algo_name} ws={n_sips}] sim latency = "
f"{latency_ns:.1f} ns ({latency_ns / 1000:.3f} us)")
for key, (_, trace) in engine._results.items():
if not isinstance(trace, dict):
continue
total = trace.get("total_ns", 0.0)
pe_exec = trace.get("pe_exec_ns", 0.0) or 0.0
network = total - pe_exec
print(f" [{key}] total={total:.1f} ns "
f"pe_exec={pe_exec:.1f} ns network={network:.1f} ns")
expected = float(n_cubes * sum(range(1, n_sips + 1)))
print()
for sip in range(n_sips):
arr = tensors[sip].numpy()
print(f"[SIP {sip}] output cube0[:4] = {arr[0][:4].tolist()}")
print(f"[SIP {sip}] output cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
ok_cubes = 0
for sip in range(n_sips):
arr = tensors[sip].numpy()
for cube_id in range(n_cubes):
assert np.allclose(
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
), (
f"SIP{sip} cube {cube_id}: "
f"got {arr[cube_id][:4]}, expected {expected}"
)
ok_cubes += 1
print(f"\n {algo_name} (ws={n_sips}): {ok_cubes} OK")
return {
"expected": expected,
"latency_ns": latency_ns,
"ok_cubes": ok_cubes,
}
# ── pytest entry point ───────────────────────────────────────────────
import pytest
import yaml
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
CONFIGS = [
pytest.param(
"intercube_allreduce", "ring_1d", 6, None, None,
id="ring_6sip",
),
pytest.param(
"intercube_allreduce", "torus_2d", 6, 2, 3,
id="torus_6sip_2x3",
),
pytest.param(
"intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3,
id="mesh_6sip_2x3",
),
]
def _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None,
sip_w=None, sip_h=None,
):
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
with open(TOPOLOGY_PATH) as f:
topo_cfg = yaml.safe_load(f)
topo_cfg["system"]["sips"]["count"] = n_sips
topo_cfg["system"]["sips"]["topology"] = sip_topology
if sip_w is not None and sip_h is not None:
topo_cfg["system"]["sips"]["w"] = int(sip_w)
topo_cfg["system"]["sips"]["h"] = int(sip_h)
else:
topo_cfg["system"]["sips"].pop("w", None)
topo_cfg["system"]["sips"].pop("h", None)
topo_path = tmp_path / "topology.yaml"
with open(topo_path, "w") as f:
yaml.dump(topo_cfg, f, default_flow_style=False)
ccl_path = Path(__file__).parent.parent / "ccl.yaml"
with open(ccl_path) as f:
ccl_cfg = yaml.safe_load(f)
ccl_cfg["defaults"]["algorithm"] = algorithm
if n_elem_override is not None:
ccl_cfg.setdefault("algorithms", {}).setdefault(
algorithm, {},
)["n_elem"] = int(n_elem_override)
# Ensure IPCQ slot is big enough for the per-message payload.
per_msg_bytes = int(n_elem_override) * 2 # f16
default_slot = int(ccl_cfg["defaults"].get("slot_size", 4096))
if per_msg_bytes > default_slot:
ccl_cfg["defaults"]["slot_size"] = per_msg_bytes
tmp_ccl = tmp_path / "ccl.yaml"
with open(tmp_ccl, "w") as f:
yaml.dump(ccl_cfg, f, default_flow_style=False)
return str(topo_path), str(tmp_ccl)
@pytest.mark.parametrize(
"algorithm,sip_topology,n_sips,sip_w,sip_h", CONFIGS,
)
def test_allreduce(
tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h,
):
topo_path, ccl_path = _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=f"test_{algorithm}_{sip_topology}",
spec=spec,
) as ctx:
result = run_allreduce(
ctx, engine, spec,
algorithm=algorithm, ccl_yaml=ccl_path,
)
assert result["ok_cubes"] > 0
# ── Latency sweep (parametrized + xdist-friendly) ─────────────────────
# avoid 16 (== n_cubes, dim_map collision). Goes up to 96 KB per PE:
# bytes_per_pe = n_elem * 2 (f16). 49152 elem * 2 = 96 KB / PE.
_SWEEP_N_ELEM = [
8, 32, 64, 128, 512, 1024, 2048,
4096, 8192, 16384, 32768, 49152,
]
_ELEM_BYTES_F16 = 2
_SWEEP_TOPOLOGIES = [
("intercube_allreduce", "ring_1d", 6, None, None),
("intercube_allreduce", "torus_2d", 6, 2, 3),
("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
]
# Shared on-disk staging dir for parametrized sweep rows. Each
# parametrized invocation writes one JSON file here; the aggregator
# (run from conftest.pytest_sessionfinish) reads them and emits the
# combined CSV + PNG plots.
_SWEEP_OUT_DIR = Path(__file__).parent / "allreduce_latency_plots"
_SWEEP_ROWS_DIR = _SWEEP_OUT_DIR / "_rows"
def _sweep_params():
out = []
for algorithm, sip_topology, n_sips, sip_w, sip_h in _SWEEP_TOPOLOGIES:
for n_elem in _SWEEP_N_ELEM:
out.append(pytest.param(
algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
id=f"{sip_topology}-n_elem{n_elem}",
))
return out
@pytest.mark.parametrize(
"algorithm,sip_topology,n_sips,sip_w,sip_h,n_elem", _sweep_params(),
)
def test_allreduce_latency_one(
tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
):
"""One config of the latency sweep. xdist parallelizes across params.
Writes a single JSON row to ``_SWEEP_ROWS_DIR``. The conftest
sessionfinish hook aggregates rows into CSV + plots after all
parametrized cases finish.
"""
import json
topo_path, ccl_path = _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
n_elem_override=n_elem,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}",
spec=spec,
) as ctx:
result = run_allreduce(
ctx, engine, spec,
algorithm=algorithm, ccl_yaml=ccl_path,
)
assert result["ok_cubes"] > 0
pe_exec_vals = [
float(tr.get("pe_exec_ns", 0.0) or 0.0)
for _, (_, tr) in engine._results.items()
if isinstance(tr, dict)
]
crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0
cm = spec["sip"]["cube_mesh"]
n_cubes = int(cm["w"]) * int(cm["h"])
bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16
bytes_per_pe = n_elem * _ELEM_BYTES_F16
record = {
"algorithm": algorithm,
"sip_topology": sip_topology,
"n_sips": n_sips,
"n_elem": n_elem,
"bytes_per_pe": bytes_per_pe,
"bytes_per_sip": bytes_per_sip,
"latency_ns": crit_ns,
}
_SWEEP_ROWS_DIR.mkdir(parents=True, exist_ok=True)
row_path = _SWEEP_ROWS_DIR / f"{sip_topology}_{n_elem}.json"
with open(row_path, "w", encoding="utf-8") as f:
json.dump(record, f)
def _aggregate_sweep_plots() -> bool:
"""Read all per-config rows and emit CSV + PNG plots.
Called by ``conftest.pytest_sessionfinish`` (controller node only).
Returns True if any rows were aggregated, False otherwise.
"""
import csv
import json
row_files = sorted(_SWEEP_ROWS_DIR.glob("*.json")) \
if _SWEEP_ROWS_DIR.exists() else []
records: list[dict] = []
if row_files:
for p in row_files:
with open(p, encoding="utf-8") as f:
records.append(json.load(f))
else:
# Fallback: replot from existing summary.csv (skip sweep re-run).
summary_path = _SWEEP_OUT_DIR / "summary.csv"
if not summary_path.exists():
return False
with open(summary_path, encoding="utf-8") as f:
for row in csv.DictReader(f):
records.append({
"algorithm": row["algorithm"],
"sip_topology": row["sip_topology"],
"n_sips": int(row["n_sips"]),
"n_elem": int(row["n_elem"]),
"bytes_per_pe": int(row["bytes_per_pe"]),
"bytes_per_sip": int(row["bytes_per_sip"]),
"latency_ns": float(row["latency_ns"]),
})
if not records:
return False
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
def _fmt_bytes(x, _pos):
if x <= 0:
return "0"
if x >= 1024 * 1024:
return f"{x / (1024 * 1024):.0f} MB"
if x >= 1024:
return f"{x / 1024:.0f} KB"
return f"{x:.0f} B"
_bytes_fmt = FuncFormatter(_fmt_bytes)
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(_SWEEP_OUT_DIR / "summary.csv", "w",
newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=[
"algorithm", "sip_topology", "n_sips", "n_elem",
"bytes_per_pe", "bytes_per_sip", "latency_ns",
])
w.writeheader()
for r in sorted(records, key=lambda r: (
r["sip_topology"], r["bytes_per_pe"],
)):
w.writerow(r)
topologies = sorted({r["sip_topology"] for r in records})
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"],
)
if not rs:
continue
xs = [r["bytes_per_pe"] for r in rs]
ys = [r["latency_ns"] for r in rs]
title = (
f"Allreduce latency — {topo_name} "
f"(n_sips={rs[0]['n_sips']})"
)
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(xs, ys, marker="o", color="tab:blue")
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title(title)
ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(_SWEEP_OUT_DIR / f"{topo_name}.png", dpi=120)
plt.close(fig)
colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange",
"mesh_2d_no_wrap": "tab:green"}
THEORETICAL_TORUS_2D_6SIP_NS = 10600.0
fig, ax = plt.subplots(figsize=(9, 6))
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"],
)
if not rs:
continue
ax.plot(
[r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs],
marker="o",
label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
color=colors.get(topo_name),
)
ax.axhline(
y=THEORETICAL_TORUS_2D_6SIP_NS,
color="tab:red", linestyle="--", linewidth=1.5,
label=f"theoretical torus_2d (6 SIPs) = "
f"{THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns",
)
BYTES_96KB = 96 * 1024
ax.axvline(
x=BYTES_96KB, ymin=0, ymax=1,
color="tab:red", linestyle=":", linewidth=1.2,
)
ax.plot(
[BYTES_96KB], [THEORETICAL_TORUS_2D_6SIP_NS],
marker="x", color="tab:red", markersize=10, markeredgewidth=2,
)
# Find simulated torus_2d latency at 96 KB (if present) for direct
# comparison with the theoretical value.
sim_torus_at_96kb = next(
(r["latency_ns"] for r in records
if r["sip_topology"] == "torus_2d" and r["bytes_per_pe"] == BYTES_96KB),
None,
)
if sim_torus_at_96kb is not None:
ax.plot(
[BYTES_96KB], [sim_torus_at_96kb],
marker="o", color="tab:orange",
markersize=10, markeredgecolor="black", markeredgewidth=1.2,
)
ax.annotate(
f"96 KB\n"
f"theoretical = {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns\n"
f"simulated = {sim_torus_at_96kb:.0f} ns",
xy=(BYTES_96KB, sim_torus_at_96kb),
xytext=(10, -20), textcoords="offset points",
color="tab:red", fontsize=9,
)
else:
ax.annotate(
f"96 KB\n→ theoretical {THEORETICAL_TORUS_2D_6SIP_NS:.0f} ns",
xy=(BYTES_96KB, THEORETICAL_TORUS_2D_6SIP_NS),
xytext=(8, -20), textcoords="offset points",
color="tab:red", fontsize=9,
)
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title("Multi-device allreduce latency by topology")
ax.grid(True, alpha=0.3)
# Drop 128 KB tick (overlaps visually with the explicit 96 KB marker)
# and add 96 KB.
BYTES_128KB = 128 * 1024
existing_ticks = [t for t in ax.get_xticks() if int(t) != BYTES_128KB]
if BYTES_96KB not in existing_ticks:
existing_ticks.append(BYTES_96KB)
ax.set_xticks(sorted(existing_ticks))
ax.set_xlim(left=min(r["bytes_per_pe"] for r in records) / 2,
right=BYTES_96KB * 1.5)
ax.legend()
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(_SWEEP_OUT_DIR / "overview.png", dpi=120)
plt.close(fig)
# Cleanup row staging dir so a partial future run doesn't pick up
# stale rows.
for p in row_files:
try:
p.unlink()
except OSError:
pass
try:
_SWEEP_ROWS_DIR.rmdir()
except OSError:
pass
print(f"\nWrote {_SWEEP_OUT_DIR / 'overview.png'} "
f"from {len(records)} rows")
return True
# ── Topology diagram (device-level + cube-level reduction) ────────────
# Convention: "rows × cols" everywhere, row-major rank assignment
# (rank = row * n_cols + col). For the 2×3 inter-SIP grid, this means
# 2 rows × 3 columns: SIP 0 1 2 / SIP 3 4 5.
_PALETTE_BG = "#fafbfd"
_PALETTE_FRAME = "#3a3f4a"
_PALETTE_BLUE = "#2c6fb6"
_PALETTE_GREEN = "#2e8a4e"
_PALETTE_TEXT = "#1f2530"
_PALETTE_BOX_FILL = "#eaf2fb"
_PALETTE_BOX_EDGE = "#2c4a78"
_PALETTE_ROOT_FILL = "#ffd9b8"
_PALETTE_ROOT_EDGE = "#bd5a14"
def _arrow(ax, xy_from, xy_to, color="black", lw=1.4, alpha=1.0,
style="-|>", curve=0.0):
from matplotlib.patches import FancyArrowPatch
arrow = FancyArrowPatch(
xy_from, xy_to,
arrowstyle=style, mutation_scale=12,
color=color, lw=lw, alpha=alpha,
connectionstyle=f"arc3,rad={curve}",
)
ax.add_patch(arrow)
def _draw_sip_box(ax, cx, cy, w, h, label, *, fill=_PALETTE_BOX_FILL,
edge=_PALETTE_BOX_EDGE, text_color=_PALETTE_TEXT,
font=10):
from matplotlib.patches import FancyBboxPatch
box = FancyBboxPatch(
(cx - w / 2, cy - h / 2), w, h,
boxstyle="round,pad=0.02,rounding_size=0.10",
linewidth=1.4, edgecolor=edge, facecolor=fill,
)
ax.add_patch(box)
ax.text(cx, cy, label, ha="center", va="center",
color=text_color, fontsize=font, fontweight="bold")
def _frame_panel(ax, title, lim_x=10.0, lim_y=6.0):
"""Set up a square-ish panel with a visible outer border."""
from matplotlib.patches import FancyBboxPatch
ax.set_xlim(0, lim_x)
ax.set_ylim(0, lim_y)
ax.set_aspect("equal")
ax.axis("off")
ax.set_facecolor(_PALETTE_BG)
border = FancyBboxPatch(
(0.05, 0.05), lim_x - 0.10, lim_y - 0.10,
boxstyle="round,pad=0.01,rounding_size=0.12",
linewidth=1.4, edgecolor=_PALETTE_FRAME, facecolor=_PALETTE_BG,
zorder=0,
)
ax.add_patch(border)
ax.set_title(title, fontsize=12, fontweight="bold",
color=_PALETTE_TEXT, pad=8)
def _draw_ring_topology(ax):
_frame_panel(ax, "ring_1d (6 SIPs)", lim_x=10.0, lim_y=6.0)
xs = [1.2, 2.7, 4.2, 5.7, 7.2, 8.7]
y = 3.1
box_w, box_h = 1.05, 0.9
for i, x in enumerate(xs):
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {i}")
# Forward ring (global_E) — adjacent neighbours, anchored to box edges.
for i in range(5):
_arrow(ax, (xs[i] + box_w / 2, y),
(xs[i + 1] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6)
# Wrap (SIP 5 → SIP 0). Anchor at right-CENTER of SIP 5 and
# left-CENTER of SIP 0; arc OUTSIDE (above) the row so it does not
# overlap any of the SIP boxes in between.
_arrow(
ax,
(xs[5] + box_w / 2, y),
(xs[0] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6, curve=-0.40,
)
ax.text(5.0, y + 2.0, "global_E (ring)", ha="center",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(5.0, y - 1.5,
"(global_W = reverse direction, used by the algorithm)",
ha="center", color="gray", fontsize=8, style="italic")
def _draw_grid_topology(ax, kind, *, n_rows=2, n_cols=3):
"""kind ∈ {'torus', 'mesh'}. Lays out as n_rows × n_cols (row-major).
For the sweep we use 2 rows × 3 cols → SIP layout::
row 0: SIP 0 SIP 1 SIP 2
row 1: SIP 3 SIP 4 SIP 5
"""
title = f"torus_2d ({n_rows}×{n_cols}, 6 SIPs)" if kind == "torus" \
else f"mesh_2d_no_wrap ({n_rows}×{n_cols}, 6 SIPs)"
_frame_panel(ax, title, lim_x=10.0, lim_y=6.0)
col_xs = [2.0, 5.0, 8.0] # 3 cols
row_ys = [4.3, 1.8] # 2 rows
box_w, box_h = 1.3, 0.95
pos: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(n_rows):
for c in range(n_cols):
rank = r * n_cols + c
x, y = col_xs[c], row_ys[r]
pos[(r, c)] = (x, y)
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {rank}")
# Row edges (E↔W) — between adjacent columns within each row.
for r in range(n_rows):
for c in range(n_cols - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r, c + 1)]
_arrow(ax, (x0 + box_w / 2, y0 + 0.10),
(x1 - box_w / 2, y1 + 0.10),
color=_PALETTE_BLUE, lw=1.5)
_arrow(ax, (x1 - box_w / 2, y1 - 0.10),
(x0 + box_w / 2, y0 - 0.10),
color=_PALETTE_BLUE, lw=1.5)
# Col edges (N↔S) — between adjacent rows within each column.
for c in range(n_cols):
for r in range(n_rows - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r + 1, c)]
_arrow(ax, (x0 - 0.12, y0 - box_h / 2),
(x1 - 0.12, y1 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
_arrow(ax, (x1 + 0.12, y1 + box_h / 2),
(x0 + 0.12, y0 - box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
# Wrap arrows for torus only — anchor to the centre of the OUTER
# edge of the end SIPs and arc OUTSIDE the row/column so they do
# not overlap the SIPs in between.
if kind == "torus":
# Row wrap: last col → first col. Top row arcs UP, bottom row
# arcs DOWN, so each wrap sits clearly outside its own row.
for r in range(n_rows):
x0, y0 = pos[(r, 0)]
x1, y1 = pos[(r, n_cols - 1)]
curve = -0.45 if r == 0 else 0.45
_arrow(
ax,
(x1 + box_w / 2, y1),
(x0 - box_w / 2, y0),
color=_PALETTE_BLUE, lw=1.5,
curve=curve, alpha=0.9,
)
# Col wrap: last row → first row. Leftmost col arcs LEFT,
# rightmost col arcs RIGHT. Middle col(s) get a small inline
# marker + legend note (drawing them through the panel would
# collide with the row arrows).
for c in range(n_cols):
x0, y0 = pos[(0, c)]
x1, y1 = pos[(n_rows - 1, c)]
if c == 0:
curve = 0.55
elif c == n_cols - 1:
curve = -0.55
else:
continue # skip middle col — see legend note
_arrow(
ax,
(x1, y1 - box_h / 2),
(x0, y0 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5,
curve=curve, alpha=0.9,
)
ax.text(0.7, 5.6, "global_E/W (row)", color=_PALETTE_BLUE,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 5.25, "global_N/S (col)", color=_PALETTE_GREEN,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 4.92,
"wrap = torus" if kind == "torus" else "no wrap = mesh",
color="gray", fontsize=8, style="italic")
if kind == "torus" and n_cols > 2:
ax.text(0.7, 0.3,
"(middle-col wrap omitted for clarity — every row "
"and every column wraps)",
color="gray", fontsize=7.5, style="italic")
def _draw_cube_reduction(ax):
"""4×4 cube grid inside SIP 0 — compact layout with phase legend."""
from matplotlib.patches import Rectangle
_frame_panel(ax, "Cube-level reduction inside SIP 0 (4×4 cubes)",
lim_x=10.0, lim_y=6.0)
cube_w = 0.65
cube_gap = 0.18
# Center the 4×4 grid in the left half of the panel.
grid_total = 4 * cube_w + 3 * cube_gap
grid_x0 = 0.7
grid_y0 = 0.7
centers: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(4):
for c in range(4):
cx = grid_x0 + c * (cube_w + cube_gap) + cube_w / 2
cy = grid_y0 + (3 - r) * (cube_w + cube_gap) + cube_w / 2
centers[(r, c)] = (cx, cy)
cube_id = r * 4 + c
is_root = (r == 3 and c == 3)
face = _PALETTE_ROOT_FILL if is_root else _PALETTE_BOX_FILL
edge = _PALETTE_ROOT_EDGE if is_root else _PALETTE_BOX_EDGE
rect = Rectangle(
(cx - cube_w / 2, cy - cube_w / 2), cube_w, cube_w,
linewidth=1.2, edgecolor=edge, facecolor=face,
)
ax.add_patch(rect)
label = f"c{cube_id}"
ax.text(cx, cy, label, ha="center", va="center",
fontsize=7.5, fontweight="bold",
color=_PALETTE_ROOT_EDGE if is_root
else _PALETTE_TEXT)
# Phase 1: row reduce W→E.
for r in range(4):
for c in range(3):
x0, y0 = centers[(r, c)]
x1, y1 = centers[(r, c + 1)]
_arrow(ax, (x0 + cube_w / 2, y0), (x1 - cube_w / 2, y1),
color=_PALETTE_BLUE, lw=1.5)
# Phase 2: col reduce N→S along rightmost column.
for r in range(3):
x0, y0 = centers[(r, 3)]
x1, y1 = centers[(r + 1, 3)]
_arrow(ax, (x0, y0 - cube_w / 2), (x1, y1 + cube_w / 2),
color=_PALETTE_GREEN, lw=1.7)
# Phase legend on the right side.
legend_x = grid_x0 + grid_total + 0.55
ax.text(legend_x, 5.0, "Phase 1: row reduce (W → E)",
color=_PALETTE_BLUE, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.55, "Phase 2: col reduce (N → S, rightmost col)",
color=_PALETTE_GREEN, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.10, "Phase 3: inter-SIP exchange at root cube",
color=_PALETTE_ROOT_EDGE, fontsize=10, fontweight="bold")
ax.text(legend_x, 3.65, "Phase 4: col broadcast (S → N)",
color=_PALETTE_GREEN, fontsize=10, style="italic")
ax.text(legend_x, 3.20, "Phase 5: row broadcast (E → W)",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(legend_x, 2.55,
"(broadcast phases reverse phases 2 & 1)",
color="gray", fontsize=8.5, style="italic")
ax.text(legend_x, 1.7,
"Root cube (c15, bottom-right) is the only\n"
"cube that performs the inter-SIP exchange.",
color=_PALETTE_ROOT_EDGE, fontsize=9, style="italic")
def emit_topology_diagram() -> str:
"""Emit a 2×2-panel topology diagram into allreduce_latency_plots/.
Top row: ring_1d | torus_2d (2×3)
Bot row: mesh_2d_no_wrap (2×3) | cube-level reduction in SIP 0
"""
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
fig = plt.figure(figsize=(16, 10), facecolor="white")
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.30, wspace=0.10)
ax_ring = fig.add_subplot(gs[0, 0])
ax_torus = fig.add_subplot(gs[0, 1])
ax_mesh = fig.add_subplot(gs[1, 0])
ax_cube = fig.add_subplot(gs[1, 1])
_draw_ring_topology(ax_ring)
_draw_grid_topology(ax_torus, "torus", n_rows=2, n_cols=3)
_draw_grid_topology(ax_mesh, "mesh", n_rows=2, n_cols=3)
_draw_cube_reduction(ax_cube)
fig.suptitle(
"Allreduce topology — device-level (top: ring, torus, mesh) "
"and cube-level reduction in SIP 0",
fontsize=14, fontweight="bold", color=_PALETTE_TEXT, y=0.98,
)
out_path = _SWEEP_OUT_DIR / "topology.png"
fig.savefig(out_path, dpi=130, bbox_inches="tight",
facecolor=fig.get_facecolor())
plt.close(fig)
return str(out_path)
def test_emit_topology_diagram():
"""Emit topology.png alongside the sweep plots. Pure plotting; no sim."""
out = emit_topology_diagram()
assert Path(out).exists()