"""Rectangular (non-square) SIP-level 2D topology support. Phase 1 regression target: today the 2D builtin topology functions in ``kernbench.ccl.topologies`` (``mesh_2d``, ``torus_2d``, ``mesh_2d_no_wrap``) hardcode ``side = sqrt(world_size)`` and raise ``ValueError`` for any non-square ``world_size``. This blocks running the allreduce sweep at n_sips=6 on torus/mesh layouts. Phase 2 will extend these functions to accept optional ``w, h`` kwargs so a 2×3 (or 3×2, etc.) layout works. Until then, every test below is expected to FAIL. Layout convention used here (matches non-rectangular case): rank = row * w + col for 0 <= row < h, 0 <= col < w For w=2, h=3, world_size=6 the layout is: col=0 col=1 row=0: 0 1 row=1: 2 3 row=2: 4 5 """ from __future__ import annotations import pytest from kernbench.ccl.topologies import ( mesh_2d, mesh_2d_no_wrap, torus_2d, ) # ── mesh_2d_no_wrap (no wrap-around) ────────────────────────────────── def test_mesh_2d_no_wrap_2x3_top_left(): """rank 0 (top-left, no N, no W): only S and E.""" nbrs = mesh_2d_no_wrap(rank=0, world_size=6, w=2, h=3) assert nbrs == {"S": 2, "E": 1}, nbrs def test_mesh_2d_no_wrap_2x3_top_right(): """rank 1 (top-right, no N, no E): only S and W.""" nbrs = mesh_2d_no_wrap(rank=1, world_size=6, w=2, h=3) assert nbrs == {"S": 3, "W": 0}, nbrs def test_mesh_2d_no_wrap_2x3_middle_left(): """rank 2 (middle-left, no W): N, S, E.""" nbrs = mesh_2d_no_wrap(rank=2, world_size=6, w=2, h=3) assert nbrs == {"N": 0, "S": 4, "E": 3}, nbrs def test_mesh_2d_no_wrap_2x3_bottom_right(): """rank 5 (bottom-right, no S, no E): only N and W.""" nbrs = mesh_2d_no_wrap(rank=5, world_size=6, w=2, h=3) assert nbrs == {"N": 3, "W": 4}, nbrs # ── torus_2d (wrap-around on all four edges) ───────────────────────── def test_torus_2d_2x3_top_left(): """rank 0: N wraps to row 2 col 0 (rank 4); W wraps to col 1 (rank 1).""" nbrs = torus_2d(rank=0, world_size=6, w=2, h=3) assert nbrs == {"N": 4, "S": 2, "W": 1, "E": 1}, nbrs def test_torus_2d_2x3_bottom_right(): """rank 5: S wraps to row 0 (rank 1); E wraps to col 0 (rank 4).""" nbrs = torus_2d(rank=5, world_size=6, w=2, h=3) assert nbrs == {"N": 3, "S": 1, "W": 4, "E": 4}, nbrs # ── mesh_2d alias for torus_2d ─────────────────────────────────────── def test_mesh_2d_2x3_matches_torus_2d(): """mesh_2d is currently a torus alias; behaviour must match torus_2d.""" for rank in range(6): assert mesh_2d(rank=rank, world_size=6, w=2, h=3) == \ torus_2d(rank=rank, world_size=6, w=2, h=3) # ── Back-compat: square layouts still work without w/h kwargs ──────── def test_square_back_compat_mesh_2d_no_wrap(): """Calling without w, h should still work for square world_size.""" nbrs = mesh_2d_no_wrap(rank=0, world_size=4) assert nbrs == {"S": 2, "E": 1}, nbrs def test_square_back_compat_torus_2d(): nbrs = torus_2d(rank=0, world_size=4) assert nbrs == {"N": 2, "S": 2, "W": 1, "E": 1}, nbrs # ── Validation: w*h must match world_size ──────────────────────────── def test_rectangular_dims_must_match_world_size(): """Phase 2 contract: explicit w, h must satisfy w*h == world_size.""" with pytest.raises(ValueError): mesh_2d_no_wrap(rank=0, world_size=6, w=3, h=3) # 9 != 6