Rectangular SIP topology + 6-device allreduce sweep

mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs;
sqrt fall-back preserved for square layouts (back-compat tests
confirm 4-SIP and 9-SIP square configs still work). sfr_config
reads system.sips.w/h from spec and threads dims through to the
topology fn.

test_allreduce_multidevice CONFIGS switched from 4 SIPs (square)
to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3.
_write_temp_configs writes system.sips.w/h when supplied;
_sip_topo_dims reads them back. Latency sweep loop also moved to
6-SIP layouts. Linear-scale plot variants dropped -- only log-scale
*.png + summary.csv emitted. Plots in tests/allreduce_latency_plots
regenerated.

New tests/test_sip_topology_rectangular.py asserts neighbor
correctness for 2x3 layouts and back-compat for square fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 15:13:14 -07:00
parent c1a5cf3a2a
commit e9cc40f74d
9 changed files with 362 additions and 143 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

+34
View File
@@ -0,0 +1,34 @@
algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns
intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937
intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947
intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992
intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865
intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865
intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865
intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865
intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965
intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957
intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932
intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903
intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923
intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993
intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905
intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985
intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985
intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985
intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985
intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777
intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705
intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952
intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923
intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993
intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55
intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975
intercube_allreduce,mesh_2d_no_wrap,6,128,256,4096,3544.049999999992
intercube_allreduce,mesh_2d_no_wrap,6,512,1024,16384,3667.049999999992
intercube_allreduce,mesh_2d_no_wrap,6,1024,2048,32768,3837.049999999992
intercube_allreduce,mesh_2d_no_wrap,6,2048,4096,65536,4177.049999999992
intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959
intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945
intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937
intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872
1 algorithm sip_topology n_sips n_elem bytes_per_pe bytes_per_sip latency_ns
2 intercube_allreduce ring_1d 6 8 16 256 3073.1299999999937
3 intercube_allreduce ring_1d 6 32 64 1024 3079.8799999999947
4 intercube_allreduce ring_1d 6 64 128 2048 3088.879999999992
5 intercube_allreduce ring_1d 6 128 256 4096 3106.8799999999865
6 intercube_allreduce ring_1d 6 512 1024 16384 3225.8799999999865
7 intercube_allreduce ring_1d 6 1024 2048 32768 3391.8799999999865
8 intercube_allreduce ring_1d 6 2048 4096 65536 3723.8799999999865
9 intercube_allreduce ring_1d 6 4096 8192 131072 4387.879999999965
10 intercube_allreduce ring_1d 6 8192 16384 262144 5715.879999999957
11 intercube_allreduce ring_1d 6 16384 32768 524288 8371.879999999932
12 intercube_allreduce ring_1d 6 32768 65536 1048576 13683.879999999903
13 intercube_allreduce torus_2d 6 8 16 256 2190.4799999999923
14 intercube_allreduce torus_2d 6 32 64 1024 2196.479999999993
15 intercube_allreduce torus_2d 6 64 128 2048 2204.4799999999905
16 intercube_allreduce torus_2d 6 128 256 4096 2220.479999999985
17 intercube_allreduce torus_2d 6 512 1024 16384 2325.479999999985
18 intercube_allreduce torus_2d 6 1024 2048 32768 2471.479999999985
19 intercube_allreduce torus_2d 6 2048 4096 65536 2763.479999999985
20 intercube_allreduce torus_2d 6 4096 8192 131072 3347.4799999999777
21 intercube_allreduce torus_2d 6 8192 16384 262144 4515.4799999999705
22 intercube_allreduce torus_2d 6 16384 32768 524288 6851.479999999952
23 intercube_allreduce torus_2d 6 32768 65536 1048576 11523.479999999923
24 intercube_allreduce mesh_2d_no_wrap 6 8 16 256 3508.4249999999993
25 intercube_allreduce mesh_2d_no_wrap 6 32 64 1024 3515.55
26 intercube_allreduce mesh_2d_no_wrap 6 64 128 2048 3525.0499999999975
27 intercube_allreduce mesh_2d_no_wrap 6 128 256 4096 3544.049999999992
28 intercube_allreduce mesh_2d_no_wrap 6 512 1024 16384 3667.049999999992
29 intercube_allreduce mesh_2d_no_wrap 6 1024 2048 32768 3837.049999999992
30 intercube_allreduce mesh_2d_no_wrap 6 2048 4096 65536 4177.049999999992
31 intercube_allreduce mesh_2d_no_wrap 6 4096 8192 131072 4857.049999999959
32 intercube_allreduce mesh_2d_no_wrap 6 8192 16384 262144 6217.049999999945
33 intercube_allreduce mesh_2d_no_wrap 6 16384 32768 524288 8937.049999999937
34 intercube_allreduce mesh_2d_no_wrap 6 32768 65536 1048576 14377.049999999872
Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

+78 -70
View File
@@ -22,13 +22,23 @@ from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
from kernbench.policy.placement.dp import DPPolicy
def _sip_topo_dims(sip_topo: str, n_sips: int) -> tuple[int, int]:
def _sip_topo_dims(
sip_topo: str, n_sips: int,
spec_w: int | None = None, spec_h: int | None = None,
) -> tuple[int, int]:
if sip_topo == "ring_1d":
return (0, 0)
if spec_w is not None and spec_h is not None:
if spec_w * spec_h != n_sips:
raise ValueError(
f"sip layout {spec_w}x{spec_h} != n_sips ({n_sips})"
)
return (spec_w, spec_h)
side = int(round(math.sqrt(n_sips)))
if side * side != n_sips:
raise ValueError(
f"SIP topology '{sip_topo}' requires square n_sips, got {n_sips}"
f"SIP topology '{sip_topo}' requires square n_sips or "
f"explicit w/h in spec, got {n_sips}"
)
return (side, side)
@@ -54,10 +64,13 @@ def run_allreduce(
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
n_elem = int(cfg.get("n_elem", 8))
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
sip_topo = str(
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
)
sips_cfg = spec.get("system", {}).get("sips", {})
n_sips = int(sips_cfg.get("count", 1))
sip_topo = str(sips_cfg.get("topology", "ring_1d"))
spec_sip_w = sips_cfg.get("w")
spec_sip_h = sips_cfg.get("h")
spec_sip_w = int(spec_sip_w) if spec_sip_w is not None else None
spec_sip_h = int(spec_sip_h) if spec_sip_h is not None else None
cm = spec["sip"]["cube_mesh"]
cube_w = int(cm["w"])
@@ -65,7 +78,9 @@ def run_allreduce(
n_cubes = cube_w * cube_h
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
sip_topo_w, sip_topo_h = _sip_topo_dims(sip_topo, n_sips)
sip_topo_w, sip_topo_h = _sip_topo_dims(
sip_topo, n_sips, spec_w=spec_sip_w, spec_h=spec_sip_h,
)
algo_name = cfg.get("algorithm", "allreduce")
print(f"\n{'=' * 60}")
@@ -173,20 +188,36 @@ from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
CONFIGS = [
pytest.param("intercube_allreduce", "ring_1d", 2, id="ring_2sip"),
pytest.param("intercube_allreduce", "torus_2d", 4, id="torus_4sip"),
pytest.param("intercube_allreduce", "mesh_2d_no_wrap", 4, id="mesh_4sip"),
pytest.param(
"intercube_allreduce", "ring_1d", 6, None, None,
id="ring_6sip",
),
pytest.param(
"intercube_allreduce", "torus_2d", 6, 2, 3,
id="torus_6sip_2x3",
),
pytest.param(
"intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3,
id="mesh_6sip_2x3",
),
]
def _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None,
sip_w=None, sip_h=None,
):
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
with open(TOPOLOGY_PATH) as f:
topo_cfg = yaml.safe_load(f)
topo_cfg["system"]["sips"]["count"] = n_sips
topo_cfg["system"]["sips"]["topology"] = sip_topology
if sip_w is not None and sip_h is not None:
topo_cfg["system"]["sips"]["w"] = int(sip_w)
topo_cfg["system"]["sips"]["h"] = int(sip_h)
else:
topo_cfg["system"]["sips"].pop("w", None)
topo_cfg["system"]["sips"].pop("h", None)
topo_path = tmp_path / "topology.yaml"
with open(topo_path, "w") as f:
yaml.dump(topo_cfg, f, default_flow_style=False)
@@ -211,10 +242,15 @@ def _write_temp_configs(
return str(topo_path), str(tmp_ccl)
@pytest.mark.parametrize("algorithm,sip_topology,n_sips", CONFIGS)
def test_allreduce(tmp_path, algorithm, sip_topology, n_sips):
@pytest.mark.parametrize(
"algorithm,sip_topology,n_sips,sip_w,sip_h", CONFIGS,
)
def test_allreduce(
tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h,
):
topo_path, ccl_path = _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
@@ -271,16 +307,17 @@ def test_allreduce_latency_sweep(tmp_path):
records: list[dict] = []
# Apples-to-apples: same n_sips across all three topologies.
for algorithm, sip_topology, n_sips in [
("intercube_allreduce", "ring_1d", 4),
("intercube_allreduce", "torus_2d", 4),
("intercube_allreduce", "mesh_2d_no_wrap", 4),
for algorithm, sip_topology, n_sips, sip_w, sip_h in [
("intercube_allreduce", "ring_1d", 6, None, None),
("intercube_allreduce", "torus_2d", 6, 2, 3),
("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
]:
for n_elem in _SWEEP_N_ELEM:
sub = tmp_path / f"{sip_topology}_{n_elem}"
sub.mkdir()
topo_path, ccl_path = _write_temp_configs(
sub, sip_topology, n_sips, algorithm,
sip_w=sip_w, sip_h=sip_h,
n_elem_override=n_elem,
)
topo = resolve_topology(topo_path)
@@ -339,8 +376,7 @@ def test_allreduce_latency_sweep(tmp_path):
w.writerow(r)
topologies = sorted({r["sip_topology"] for r in records})
# Per-topology plots: log-scale + linear-scale side-by-side.
# X-axis = bytes per PE (per-message payload size).
# Per-topology plots, log-scale x-axis = bytes per PE.
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
@@ -352,7 +388,6 @@ def test_allreduce_latency_sweep(tmp_path):
f"Allreduce latency — {topo_name} "
f"(n_sips={rs[0]['n_sips']})"
)
# Log-scale
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(xs, ys, marker="o", color="tab:blue")
ax.set_xscale("log", base=2)
@@ -364,58 +399,31 @@ def test_allreduce_latency_sweep(tmp_path):
fig.tight_layout()
fig.savefig(out_dir / f"{topo_name}.png", dpi=120)
plt.close(fig)
# Linear-scale companion
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(xs, ys, marker="o", color="tab:blue")
ax.set_xlabel("Bytes per PE")
ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title(title + " [linear scale]")
ax.grid(True, alpha=0.3)
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(out_dir / f"{topo_name}_linear.png", dpi=120)
plt.close(fig)
# Combined overview — two variants: log-scale (overview.png) and
# linear-scale (overview_linear.png).
colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange",
"mesh_2d_no_wrap": "tab:green"}
fig, ax = plt.subplots(figsize=(9, 6))
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"],
)
ax.plot(
[r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs],
marker="o",
label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
color=colors.get(topo_name),
)
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title("Multi-device allreduce latency by topology")
ax.grid(True, alpha=0.3)
ax.legend()
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(out_dir / "overview.png", dpi=120)
plt.close(fig)
def _draw_overview(log_x: bool, filename: str, title_suffix: str) -> None:
fig, ax = plt.subplots(figsize=(9, 6))
for topo_name in topologies:
rs = sorted(
[r for r in records if r["sip_topology"] == topo_name],
key=lambda r: r["bytes_per_pe"],
)
ax.plot(
[r["bytes_per_pe"] for r in rs],
[r["latency_ns"] for r in rs],
marker="o",
label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
color=colors.get(topo_name),
)
if log_x:
ax.set_xscale("log", base=2)
ax.set_xlabel("Bytes per PE (log scale)")
else:
ax.set_xlabel("Bytes per PE")
ax.set_ylabel("max pe_exec_ns (critical path)")
ax.set_title("Multi-device allreduce latency by topology" + title_suffix)
ax.grid(True, alpha=0.3)
ax.legend()
ax.xaxis.set_major_formatter(_bytes_fmt)
fig.tight_layout()
fig.savefig(out_dir / filename, dpi=120)
plt.close(fig)
_draw_overview(log_x=True, filename="overview.png", title_suffix="")
_draw_overview(
log_x=False, filename="overview_linear.png",
title_suffix=" [linear scale]",
)
print(
f"\nWrote {out_dir / 'overview.png'} + "
f"{out_dir / 'overview_linear.png'}"
)
print(f"\nWrote {out_dir / 'overview.png'}")
+106
View File
@@ -0,0 +1,106 @@
"""Rectangular (non-square) SIP-level 2D topology support.
Phase 1 regression target: today the 2D builtin topology functions in
``kernbench.ccl.topologies`` (``mesh_2d``, ``torus_2d``,
``mesh_2d_no_wrap``) hardcode ``side = sqrt(world_size)`` and raise
``ValueError`` for any non-square ``world_size``. This blocks running
the allreduce sweep at n_sips=6 on torus/mesh layouts.
Phase 2 will extend these functions to accept optional ``w, h`` kwargs
so a 2×3 (or 3×2, etc.) layout works. Until then, every test below is
expected to FAIL.
Layout convention used here (matches non-rectangular case):
rank = row * w + col for 0 <= row < h, 0 <= col < w
For w=2, h=3, world_size=6 the layout is:
col=0 col=1
row=0: 0 1
row=1: 2 3
row=2: 4 5
"""
from __future__ import annotations
import pytest
from kernbench.ccl.topologies import (
mesh_2d,
mesh_2d_no_wrap,
torus_2d,
)
# ── mesh_2d_no_wrap (no wrap-around) ──────────────────────────────────
def test_mesh_2d_no_wrap_2x3_top_left():
"""rank 0 (top-left, no N, no W): only S and E."""
nbrs = mesh_2d_no_wrap(rank=0, world_size=6, w=2, h=3)
assert nbrs == {"S": 2, "E": 1}, nbrs
def test_mesh_2d_no_wrap_2x3_top_right():
"""rank 1 (top-right, no N, no E): only S and W."""
nbrs = mesh_2d_no_wrap(rank=1, world_size=6, w=2, h=3)
assert nbrs == {"S": 3, "W": 0}, nbrs
def test_mesh_2d_no_wrap_2x3_middle_left():
"""rank 2 (middle-left, no W): N, S, E."""
nbrs = mesh_2d_no_wrap(rank=2, world_size=6, w=2, h=3)
assert nbrs == {"N": 0, "S": 4, "E": 3}, nbrs
def test_mesh_2d_no_wrap_2x3_bottom_right():
"""rank 5 (bottom-right, no S, no E): only N and W."""
nbrs = mesh_2d_no_wrap(rank=5, world_size=6, w=2, h=3)
assert nbrs == {"N": 3, "W": 4}, nbrs
# ── torus_2d (wrap-around on all four edges) ─────────────────────────
def test_torus_2d_2x3_top_left():
"""rank 0: N wraps to row 2 col 0 (rank 4); W wraps to col 1 (rank 1)."""
nbrs = torus_2d(rank=0, world_size=6, w=2, h=3)
assert nbrs == {"N": 4, "S": 2, "W": 1, "E": 1}, nbrs
def test_torus_2d_2x3_bottom_right():
"""rank 5: S wraps to row 0 (rank 1); E wraps to col 0 (rank 4)."""
nbrs = torus_2d(rank=5, world_size=6, w=2, h=3)
assert nbrs == {"N": 3, "S": 1, "W": 4, "E": 4}, nbrs
# ── mesh_2d alias for torus_2d ───────────────────────────────────────
def test_mesh_2d_2x3_matches_torus_2d():
"""mesh_2d is currently a torus alias; behaviour must match torus_2d."""
for rank in range(6):
assert mesh_2d(rank=rank, world_size=6, w=2, h=3) == \
torus_2d(rank=rank, world_size=6, w=2, h=3)
# ── Back-compat: square layouts still work without w/h kwargs ────────
def test_square_back_compat_mesh_2d_no_wrap():
"""Calling without w, h should still work for square world_size."""
nbrs = mesh_2d_no_wrap(rank=0, world_size=4)
assert nbrs == {"S": 2, "E": 1}, nbrs
def test_square_back_compat_torus_2d():
nbrs = torus_2d(rank=0, world_size=4)
assert nbrs == {"N": 2, "S": 2, "W": 1, "E": 1}, nbrs
# ── Validation: w*h must match world_size ────────────────────────────
def test_rectangular_dims_must_match_world_size():
"""Phase 2 contract: explicit w, h must satisfy w*h == world_size."""
with pytest.raises(ValueError):
mesh_2d_no_wrap(rank=0, world_size=6, w=3, h=3) # 9 != 6