Rectangular SIP topology + 6-device allreduce sweep
mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs; sqrt fall-back preserved for square layouts (back-compat tests confirm 4-SIP and 9-SIP square configs still work). sfr_config reads system.sips.w/h from spec and threads dims through to the topology fn. test_allreduce_multidevice CONFIGS switched from 4 SIPs (square) to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3. _write_temp_configs writes system.sips.w/h when supplied; _sip_topo_dims reads them back. Latency sweep loop also moved to 6-SIP layouts. Linear-scale plot variants dropped -- only log-scale *.png + summary.csv emitted. Plots in tests/allreduce_latency_plots regenerated. New tests/test_sip_topology_rectangular.py asserts neighbor correctness for 2x3 layouts and back-compat for square fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
"""Rectangular (non-square) SIP-level 2D topology support.
|
||||
|
||||
Phase 1 regression target: today the 2D builtin topology functions in
|
||||
``kernbench.ccl.topologies`` (``mesh_2d``, ``torus_2d``,
|
||||
``mesh_2d_no_wrap``) hardcode ``side = sqrt(world_size)`` and raise
|
||||
``ValueError`` for any non-square ``world_size``. This blocks running
|
||||
the allreduce sweep at n_sips=6 on torus/mesh layouts.
|
||||
|
||||
Phase 2 will extend these functions to accept optional ``w, h`` kwargs
|
||||
so a 2×3 (or 3×2, etc.) layout works. Until then, every test below is
|
||||
expected to FAIL.
|
||||
|
||||
Layout convention used here (matches non-rectangular case):
|
||||
rank = row * w + col for 0 <= row < h, 0 <= col < w
|
||||
|
||||
For w=2, h=3, world_size=6 the layout is:
|
||||
|
||||
col=0 col=1
|
||||
row=0: 0 1
|
||||
row=1: 2 3
|
||||
row=2: 4 5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.ccl.topologies import (
|
||||
mesh_2d,
|
||||
mesh_2d_no_wrap,
|
||||
torus_2d,
|
||||
)
|
||||
|
||||
|
||||
# ── mesh_2d_no_wrap (no wrap-around) ──────────────────────────────────
|
||||
|
||||
|
||||
def test_mesh_2d_no_wrap_2x3_top_left():
|
||||
"""rank 0 (top-left, no N, no W): only S and E."""
|
||||
nbrs = mesh_2d_no_wrap(rank=0, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"S": 2, "E": 1}, nbrs
|
||||
|
||||
|
||||
def test_mesh_2d_no_wrap_2x3_top_right():
|
||||
"""rank 1 (top-right, no N, no E): only S and W."""
|
||||
nbrs = mesh_2d_no_wrap(rank=1, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"S": 3, "W": 0}, nbrs
|
||||
|
||||
|
||||
def test_mesh_2d_no_wrap_2x3_middle_left():
|
||||
"""rank 2 (middle-left, no W): N, S, E."""
|
||||
nbrs = mesh_2d_no_wrap(rank=2, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"N": 0, "S": 4, "E": 3}, nbrs
|
||||
|
||||
|
||||
def test_mesh_2d_no_wrap_2x3_bottom_right():
|
||||
"""rank 5 (bottom-right, no S, no E): only N and W."""
|
||||
nbrs = mesh_2d_no_wrap(rank=5, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"N": 3, "W": 4}, nbrs
|
||||
|
||||
|
||||
# ── torus_2d (wrap-around on all four edges) ─────────────────────────
|
||||
|
||||
|
||||
def test_torus_2d_2x3_top_left():
|
||||
"""rank 0: N wraps to row 2 col 0 (rank 4); W wraps to col 1 (rank 1)."""
|
||||
nbrs = torus_2d(rank=0, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"N": 4, "S": 2, "W": 1, "E": 1}, nbrs
|
||||
|
||||
|
||||
def test_torus_2d_2x3_bottom_right():
|
||||
"""rank 5: S wraps to row 0 (rank 1); E wraps to col 0 (rank 4)."""
|
||||
nbrs = torus_2d(rank=5, world_size=6, w=2, h=3)
|
||||
assert nbrs == {"N": 3, "S": 1, "W": 4, "E": 4}, nbrs
|
||||
|
||||
|
||||
# ── mesh_2d alias for torus_2d ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_mesh_2d_2x3_matches_torus_2d():
|
||||
"""mesh_2d is currently a torus alias; behaviour must match torus_2d."""
|
||||
for rank in range(6):
|
||||
assert mesh_2d(rank=rank, world_size=6, w=2, h=3) == \
|
||||
torus_2d(rank=rank, world_size=6, w=2, h=3)
|
||||
|
||||
|
||||
# ── Back-compat: square layouts still work without w/h kwargs ────────
|
||||
|
||||
|
||||
def test_square_back_compat_mesh_2d_no_wrap():
|
||||
"""Calling without w, h should still work for square world_size."""
|
||||
nbrs = mesh_2d_no_wrap(rank=0, world_size=4)
|
||||
assert nbrs == {"S": 2, "E": 1}, nbrs
|
||||
|
||||
|
||||
def test_square_back_compat_torus_2d():
|
||||
nbrs = torus_2d(rank=0, world_size=4)
|
||||
assert nbrs == {"N": 2, "S": 2, "W": 1, "E": 1}, nbrs
|
||||
|
||||
|
||||
# ── Validation: w*h must match world_size ────────────────────────────
|
||||
|
||||
|
||||
def test_rectangular_dims_must_match_world_size():
|
||||
"""Phase 2 contract: explicit w, h must satisfy w*h == world_size."""
|
||||
with pytest.raises(ValueError):
|
||||
mesh_2d_no_wrap(rank=0, world_size=6, w=3, h=3) # 9 != 6
|
||||
Reference in New Issue
Block a user