Files
kernbench2/tests/sccl/_allreduce_helpers.py
mukesh b610cb0d9a sccl: drive allreduce tests via torch.distributed; reorganize into tests/sccl/
Convert the multidevice allreduce correctness + latency/buffer-kind sweeps
to run through the real PyTorch-distributed path
(init_process_group(backend="ahbm") -> mp.spawn -> dist.all_reduce) instead
of direct ctx.launch, and reorganize the CCL/allreduce tests into a
tests/sccl/ package split one test per file.

Production change (required for the distributed path on non-square SIP grids):
- AhbmCCLBackend now reads explicit system.sips.w/h from the spec, with a
  square-only sqrt fallback that raises on ambiguity, instead of silently
  guessing round(sqrt(count)). This fixes the 2x3 / 3x2 torus + mesh cases,
  which previously resolved to a wrong 2x2 grid. Mirrors the test helper's
  _sip_topo_dims precedence (explicit w/h > square fallback > raise).

Test reorganization (tests/sccl/):
- _allreduce_helpers.py: shared plumbing (distributed driver, config writers,
  direct-launch run_allreduce parity reference, sweep/buffer-kind constants,
  plot aggregators, topology-diagram + FSIM-comparison emitters).
- test_allreduce_ring_torus_mesh.py: correctness across ring/torus/mesh.
- test_distributed_default_topology.py: full distributed path on topology.yaml.
- test_plot_latency_sweep.py / test_plot_buffer_kind_sweep.py: sweep rows.
- test_plot_topology_diagram.py / test_plot_comparison_fsim.py: plot emitters.
- test_intercube_root_center.py: moved in (ADR-0032 center-root latency guard).

Also:
- Move the FSIM comparison plot generator out of scripts/ into the sccl suite.
- Delete superseded test files (test_allreduce_multidevice,
  test_distributed_lrab_hierarchical_allreduce, test_allreduce_buffer_kind_sweep)
  and repoint conftest aggregators + the ipcq buffer-kind importers.
- Regenerate the allreduce_latency_plots derived artifacts from the full sweep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 22:24:43 -07:00

1014 lines
35 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Shared plumbing for the sccl allreduce tests.
Not a test module (no ``test_`` prefix → pytest does not collect it).
Holds the distributed driver, the direct-launch parity reference, the
config writers, the sweep/buffer-kind constants, the plot aggregators
(called from ``conftest.pytest_sessionfinish``), and the topology-diagram
emitter. The per-test files under ``tests/sccl/`` import from here, as do
the external buffer-kind / root-center tests under ``tests/``.
"""
from __future__ import annotations
import importlib
import math
import textwrap
from pathlib import Path
from typing import Any
import numpy as np
import pytest
import yaml
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent.parent / "topology.yaml"
DEFAULT_N_ELEM = 8
# ── config writers ────────────────────────────────────────────────────
def _write_ccl_yaml(tmp_path) -> str:
body = textwrap.dedent("""\
defaults:
algorithm: lrab_hierarchical_allreduce
buffer_kind: tcm
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
lrab_hierarchical_allreduce:
module: kernbench.ccl.algorithms.lrab_hierarchical_allreduce
topology: none
buffer_kind: tcm
n_elem: 8
root_cube: 15
""")
(tmp_path / "ccl.yaml").write_text(body)
return str(tmp_path)
def _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None,
sip_w=None, sip_h=None,
):
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
with open(TOPOLOGY_PATH) as f:
topo_cfg = yaml.safe_load(f)
topo_cfg["system"]["sips"]["count"] = n_sips
topo_cfg["system"]["sips"]["topology"] = sip_topology
if sip_w is not None and sip_h is not None:
topo_cfg["system"]["sips"]["w"] = int(sip_w)
topo_cfg["system"]["sips"]["h"] = int(sip_h)
else:
topo_cfg["system"]["sips"].pop("w", None)
topo_cfg["system"]["sips"].pop("h", None)
topo_path = tmp_path / "topology.yaml"
with open(topo_path, "w") as f:
yaml.dump(topo_cfg, f, default_flow_style=False)
ccl_path = Path(__file__).parent.parent.parent / "ccl.yaml"
with open(ccl_path) as f:
ccl_cfg = yaml.safe_load(f)
ccl_cfg["defaults"]["algorithm"] = algorithm
if n_elem_override is not None:
ccl_cfg.setdefault("algorithms", {}).setdefault(
algorithm, {},
)["n_elem"] = int(n_elem_override)
# Ensure IPCQ slot is big enough for the per-message payload.
per_msg_bytes = int(n_elem_override) * 2 # f16
default_slot = int(ccl_cfg["defaults"].get("slot_size", 4096))
if per_msg_bytes > default_slot:
ccl_cfg["defaults"]["slot_size"] = per_msg_bytes
tmp_ccl = tmp_path / "ccl.yaml"
with open(tmp_ccl, "w") as f:
yaml.dump(ccl_cfg, f, default_flow_style=False)
return str(topo_path), str(tmp_ccl)
# ── distributed driver (init_process_group → mp.spawn → all_reduce) ────
def _worker(rank: int, n_cubes: int, n_elem: int, n_sips: int, torch) -> None:
"""Per-SIP worker: allocate, fill, all_reduce, verify."""
torch.ahbm.set_device(rank)
dp = DPPolicy(
cube="row_wise", pe="replicate",
num_pes=1, num_cubes=n_cubes,
)
tensor = torch.zeros(
(n_cubes, n_elem), dtype="f16", dp=dp,
name=f"sip{rank}",
)
tensor.copy_(torch.from_numpy(
np.full((n_cubes, n_elem), float(rank + 1), dtype=np.float16)
))
torch.distributed.all_reduce(tensor, op="sum")
arr = tensor.numpy()
expected = float(n_cubes * sum(range(1, n_sips + 1)))
for cube_id in range(n_cubes):
assert np.allclose(arr[cube_id], expected, rtol=1e-1, atol=1e-1), (
f"SIP{rank} cube {cube_id}: "
f"got {arr[cube_id][:4]}, expected {expected}"
)
if rank == 0:
print(f"\n lrab_hierarchical_allreduce (ws={n_sips}): "
f"{n_sips * n_cubes} OK")
def _crit_ns(engine) -> float:
"""Critical-path latency = max per-result pe_exec_ns over engine results."""
vals = [
float(tr.get("pe_exec_ns", 0.0) or 0.0)
for _, (_, tr) in engine._results.items()
if isinstance(tr, dict)
]
return max(vals) if vals else 0.0
def _run_distributed(tmp_path, monkeypatch, topo_path, correlation_id, n_elem):
"""Build engine + run the collective via the full distributed path.
Returns ``(engine, n_cubes)``. ``monkeypatch.chdir`` points the backend's
``load_ccl_config()`` (cwd lookup) at the temp ``ccl.yaml``.
"""
monkeypatch.chdir(tmp_path)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
n_sips = int(spec["system"]["sips"]["count"])
cm = spec["sip"]["cube_mesh"]
n_cubes = int(cm["w"]) * int(cm["h"])
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=correlation_id,
spec=spec,
) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
assert ctx.distributed.get_world_size() == n_sips
ctx.multiprocessing.spawn(
_worker, args=(n_cubes, n_elem, n_sips, ctx), nprocs=n_sips,
)
return engine, n_cubes
# ── correctness config matrix (used by test_allreduce) ─────────────────
CONFIGS = [
pytest.param(
"lrab_hierarchical_allreduce", "ring_1d", 6, None, None,
id="ring_6sip",
),
pytest.param(
"lrab_hierarchical_allreduce", "torus_2d", 6, 2, 3,
id="torus_6sip_2x3",
),
pytest.param(
"lrab_hierarchical_allreduce", "mesh_2d_no_wrap", 6, 2, 3,
id="mesh_6sip_2x3",
),
]
# ── direct-launch helper (parity reference only) ───────────────────────
def _sip_topo_dims(
sip_topo: str, n_sips: int,
spec_w: int | None = None, spec_h: int | None = None,
) -> tuple[int, int]:
if sip_topo == "ring_1d":
return (0, 0)
if spec_w is not None and spec_h is not None:
if spec_w * spec_h != n_sips:
raise ValueError(
f"sip layout {spec_w}x{spec_h} != n_sips ({n_sips})"
)
return (spec_w, spec_h)
side = int(round(math.sqrt(n_sips)))
if side * side != n_sips:
raise ValueError(
f"SIP topology '{sip_topo}' requires square n_sips or "
f"explicit w/h in spec, got {n_sips}"
)
return (side, side)
def run_allreduce(
ctx: Any,
engine: Any,
spec: dict,
*,
algorithm: str | None = None,
ccl_yaml: str | None = None,
) -> dict:
"""Config-driven allreduce via direct ctx.launch (no distributed wrapper).
Retained as the parity reference for the distributed path and reused by
the external buffer-kind / root-center micro-tests.
"""
cfg_all = load_ccl_config(ccl_yaml)
cfg = resolve_algorithm_config(cfg_all, algorithm)
algo_module = importlib.import_module(cfg["module"])
kernel_fn = algo_module.kernel
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
n_elem = int(cfg.get("n_elem", 8))
sips_cfg = spec.get("system", {}).get("sips", {})
n_sips = int(sips_cfg.get("count", 1))
sip_topo = str(sips_cfg.get("topology", "ring_1d"))
spec_sip_w = sips_cfg.get("w")
spec_sip_h = sips_cfg.get("h")
spec_sip_w = int(spec_sip_w) if spec_sip_w is not None else None
spec_sip_h = int(spec_sip_h) if spec_sip_h is not None else None
cm = spec["sip"]["cube_mesh"]
cube_w = int(cm["w"])
cube_h = int(cm["h"])
n_cubes = cube_w * cube_h
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
sip_topo_w, sip_topo_h = _sip_topo_dims(
sip_topo, n_sips, spec_w=spec_sip_w, spec_h=spec_sip_h,
)
algo_name = cfg.get("algorithm", "allreduce")
configure_sfr_intercube_multisip(engine, spec, cfg)
dp = DPPolicy(
cube="row_wise", pe="replicate",
num_pes=1, num_cubes=n_cubes,
)
tensors = []
for sip in range(n_sips):
ctx.ahbm.set_device(sip)
t = ctx.zeros(
(n_cubes, n_elem), dtype="f16", dp=dp,
name=f"sip{sip}",
)
t.copy_(ctx.from_numpy(
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
))
tensors.append(t)
t_start = engine._env.now
all_pending = []
for sip_rank, t in enumerate(tensors):
pending = ctx.launch(
algo_name, kernel_fn, t,
n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h,
_defer_wait=True,
)
all_pending.extend(pending)
for h, _sip_id, meta in all_pending:
ctx.wait(h, _meta=meta)
t_end = engine._env.now
latency_ns = t_end - t_start
expected = float(n_cubes * sum(range(1, n_sips + 1)))
ok_cubes = 0
for sip in range(n_sips):
arr = tensors[sip].numpy()
for cube_id in range(n_cubes):
assert np.allclose(
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
), (
f"SIP{sip} cube {cube_id}: "
f"got {arr[cube_id][:4]}, expected {expected}"
)
ok_cubes += 1
return {
"expected": expected,
"latency_ns": latency_ns,
"ok_cubes": ok_cubes,
}
# ── Latency sweep constants + aggregator ──────────────────────────────
# avoid 16 (== n_cubes, dim_map collision). Goes up to 96 KB per PE:
# bytes_per_pe = n_elem * 2 (f16). 49152 elem * 2 = 96 KB / PE.
_SWEEP_N_ELEM = [
8, 32, 64, 128, 512, 1024, 2048,
4096, 8192, 16384, 32768, 49152,
]
_ELEM_BYTES_F16 = 2
_SWEEP_TOPOLOGIES = [
("lrab_hierarchical_allreduce", "ring_1d", 6, None, None),
("lrab_hierarchical_allreduce", "torus_2d", 6, 2, 3),
("lrab_hierarchical_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
]
# Shared on-disk staging dir for parametrized sweep rows. Each
# parametrized invocation writes one JSON file here; the aggregator
# (run from conftest.pytest_sessionfinish) reads them and emits the
# combined CSV + PNG plots.
_SWEEP_OUT_DIR = (Path(__file__).parent.parent.parent / "docs" / "diagrams"
/ "allreduce_latency_plots")
_SWEEP_ROWS_DIR = _SWEEP_OUT_DIR / "_rows"
def _sweep_params():
out = []
for algorithm, sip_topology, n_sips, sip_w, sip_h in _SWEEP_TOPOLOGIES:
for n_elem in _SWEEP_N_ELEM:
out.append(pytest.param(
algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
id=f"{sip_topology}-n_elem{n_elem}",
))
return out
def _aggregate_sweep_plots() -> bool:
"""Read all per-config rows and emit CSV + PNG plots.
Called by ``conftest.pytest_sessionfinish`` (controller node only).
Returns True if any rows were aggregated, False otherwise.
"""
import csv
import json
row_files = sorted(_SWEEP_ROWS_DIR.glob("*.json")) \
if _SWEEP_ROWS_DIR.exists() else []
records: list[dict] = []
if row_files:
for p in row_files:
with open(p, encoding="utf-8") as f:
records.append(json.load(f))
else:
# Fallback: replot from existing summary.csv (skip sweep re-run).
summary_path = _SWEEP_OUT_DIR / "summary.csv"
if not summary_path.exists():
return False
with open(summary_path, encoding="utf-8") as f:
for row in csv.DictReader(f):
records.append({
"algorithm": row["algorithm"],
"sip_topology": row["sip_topology"],
"n_sips": int(row["n_sips"]),
"n_elem": int(row["n_elem"]),
"bytes_per_pe": int(row["bytes_per_pe"]),
"bytes_per_sip": int(row["bytes_per_sip"]),
"latency_ns": float(row["latency_ns"]),
})
if not records:
return False
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
def _fmt_bytes(x, _pos):
if x <= 0:
return "0"
if x >= 1024 * 1024:
return f"{x / (1024 * 1024):.0f} MB"
if x >= 1024:
return f"{x / 1024:.0f} KB"
return f"{x:.0f} B"
_bytes_fmt = FuncFormatter(_fmt_bytes)
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(_SWEEP_OUT_DIR / "summary.csv", "w",
newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=[
"algorithm", "sip_topology", "n_sips", "n_elem",
"bytes_per_pe", "bytes_per_sip", "latency_ns",
])
w.writeheader()
for r in sorted(records, key=lambda r: (
r["sip_topology"], r["bytes_per_pe"],
)):
w.writerow(r)
topologies = sorted({r["sip_topology"] for r in records})
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"],
)
if not rs:
continue
xs = [r["bytes_per_pe"] for r in rs]
ys = [r["latency_ns"] for r in rs]
_per_topo_titles = {
"ring_1d": "AllReduce_LRAB_Ring1D_6SiP(1x6)",
"torus_2d": "AllReduce_LRAB_2Dtorus_6SiP(2x3)",
"mesh_2d_no_wrap": "AllReduce_LRAB_2DMesh_6SiP(2x3)",
}
# Descriptive output filenames (parens → underscores for
# markdown/URL safety; topo key stays the summary.csv value).
_per_topo_files = {
"ring_1d": "AllReduce_LRAB_Ring1D_6SiP_1x6",
"torus_2d": "AllReduce_LRAB_2Dtorus_6SiP_2x3",
"mesh_2d_no_wrap": "AllReduce_LRAB_2DMesh_6SiP_2x3",
}
title = _per_topo_titles.get(
topo_name, f"Allreduce latency — {topo_name}"
)
out_stem = _per_topo_files.get(topo_name, topo_name)
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(xs, ys, marker="o", color="tab:blue")
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("Time (ns)")
ax.set_title(title)
ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(_SWEEP_OUT_DIR / f"{out_stem}.png", dpi=120)
plt.close(fig)
# Combined overview.png is no longer emitted — the broken-y-axis
# comparison (emit_comparison_fsim_plot() below →
# comparison_mesh_vs_ring_vs_2DTorus_vs_theoretical_vs_fsim.png)
# supersedes it. Per-topology plots above and summary.csv are still
# produced.
# Cleanup row staging dir so a partial future run doesn't pick up
# stale rows.
for p in row_files:
try:
p.unlink()
except OSError:
pass
try:
_SWEEP_ROWS_DIR.rmdir()
except OSError:
pass
print(f"\nWrote per-topology plots + summary.csv to {_SWEEP_OUT_DIR} "
f"from {len(records)} rows")
return True
# ── Buffer-kind sweep constants + aggregator ──────────────────────────
#
# Parametrized over (buffer_kind, n_elem) on torus_2d 6 SIPs (3×2). Pre
# slot-latency modeling the three lines overlap exactly (slot access is
# latency-free today); they spread out once tcm/sram/hbm carry distinct
# access costs.
_BUFFER_KINDS = ["tcm", "sram", "hbm"]
_BK_N_ELEM_GRID = [128, 1024, 8192, 32768] # 256 B → 64 KB per slot
_BK_ROWS_DIR = _SWEEP_OUT_DIR / "_buffer_kind_rows"
# Descriptive output stem (shared by the .png and .csv).
_BK_OUT_STEM = "AllReduce_LRAB_2Dtorus_6SiP_2x3_with_TCM_SRAM_HBM"
def _bk_params():
out = []
for bk in _BUFFER_KINDS:
for n_elem in _BK_N_ELEM_GRID:
out.append(pytest.param(bk, n_elem, id=f"{bk}-n_elem{n_elem}"))
return out
def aggregate_buffer_kind_plot() -> bool:
"""Read per-config rows and emit the descriptive .png + .csv (_BK_OUT_STEM).
Called from conftest.pytest_sessionfinish (controller-only).
Returns True if rows were aggregated.
"""
import csv
import json
if not _BK_ROWS_DIR.exists():
return False
row_files = sorted(_BK_ROWS_DIR.glob("*.json"))
if not row_files:
return False
records = []
for p in row_files:
with open(p, encoding="utf-8") as f:
records.append(json.load(f))
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
def _fmt_bytes(x, _pos):
if x <= 0:
return "0"
if x >= 1024 * 1024:
return f"{x / (1024 * 1024):.0f} MB"
if x >= 1024:
return f"{x / 1024:.0f} KB"
return f"{x:.0f} B"
_bytes_fmt = FuncFormatter(_fmt_bytes)
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(_SWEEP_OUT_DIR / f"{_BK_OUT_STEM}.csv", "w",
newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=[
"buffer_kind", "sip_topology", "n_sips", "n_elem",
"bytes_per_pe", "latency_ns",
])
w.writeheader()
for r in sorted(records, key=lambda r: (
r["buffer_kind"], r["bytes_per_pe"],
)):
w.writerow(r)
colors = {"tcm": "tab:blue", "sram": "tab:orange", "hbm": "tab:red"}
fig, ax = plt.subplots(figsize=(10, 6))
for bk in ["tcm", "sram", "hbm"]:
rs = sorted(
[r for r in records if r["buffer_kind"] == bk],
key=lambda r: r["bytes_per_pe"],
)
if not rs:
continue
ax.plot(
[r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs],
marker="o", lw=2.0,
color=colors[bk], label=f"buffer_kind = {bk}",
)
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("Time (ns)")
ax.set_title(
"AllReduce_LRAB_2Dtorus_6SiP(2x3) — IPCQ memory (SRAM, TCM, HBM)"
)
ax.grid(True, alpha=0.3)
ax.legend()
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(_SWEEP_OUT_DIR / f"{_BK_OUT_STEM}.png", dpi=130)
plt.close(fig)
for p in row_files:
try:
p.unlink()
except OSError:
pass
try:
_BK_ROWS_DIR.rmdir()
except OSError:
pass
print(f"\nWrote {_SWEEP_OUT_DIR / f'{_BK_OUT_STEM}.png'} "
f"from {len(records)} rows")
return True
# ── Topology diagram (device-level + cube-level reduction) ────────────
# Convention: "rows × cols" everywhere, row-major rank assignment
# (rank = row * n_cols + col). For the 2×3 inter-SIP grid, this means
# 2 rows × 3 columns: SIP 0 1 2 / SIP 3 4 5.
_PALETTE_BG = "#fafbfd"
_PALETTE_FRAME = "#3a3f4a"
_PALETTE_BLUE = "#2c6fb6"
_PALETTE_GREEN = "#2e8a4e"
_PALETTE_TEXT = "#1f2530"
_PALETTE_BOX_FILL = "#eaf2fb"
_PALETTE_BOX_EDGE = "#2c4a78"
_PALETTE_ROOT_FILL = "#ffd9b8"
_PALETTE_ROOT_EDGE = "#bd5a14"
def _arrow(ax, xy_from, xy_to, color="black", lw=1.4, alpha=1.0,
style="-|>", curve=0.0):
from matplotlib.patches import FancyArrowPatch
arrow = FancyArrowPatch(
xy_from, xy_to,
arrowstyle=style, mutation_scale=12,
color=color, lw=lw, alpha=alpha,
connectionstyle=f"arc3,rad={curve}",
)
ax.add_patch(arrow)
def _draw_sip_box(ax, cx, cy, w, h, label, *, fill=_PALETTE_BOX_FILL,
edge=_PALETTE_BOX_EDGE, text_color=_PALETTE_TEXT,
font=10):
from matplotlib.patches import FancyBboxPatch
box = FancyBboxPatch(
(cx - w / 2, cy - h / 2), w, h,
boxstyle="round,pad=0.02,rounding_size=0.10",
linewidth=1.4, edgecolor=edge, facecolor=fill,
)
ax.add_patch(box)
ax.text(cx, cy, label, ha="center", va="center",
color=text_color, fontsize=font, fontweight="bold")
def _frame_panel(ax, title, lim_x=10.0, lim_y=6.0):
"""Set up a square-ish panel with a visible outer border."""
from matplotlib.patches import FancyBboxPatch
ax.set_xlim(0, lim_x)
ax.set_ylim(0, lim_y)
ax.set_aspect("equal")
ax.axis("off")
ax.set_facecolor(_PALETTE_BG)
border = FancyBboxPatch(
(0.05, 0.05), lim_x - 0.10, lim_y - 0.10,
boxstyle="round,pad=0.01,rounding_size=0.12",
linewidth=1.4, edgecolor=_PALETTE_FRAME, facecolor=_PALETTE_BG,
zorder=0,
)
ax.add_patch(border)
ax.set_title(title, fontsize=12, fontweight="bold",
color=_PALETTE_TEXT, pad=8)
def _draw_ring_topology(ax):
_frame_panel(ax, "ring_1d (6 SIPs)", lim_x=10.0, lim_y=6.0)
xs = [1.2, 2.7, 4.2, 5.7, 7.2, 8.7]
y = 3.1
box_w, box_h = 1.05, 0.9
for i, x in enumerate(xs):
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {i}")
# Forward ring (global_E) — adjacent neighbours, anchored to box edges.
for i in range(5):
_arrow(ax, (xs[i] + box_w / 2, y),
(xs[i + 1] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6)
# Wrap (SIP 5 → SIP 0). Anchor at right-CENTER of SIP 5 and
# left-CENTER of SIP 0; arc OUTSIDE (above) the row so it does not
# overlap any of the SIP boxes in between.
_arrow(
ax,
(xs[5] + box_w / 2, y),
(xs[0] - box_w / 2, y),
color=_PALETTE_BLUE, lw=1.6, curve=-0.40,
)
ax.text(5.0, y + 2.0, "global_E (ring)", ha="center",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(5.0, y - 1.5,
"(global_W = reverse direction, used by the algorithm)",
ha="center", color="gray", fontsize=8, style="italic")
def _draw_grid_topology(ax, kind, *, n_rows=2, n_cols=3):
"""kind ∈ {'torus', 'mesh'}. Lays out as n_rows × n_cols (row-major).
For the sweep we use 2 rows × 3 cols → SIP layout::
row 0: SIP 0 SIP 1 SIP 2
row 1: SIP 3 SIP 4 SIP 5
"""
title = f"torus_2d ({n_rows}×{n_cols}, 6 SIPs)" if kind == "torus" \
else f"mesh_2d_no_wrap ({n_rows}×{n_cols}, 6 SIPs)"
_frame_panel(ax, title, lim_x=10.0, lim_y=6.0)
col_xs = [2.0, 5.0, 8.0] # 3 cols
row_ys = [4.3, 1.8] # 2 rows
box_w, box_h = 1.3, 0.95
pos: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(n_rows):
for c in range(n_cols):
rank = r * n_cols + c
x, y = col_xs[c], row_ys[r]
pos[(r, c)] = (x, y)
_draw_sip_box(ax, x, y, box_w, box_h, f"SIP {rank}")
# Row edges (E↔W) — between adjacent columns within each row.
for r in range(n_rows):
for c in range(n_cols - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r, c + 1)]
_arrow(ax, (x0 + box_w / 2, y0 + 0.10),
(x1 - box_w / 2, y1 + 0.10),
color=_PALETTE_BLUE, lw=1.5)
_arrow(ax, (x1 - box_w / 2, y1 - 0.10),
(x0 + box_w / 2, y0 - 0.10),
color=_PALETTE_BLUE, lw=1.5)
# Col edges (N↔S) — between adjacent rows within each column.
for c in range(n_cols):
for r in range(n_rows - 1):
x0, y0 = pos[(r, c)]
x1, y1 = pos[(r + 1, c)]
_arrow(ax, (x0 - 0.12, y0 - box_h / 2),
(x1 - 0.12, y1 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
_arrow(ax, (x1 + 0.12, y1 + box_h / 2),
(x0 + 0.12, y0 - box_h / 2),
color=_PALETTE_GREEN, lw=1.5)
# Wrap arrows for torus only — anchor to the centre of the OUTER
# edge of the end SIPs and arc OUTSIDE the row/column so they do
# not overlap the SIPs in between.
if kind == "torus":
# Row wrap: last col → first col. Top row arcs UP, bottom row
# arcs DOWN, so each wrap sits clearly outside its own row.
for r in range(n_rows):
x0, y0 = pos[(r, 0)]
x1, y1 = pos[(r, n_cols - 1)]
curve = -0.45 if r == 0 else 0.45
_arrow(
ax,
(x1 + box_w / 2, y1),
(x0 - box_w / 2, y0),
color=_PALETTE_BLUE, lw=1.5,
curve=curve, alpha=0.9,
)
# Col wrap: last row → first row. Leftmost col arcs LEFT,
# rightmost col arcs RIGHT. Middle col(s) get a small inline
# marker + legend note (drawing them through the panel would
# collide with the row arrows).
for c in range(n_cols):
x0, y0 = pos[(0, c)]
x1, y1 = pos[(n_rows - 1, c)]
if c == 0:
curve = 0.55
elif c == n_cols - 1:
curve = -0.55
else:
continue # skip middle col — see legend note
_arrow(
ax,
(x1, y1 - box_h / 2),
(x0, y0 + box_h / 2),
color=_PALETTE_GREEN, lw=1.5,
curve=curve, alpha=0.9,
)
ax.text(0.7, 5.6, "global_E/W (row)", color=_PALETTE_BLUE,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 5.25, "global_N/S (col)", color=_PALETTE_GREEN,
fontsize=9, style="italic", fontweight="bold")
ax.text(0.7, 4.92,
"wrap = torus" if kind == "torus" else "no wrap = mesh",
color="gray", fontsize=8, style="italic")
if kind == "torus" and n_cols > 2:
ax.text(0.7, 0.3,
"(middle-col wrap omitted for clarity — every row "
"and every column wraps)",
color="gray", fontsize=7.5, style="italic")
def _draw_cube_reduction(ax):
"""4×4 cube grid inside SIP 0 — compact layout with phase legend."""
from matplotlib.patches import Rectangle
_frame_panel(ax, "Cube-level reduction inside SIP 0 (4×4 cubes)",
lim_x=10.0, lim_y=6.0)
cube_w = 0.65
cube_gap = 0.18
# Center the 4×4 grid in the left half of the panel.
grid_total = 4 * cube_w + 3 * cube_gap
grid_x0 = 0.7
grid_y0 = 0.7
centers: dict[tuple[int, int], tuple[float, float]] = {}
for r in range(4):
for c in range(4):
cx = grid_x0 + c * (cube_w + cube_gap) + cube_w / 2
cy = grid_y0 + (3 - r) * (cube_w + cube_gap) + cube_w / 2
centers[(r, c)] = (cx, cy)
cube_id = r * 4 + c
is_root = (r == 3 and c == 3)
face = _PALETTE_ROOT_FILL if is_root else _PALETTE_BOX_FILL
edge = _PALETTE_ROOT_EDGE if is_root else _PALETTE_BOX_EDGE
rect = Rectangle(
(cx - cube_w / 2, cy - cube_w / 2), cube_w, cube_w,
linewidth=1.2, edgecolor=edge, facecolor=face,
)
ax.add_patch(rect)
label = f"c{cube_id}"
ax.text(cx, cy, label, ha="center", va="center",
fontsize=7.5, fontweight="bold",
color=_PALETTE_ROOT_EDGE if is_root
else _PALETTE_TEXT)
# Phase 1: row reduce W→E.
for r in range(4):
for c in range(3):
x0, y0 = centers[(r, c)]
x1, y1 = centers[(r, c + 1)]
_arrow(ax, (x0 + cube_w / 2, y0), (x1 - cube_w / 2, y1),
color=_PALETTE_BLUE, lw=1.5)
# Phase 2: col reduce N→S along rightmost column.
for r in range(3):
x0, y0 = centers[(r, 3)]
x1, y1 = centers[(r + 1, 3)]
_arrow(ax, (x0, y0 - cube_w / 2), (x1, y1 + cube_w / 2),
color=_PALETTE_GREEN, lw=1.7)
# Phase legend on the right side.
legend_x = grid_x0 + grid_total + 0.55
ax.text(legend_x, 5.0, "Phase 1: row reduce (W → E)",
color=_PALETTE_BLUE, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.55, "Phase 2: col reduce (N → S, rightmost col)",
color=_PALETTE_GREEN, fontsize=10, fontweight="bold")
ax.text(legend_x, 4.10, "Phase 3: inter-SIP exchange at root cube",
color=_PALETTE_ROOT_EDGE, fontsize=10, fontweight="bold")
ax.text(legend_x, 3.65, "Phase 4: col broadcast (S → N)",
color=_PALETTE_GREEN, fontsize=10, style="italic")
ax.text(legend_x, 3.20, "Phase 5: row broadcast (E → W)",
color=_PALETTE_BLUE, fontsize=10, style="italic")
ax.text(legend_x, 2.55,
"(broadcast phases reverse phases 2 & 1)",
color="gray", fontsize=8.5, style="italic")
ax.text(legend_x, 1.7,
"Root cube (c15, bottom-right) is the only\n"
"cube that performs the inter-SIP exchange.",
color=_PALETTE_ROOT_EDGE, fontsize=9, style="italic")
def emit_topology_diagram() -> str:
"""Emit a 2×2-panel topology diagram into docs/diagrams/allreduce_latency_plots/.
Top row: ring_1d | torus_2d (2×3)
Bot row: mesh_2d_no_wrap (2×3) | cube-level reduction in SIP 0
"""
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
_SWEEP_OUT_DIR.mkdir(parents=True, exist_ok=True)
fig = plt.figure(figsize=(16, 10), facecolor="white")
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.30, wspace=0.10)
ax_ring = fig.add_subplot(gs[0, 0])
ax_torus = fig.add_subplot(gs[0, 1])
ax_mesh = fig.add_subplot(gs[1, 0])
ax_cube = fig.add_subplot(gs[1, 1])
_draw_ring_topology(ax_ring)
_draw_grid_topology(ax_torus, "torus", n_rows=2, n_cols=3)
_draw_grid_topology(ax_mesh, "mesh", n_rows=2, n_cols=3)
_draw_cube_reduction(ax_cube)
fig.suptitle(
"Allreduce topology — device-level (top: ring, torus, mesh) "
"and cube-level reduction in SIP 0",
fontsize=14, fontweight="bold", color=_PALETTE_TEXT, y=0.98,
)
out_path = _SWEEP_OUT_DIR / "topology.png"
fig.savefig(out_path, dpi=130, bbox_inches="tight",
facecolor=fig.get_facecolor())
plt.close(fig)
return str(out_path)
# ── Comparison vs FSIM (broken-y-axis) ────────────────────────────────
#
# Post-processes summary.csv: today's three model curves + a hand-derived
# theoretical torus_2d line in the bottom panel, and a single external FSIM
# single-device reference marker in the top panel (hardcoded 366 µs; no
# external data file). Reads summary.csv written by _aggregate_sweep_plots.
_FSIM_EXT_LABEL = "FSIM (single device): 366 µs"
_FSIM_EXT_LATENCY_NS = 366_000.0
_CMP_COLORS = {
"ring_1d": "tab:blue",
"torus_2d": "tab:orange",
"mesh_2d_no_wrap": "tab:green",
}
_CMP_DISPLAY = {
"ring_1d": "Ring 1x6 (6 devices)",
"torus_2d": "2D Torus 2x3 (6 devices)",
"mesh_2d_no_wrap": "2D Mesh 2x3 (6 devices)",
}
# Hand-derived theoretical model for torus_2d (6 SIPs): per-PE NOC-packet
# count fit to the simulated startup + per-packet tau.
_CMP_NOC_PACKET_BYTES = 128
_CMP_PES_PER_CUBE = 8
_CMP_T_STARTUP_NS = 1346.0
_CMP_TAU_NS = (8741.0 - 1346.0) / (6144 - 1)
def emit_comparison_fsim_plot() -> str | None:
"""Render comparison_mesh_vs_ring_vs_2DTorus_vs_theoretical_vs_fsim.png.
Reads ``summary.csv`` (written by ``_aggregate_sweep_plots``). Returns the
output path, or ``None`` if summary.csv is absent / empty.
"""
import csv
csv_path = _SWEEP_OUT_DIR / "summary.csv"
if not csv_path.exists():
return None
records = []
with open(csv_path, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
records.append({
"sip_topology": row["sip_topology"],
"bytes_per_pe": int(row["bytes_per_pe"]),
"latency_ns": float(row["latency_ns"]),
})
if not records:
return None
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
def _theoretical_torus_2d_ns(bytes_per_pe: int) -> float:
bytes_per_cube = int(bytes_per_pe) * _CMP_PES_PER_CUBE
n_packets = max(1, -(-bytes_per_cube // _CMP_NOC_PACKET_BYTES))
return _CMP_T_STARTUP_NS + (n_packets - 1) * _CMP_TAU_NS
def _bytes_fmt(x, _pos):
if x >= 1024 * 1024:
return f"{x / (1024 * 1024):.0f}M"
if x >= 1024:
return f"{x / 1024:.0f}K"
return f"{int(x)}"
topologies = sorted({r["sip_topology"] for r in records})
max_local = max(r["latency_ns"] for r in records)
ext_x = max(r["bytes_per_pe"] for r in records)
fig, (ax_top, ax_bot) = plt.subplots(
2, 1, sharex=True,
gridspec_kw={"height_ratios": [1, 4], "hspace": 0.05},
figsize=(9, 6.5),
)
# Bottom panel: model curves + theoretical torus, linear y.
for topo in topologies:
rs = sorted([r for r in records if r["sip_topology"] == topo],
key=lambda r: r["bytes_per_pe"])
if not rs:
continue
ax_bot.plot(
[r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs],
marker="o", label=_CMP_DISPLAY.get(topo, topo),
color=_CMP_COLORS.get(topo),
)
torus_rs = sorted(
[r for r in records if r["sip_topology"] == "torus_2d"],
key=lambda r: r["bytes_per_pe"],
)
if torus_rs:
ax_bot.plot(
[r["bytes_per_pe"] for r in torus_rs],
[_theoretical_torus_2d_ns(r["bytes_per_pe"]) for r in torus_rs],
color="tab:red", linestyle="--", linewidth=1.6, marker="x",
label="Theoretical 2D Torus 2x3",
)
ax_bot.set_ylim(0, max_local * 1.10)
# Top panel: external FSIM single-device reference marker.
ax_top.scatter(
[ext_x], [_FSIM_EXT_LATENCY_NS],
marker="*", s=240, color="tab:red", zorder=5,
label=_FSIM_EXT_LABEL,
)
ax_top.set_ylim(_FSIM_EXT_LATENCY_NS * 0.93, _FSIM_EXT_LATENCY_NS * 1.05)
# Hide spine between panels; draw diagonal break ticks.
ax_top.spines["bottom"].set_visible(False)
ax_bot.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False, bottom=False)
ax_bot.xaxis.tick_bottom()
d = 0.012
kw = dict(transform=ax_top.transAxes, color="k", clip_on=False, lw=1)
ax_top.plot((-d, +d), (-d, +d), **kw)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kw)
kw.update(transform=ax_bot.transAxes)
ax_bot.plot((-d, +d), (1 - d * 4, 1 + d * 4), **kw)
ax_bot.plot((1 - d, 1 + d), (1 - d * 4, 1 + d * 4), **kw)
ax_bot.set_xscale("log", base=2)
ax_bot.set_xlabel("Bytes per PE (log scale)")
ax_bot.set_ylabel("Time (ns)")
ax_top.set_ylabel("Time (ns)")
ax_bot.grid(True, alpha=0.3)
ax_top.grid(True, alpha=0.3)
ax_bot.xaxis.set_major_formatter(mticker.FuncFormatter(_bytes_fmt))
handles_bot, labels_bot = ax_bot.get_legend_handles_labels()
handles_top, labels_top = ax_top.get_legend_handles_labels()
ax_bot.legend(handles_bot + handles_top, labels_bot + labels_top,
loc="upper left")
fig.suptitle("Multidevice allreduce (ring, Mesh, 2DTorus) vs FSIM latency")
fig.tight_layout()
out = (_SWEEP_OUT_DIR
/ "comparison_mesh_vs_ring_vs_2DTorus_vs_theoretical_vs_fsim.png")
fig.savefig(out, dpi=120)
plt.close(fig)
return str(out)