Files
kernbench2/tests/test_sip_topology_rectangular.py
mukesh e9cc40f74d Rectangular SIP topology + 6-device allreduce sweep
mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs;
sqrt fall-back preserved for square layouts (back-compat tests
confirm 4-SIP and 9-SIP square configs still work). sfr_config
reads system.sips.w/h from spec and threads dims through to the
topology fn.

test_allreduce_multidevice CONFIGS switched from 4 SIPs (square)
to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3.
_write_temp_configs writes system.sips.w/h when supplied;
_sip_topo_dims reads them back. Latency sweep loop also moved to
6-SIP layouts. Linear-scale plot variants dropped -- only log-scale
*.png + summary.csv emitted. Plots in tests/allreduce_latency_plots
regenerated.

New tests/test_sip_topology_rectangular.py asserts neighbor
correctness for 2x3 layouts and back-compat for square fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 15:13:14 -07:00

107 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Rectangular (non-square) SIP-level 2D topology support.
Phase 1 regression target: today the 2D builtin topology functions in
``kernbench.ccl.topologies`` (``mesh_2d``, ``torus_2d``,
``mesh_2d_no_wrap``) hardcode ``side = sqrt(world_size)`` and raise
``ValueError`` for any non-square ``world_size``. This blocks running
the allreduce sweep at n_sips=6 on torus/mesh layouts.
Phase 2 will extend these functions to accept optional ``w, h`` kwargs
so a 2×3 (or 3×2, etc.) layout works. Until then, every test below is
expected to FAIL.
Layout convention used here (matches non-rectangular case):
rank = row * w + col for 0 <= row < h, 0 <= col < w
For w=2, h=3, world_size=6 the layout is:
col=0 col=1
row=0: 0 1
row=1: 2 3
row=2: 4 5
"""
from __future__ import annotations
import pytest
from kernbench.ccl.topologies import (
mesh_2d,
mesh_2d_no_wrap,
torus_2d,
)
# ── mesh_2d_no_wrap (no wrap-around) ──────────────────────────────────
def test_mesh_2d_no_wrap_2x3_top_left():
"""rank 0 (top-left, no N, no W): only S and E."""
nbrs = mesh_2d_no_wrap(rank=0, world_size=6, w=2, h=3)
assert nbrs == {"S": 2, "E": 1}, nbrs
def test_mesh_2d_no_wrap_2x3_top_right():
"""rank 1 (top-right, no N, no E): only S and W."""
nbrs = mesh_2d_no_wrap(rank=1, world_size=6, w=2, h=3)
assert nbrs == {"S": 3, "W": 0}, nbrs
def test_mesh_2d_no_wrap_2x3_middle_left():
"""rank 2 (middle-left, no W): N, S, E."""
nbrs = mesh_2d_no_wrap(rank=2, world_size=6, w=2, h=3)
assert nbrs == {"N": 0, "S": 4, "E": 3}, nbrs
def test_mesh_2d_no_wrap_2x3_bottom_right():
"""rank 5 (bottom-right, no S, no E): only N and W."""
nbrs = mesh_2d_no_wrap(rank=5, world_size=6, w=2, h=3)
assert nbrs == {"N": 3, "W": 4}, nbrs
# ── torus_2d (wrap-around on all four edges) ─────────────────────────
def test_torus_2d_2x3_top_left():
"""rank 0: N wraps to row 2 col 0 (rank 4); W wraps to col 1 (rank 1)."""
nbrs = torus_2d(rank=0, world_size=6, w=2, h=3)
assert nbrs == {"N": 4, "S": 2, "W": 1, "E": 1}, nbrs
def test_torus_2d_2x3_bottom_right():
"""rank 5: S wraps to row 0 (rank 1); E wraps to col 0 (rank 4)."""
nbrs = torus_2d(rank=5, world_size=6, w=2, h=3)
assert nbrs == {"N": 3, "S": 1, "W": 4, "E": 4}, nbrs
# ── mesh_2d alias for torus_2d ───────────────────────────────────────
def test_mesh_2d_2x3_matches_torus_2d():
"""mesh_2d is currently a torus alias; behaviour must match torus_2d."""
for rank in range(6):
assert mesh_2d(rank=rank, world_size=6, w=2, h=3) == \
torus_2d(rank=rank, world_size=6, w=2, h=3)
# ── Back-compat: square layouts still work without w/h kwargs ────────
def test_square_back_compat_mesh_2d_no_wrap():
"""Calling without w, h should still work for square world_size."""
nbrs = mesh_2d_no_wrap(rank=0, world_size=4)
assert nbrs == {"S": 2, "E": 1}, nbrs
def test_square_back_compat_torus_2d():
nbrs = torus_2d(rank=0, world_size=4)
assert nbrs == {"N": 2, "S": 2, "W": 1, "E": 1}, nbrs
# ── Validation: w*h must match world_size ────────────────────────────
def test_rectangular_dims_must_match_world_size():
"""Phase 2 contract: explicit w, h must satisfy w*h == world_size."""
with pytest.raises(ValueError):
mesh_2d_no_wrap(rank=0, world_size=6, w=3, h=3) # 9 != 6