Rectangular SIP topology + 6-device allreduce sweep

mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs;
sqrt fall-back preserved for square layouts (back-compat tests
confirm 4-SIP and 9-SIP square configs still work). sfr_config
reads system.sips.w/h from spec and threads dims through to the
topology fn.

test_allreduce_multidevice CONFIGS switched from 4 SIPs (square)
to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3.
_write_temp_configs writes system.sips.w/h when supplied;
_sip_topo_dims reads them back. Latency sweep loop also moved to
6-SIP layouts. Linear-scale plot variants dropped -- only log-scale
*.png + summary.csv emitted. Plots in tests/allreduce_latency_plots
regenerated.

New tests/test_sip_topology_rectangular.py asserts neighbor
correctness for 2x3 layouts and back-compat for square fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 15:13:14 -07:00
parent c1a5cf3a2a
commit e9cc40f74d
9 changed files with 362 additions and 143 deletions
+49 -37
View File
@@ -33,23 +33,41 @@ def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap:
return {"E": (rank + 1) % world_size}
def mesh_2d(rank: int, world_size: int) -> NeighborMap:
"""Square 2D mesh (N/S/E/W).
Layout: rank = row * side + col, with side = sqrt(world_size).
Wrap-around (torus) on all four edges.
"""
def _resolve_2d_dims(
world_size: int, w: int | None, h: int | None, name: str,
) -> tuple[int, int]:
if w is not None and h is not None:
if w * h != world_size:
raise ValueError(
f"{name}: w*h ({w}*{h}) != world_size ({world_size})"
)
return w, h
side = int(round(world_size ** 0.5))
if side * side != world_size:
raise ValueError(
f"mesh_2d requires square world_size, got {world_size}"
f"{name} requires square world_size or explicit w,h, "
f"got {world_size}"
)
r, c = divmod(rank, side)
return side, side
def mesh_2d(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D mesh (N/S/E/W) with wrap-around on all four edges.
Layout: rank = row * w + col. When w, h are given, supports
rectangular (e.g. 2x3) layouts. Otherwise falls back to square
side = sqrt(world_size).
"""
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d")
r, c = divmod(rank, w)
return {
"N": ((r - 1) % side) * side + c,
"S": ((r + 1) % side) * side + c,
"W": r * side + (c - 1) % side,
"E": r * side + (c + 1) % side,
"N": ((r - 1) % h) * w + c,
"S": ((r + 1) % h) * w + c,
"W": r * w + (c - 1) % w,
"E": r * w + (c + 1) % w,
}
@@ -73,36 +91,30 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap:
return n
def torus_2d(rank: int, world_size: int) -> NeighborMap:
"""Square 2D torus (N/S/E/W) with wrap-around on all edges.
Alias for mesh_2d (which already wraps). Explicit name for clarity
when used as a SIP-level topology.
"""
return mesh_2d(rank, world_size)
def torus_2d(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D torus (N/S/E/W) with wrap-around on all edges. Alias for mesh_2d."""
return mesh_2d(rank, world_size, w=w, h=h)
def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap:
"""Square 2D mesh (N/S/E/W) WITHOUT wrap-around.
Edge nodes have fewer neighbors (no wrapping). Used for SIP-level
topologies where physical links don't wrap.
"""
side = int(round(world_size ** 0.5))
if side * side != world_size:
raise ValueError(
f"mesh_2d_no_wrap requires square world_size, got {world_size}"
)
r, c = divmod(rank, side)
def mesh_2d_no_wrap(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D mesh (N/S/E/W) WITHOUT wrap-around. Supports rectangular dims."""
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d_no_wrap")
r, c = divmod(rank, w)
n: NeighborMap = {}
if r > 0:
n["N"] = (r - 1) * side + c
if r < side - 1:
n["S"] = (r + 1) * side + c
n["N"] = (r - 1) * w + c
if r < h - 1:
n["S"] = (r + 1) * w + c
if c > 0:
n["W"] = r * side + (c - 1)
if c < side - 1:
n["E"] = r * side + (c + 1)
n["W"] = r * w + (c - 1)
if c < w - 1:
n["E"] = r * w + (c + 1)
return n