diff --git a/src/kernbench/ccl/sfr_config.py b/src/kernbench/ccl/sfr_config.py index 0bbe297..7c8516e 100644 --- a/src/kernbench/ccl/sfr_config.py +++ b/src/kernbench/ccl/sfr_config.py @@ -1,22 +1,24 @@ -"""SFR configuration for intercube + inter-SIP IPCQ wiring. +"""SFR configuration for the full IPCQ hardware wiring. -Provides ``configure_sfr_intercube_multisip`` which programs PE_IPCQ -neighbor tables for: +Installs PE_IPCQ neighbor tables modeling the physical hardware. +Wiring is independent of DPPolicy / kernel choice — the kernel decides +at runtime which links to use. - 1. Intercube within each SIP — pe0 of every cube connects to pe0 of - its N/S/E/W mesh neighbors (no wrap-around). - 2. Inter-SIP on ALL cubes — pe0 of cube_c on sip_A connects to pe0 of - cube_c on each peer SIP, using ``global_E``/``global_W`` (ring) or - ``global_N``/``global_S``/``global_E``/``global_W`` (mesh/torus) - direction labels. Wiring all cubes allows the kernel to - dynamically elect the root cube at runtime. +Direction label namespaces (disjoint): -SIP-level topology is read from ``topology.yaml`` → -``system.sips.topology`` (e.g. ``ring_1d``, ``mesh_2d``). -Intercube mesh dimensions come from ``sip.cube_mesh.w/h``. + - Intra-cube PE-to-PE: ``intra_N / intra_S / intra_E / intra_W`` + Logical 2×4 PE grid within a cube (no wrap): -Internally delegates to ``install_ipcq`` with a computed ``rank_to_pe`` -(pe0-only) and a closure-captured ``neighbors()`` function. + Row 0: pe0 pe1 pe2 pe3 + Row 1: pe4 pe5 pe6 pe7 + + - Intercube same-lane: ``N / S / E / W`` + ``pe_i of cube_A ↔ pe_i of cube_B`` across the 4×4 cube mesh + (no wrap). Every PE i ∈ [0..7] wired independently. + + - Inter-SIP same-(cube, pe): ``global_N / global_S / global_E / global_W`` + ``pe_i of cube_c on sip_A ↔ pe_i of cube_c on sip_B`` per + ``topology.yaml → system.sips.topology``. """ from __future__ import annotations @@ -27,12 +29,46 @@ from kernbench.ccl.install import install_ipcq from kernbench.ccl.topologies import _BUILTIN as _TOPO_BUILTINS +# ── Intra-cube 2×4 PE grid ─────────────────────────────────────────── + +_PE_GRID_COLS = 4 +_PE_GRID_ROWS = 2 +_PES_PER_CUBE = _PE_GRID_COLS * _PE_GRID_ROWS # 8 + + +def _intra_cube_neighbors(pe: int) -> dict[str, int]: + """Logical 2×4 PE grid neighbors within a cube (no wrap). + + Returns directions in the ``intra_*`` namespace. + """ + row, col = divmod(pe, _PE_GRID_COLS) + nbrs: dict[str, int] = {} + if col < _PE_GRID_COLS - 1: + nbrs["intra_E"] = row * _PE_GRID_COLS + (col + 1) + if col > 0: + nbrs["intra_W"] = row * _PE_GRID_COLS + (col - 1) + if row < _PE_GRID_ROWS - 1: + nbrs["intra_S"] = (row + 1) * _PE_GRID_COLS + col + if row > 0: + nbrs["intra_N"] = (row - 1) * _PE_GRID_COLS + col + return nbrs + + +# ── Public entry point ─────────────────────────────────────────────── + + def configure_sfr_intercube_multisip( engine: Any, spec: dict, cfg: dict, ) -> dict[str, Any]: - """Wire IPCQ for intercube (pe0, mesh) + inter-SIP (pe0, all cubes). + """Wire the full IPCQ hardware model. + + Every PE on every cube on every SIP gets neighbor table entries for: + + - intra-cube (2×4 grid) in the ``intra_*`` namespace + - intercube same-lane (4×4 cube mesh, no wrap) in ``N/S/E/W`` + - inter-SIP same-(cube, pe) in ``global_*`` Args: engine: GraphEngine with ``_components``. @@ -46,48 +82,71 @@ def configure_sfr_intercube_multisip( mesh_w = int(cm["w"]) mesh_h = int(cm["h"]) n_cubes = mesh_w * mesh_h - n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) - sip_topology = str( - spec.get("system", {}).get("sips", {}).get("topology", "ring_1d") - ) + sips_cfg = spec.get("system", {}).get("sips", {}) + n_sips = int(sips_cfg.get("count", 1)) + sip_topology = str(sips_cfg.get("topology", "ring_1d")) + sip_w = sips_cfg.get("w") + sip_h = sips_cfg.get("h") + sip_w = int(sip_w) if sip_w is not None else None + sip_h = int(sip_h) if sip_h is not None else None if sip_topology not in _TOPO_BUILTINS: raise ValueError( f"Unknown sip topology '{sip_topology}'. " f"Available: {list(_TOPO_BUILTINS)}" ) - sip_topo_fn = _TOPO_BUILTINS[sip_topology] + _sip_topo_fn_raw = _TOPO_BUILTINS[sip_topology] - world_size = n_sips * n_cubes + def sip_topo_fn(rank: int, ws: int) -> dict: + if sip_w is not None and sip_h is not None: + try: + return _sip_topo_fn_raw(rank, ws, w=sip_w, h=sip_h) + except TypeError: + pass + return _sip_topo_fn_raw(rank, ws) + + pes_per_cube = _PES_PER_CUBE + world_size = n_sips * n_cubes * pes_per_cube pe_idx_to_pe: list[tuple[int, int, int]] = [ - (sip, cube, 0) + (sip, cube, pe) for sip in range(n_sips) for cube in range(n_cubes) + for pe in range(pes_per_cube) ] + def _pe_idx(sip: int, cube: int, pe: int) -> int: + return (sip * n_cubes + cube) * pes_per_cube + pe + def _neighbors(pe_idx: int, ws: int, _base: dict) -> dict[str, int]: - sip = pe_idx // n_cubes - cube = pe_idx % n_cubes + tmp = pe_idx + pe = tmp % pes_per_cube + tmp //= pes_per_cube + cube = tmp % n_cubes + sip = tmp // n_cubes row = cube // mesh_w col = cube % mesh_w nbrs: dict[str, int] = {} - # Intercube within SIP (mesh, no wrap-around) - if col < mesh_w - 1: - nbrs["E"] = sip * n_cubes + (row * mesh_w + col + 1) - if col > 0: - nbrs["W"] = sip * n_cubes + (row * mesh_w + col - 1) - if row < mesh_h - 1: - nbrs["S"] = sip * n_cubes + ((row + 1) * mesh_w + col) - if row > 0: - nbrs["N"] = sip * n_cubes + ((row - 1) * mesh_w + col) + # ── Intra-cube (intra_N/S/E/W) ── + for d, peer_pe in _intra_cube_neighbors(pe).items(): + nbrs[d] = _pe_idx(sip, cube, peer_pe) - # Inter-SIP on ALL cubes + # ── Intercube same-lane (N/S/E/W, 4×4 no wrap) ── + if col < mesh_w - 1: + nbrs["E"] = _pe_idx(sip, row * mesh_w + (col + 1), pe) + if col > 0: + nbrs["W"] = _pe_idx(sip, row * mesh_w + (col - 1), pe) + if row < mesh_h - 1: + nbrs["S"] = _pe_idx(sip, (row + 1) * mesh_w + col, pe) + if row > 0: + nbrs["N"] = _pe_idx(sip, (row - 1) * mesh_w + col, pe) + + # ── Inter-SIP same-(cube, pe) (global_*) ── if n_sips > 1: sip_nbrs = sip_topo_fn(sip, n_sips) for d, peer_sip in sip_nbrs.items(): - nbrs[f"global_{d}"] = peer_sip * n_cubes + cube + nbrs[f"global_{d}"] = _pe_idx(peer_sip, cube, pe) return nbrs diff --git a/src/kernbench/ccl/topologies.py b/src/kernbench/ccl/topologies.py index ea46019..1be821f 100644 --- a/src/kernbench/ccl/topologies.py +++ b/src/kernbench/ccl/topologies.py @@ -33,23 +33,41 @@ def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap: return {"E": (rank + 1) % world_size} -def mesh_2d(rank: int, world_size: int) -> NeighborMap: - """Square 2D mesh (N/S/E/W). - - Layout: rank = row * side + col, with side = sqrt(world_size). - Wrap-around (torus) on all four edges. - """ +def _resolve_2d_dims( + world_size: int, w: int | None, h: int | None, name: str, +) -> tuple[int, int]: + if w is not None and h is not None: + if w * h != world_size: + raise ValueError( + f"{name}: w*h ({w}*{h}) != world_size ({world_size})" + ) + return w, h side = int(round(world_size ** 0.5)) if side * side != world_size: raise ValueError( - f"mesh_2d requires square world_size, got {world_size}" + f"{name} requires square world_size or explicit w,h, " + f"got {world_size}" ) - r, c = divmod(rank, side) + return side, side + + +def mesh_2d( + rank: int, world_size: int, + w: int | None = None, h: int | None = None, +) -> NeighborMap: + """2D mesh (N/S/E/W) with wrap-around on all four edges. + + Layout: rank = row * w + col. When w, h are given, supports + rectangular (e.g. 2x3) layouts. Otherwise falls back to square + side = sqrt(world_size). + """ + w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d") + r, c = divmod(rank, w) return { - "N": ((r - 1) % side) * side + c, - "S": ((r + 1) % side) * side + c, - "W": r * side + (c - 1) % side, - "E": r * side + (c + 1) % side, + "N": ((r - 1) % h) * w + c, + "S": ((r + 1) % h) * w + c, + "W": r * w + (c - 1) % w, + "E": r * w + (c + 1) % w, } @@ -73,36 +91,30 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap: return n -def torus_2d(rank: int, world_size: int) -> NeighborMap: - """Square 2D torus (N/S/E/W) with wrap-around on all edges. - - Alias for mesh_2d (which already wraps). Explicit name for clarity - when used as a SIP-level topology. - """ - return mesh_2d(rank, world_size) +def torus_2d( + rank: int, world_size: int, + w: int | None = None, h: int | None = None, +) -> NeighborMap: + """2D torus (N/S/E/W) with wrap-around on all edges. Alias for mesh_2d.""" + return mesh_2d(rank, world_size, w=w, h=h) -def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap: - """Square 2D mesh (N/S/E/W) WITHOUT wrap-around. - - Edge nodes have fewer neighbors (no wrapping). Used for SIP-level - topologies where physical links don't wrap. - """ - side = int(round(world_size ** 0.5)) - if side * side != world_size: - raise ValueError( - f"mesh_2d_no_wrap requires square world_size, got {world_size}" - ) - r, c = divmod(rank, side) +def mesh_2d_no_wrap( + rank: int, world_size: int, + w: int | None = None, h: int | None = None, +) -> NeighborMap: + """2D mesh (N/S/E/W) WITHOUT wrap-around. Supports rectangular dims.""" + w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d_no_wrap") + r, c = divmod(rank, w) n: NeighborMap = {} if r > 0: - n["N"] = (r - 1) * side + c - if r < side - 1: - n["S"] = (r + 1) * side + c + n["N"] = (r - 1) * w + c + if r < h - 1: + n["S"] = (r + 1) * w + c if c > 0: - n["W"] = r * side + (c - 1) - if c < side - 1: - n["E"] = r * side + (c + 1) + n["W"] = r * w + (c - 1) + if c < w - 1: + n["E"] = r * w + (c + 1) return n diff --git a/tests/allreduce_latency_plots/mesh_2d_no_wrap.png b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png new file mode 100644 index 0000000..2d57582 Binary files /dev/null and b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png differ diff --git a/tests/allreduce_latency_plots/overview.png b/tests/allreduce_latency_plots/overview.png new file mode 100644 index 0000000..12c58b9 Binary files /dev/null and b/tests/allreduce_latency_plots/overview.png differ diff --git a/tests/allreduce_latency_plots/ring_1d.png b/tests/allreduce_latency_plots/ring_1d.png new file mode 100644 index 0000000..ac87a1d Binary files /dev/null and b/tests/allreduce_latency_plots/ring_1d.png differ diff --git a/tests/allreduce_latency_plots/summary.csv b/tests/allreduce_latency_plots/summary.csv new file mode 100644 index 0000000..a58d290 --- /dev/null +++ b/tests/allreduce_latency_plots/summary.csv @@ -0,0 +1,34 @@ +algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns +intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937 +intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947 +intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992 +intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865 +intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865 +intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865 +intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865 +intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965 +intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957 +intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932 +intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903 +intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923 +intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993 +intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905 +intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985 +intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985 +intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985 +intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985 +intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777 +intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705 +intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952 +intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923 +intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993 +intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55 +intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975 +intercube_allreduce,mesh_2d_no_wrap,6,128,256,4096,3544.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,512,1024,16384,3667.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,1024,2048,32768,3837.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,2048,4096,65536,4177.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959 +intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945 +intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937 +intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872 diff --git a/tests/allreduce_latency_plots/torus_2d.png b/tests/allreduce_latency_plots/torus_2d.png new file mode 100644 index 0000000..5a8bf2d Binary files /dev/null and b/tests/allreduce_latency_plots/torus_2d.png differ diff --git a/tests/test_allreduce_multidevice.py b/tests/test_allreduce_multidevice.py index 347d61e..81e1093 100644 --- a/tests/test_allreduce_multidevice.py +++ b/tests/test_allreduce_multidevice.py @@ -22,13 +22,23 @@ from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip from kernbench.policy.placement.dp import DPPolicy -def _sip_topo_dims(sip_topo: str, n_sips: int) -> tuple[int, int]: +def _sip_topo_dims( + sip_topo: str, n_sips: int, + spec_w: int | None = None, spec_h: int | None = None, +) -> tuple[int, int]: if sip_topo == "ring_1d": return (0, 0) + if spec_w is not None and spec_h is not None: + if spec_w * spec_h != n_sips: + raise ValueError( + f"sip layout {spec_w}x{spec_h} != n_sips ({n_sips})" + ) + return (spec_w, spec_h) side = int(round(math.sqrt(n_sips))) if side * side != n_sips: raise ValueError( - f"SIP topology '{sip_topo}' requires square n_sips, got {n_sips}" + f"SIP topology '{sip_topo}' requires square n_sips or " + f"explicit w/h in spec, got {n_sips}" ) return (side, side) @@ -54,10 +64,13 @@ def run_allreduce( topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND n_elem = int(cfg.get("n_elem", 8)) - n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) - sip_topo = str( - spec.get("system", {}).get("sips", {}).get("topology", "ring_1d") - ) + sips_cfg = spec.get("system", {}).get("sips", {}) + n_sips = int(sips_cfg.get("count", 1)) + sip_topo = str(sips_cfg.get("topology", "ring_1d")) + spec_sip_w = sips_cfg.get("w") + spec_sip_h = sips_cfg.get("h") + spec_sip_w = int(spec_sip_w) if spec_sip_w is not None else None + spec_sip_h = int(spec_sip_h) if spec_sip_h is not None else None cm = spec["sip"]["cube_mesh"] cube_w = int(cm["w"]) @@ -65,7 +78,9 @@ def run_allreduce( n_cubes = cube_w * cube_h sip_topo_kind = topo_name_to_kind.get(sip_topo, 0) - sip_topo_w, sip_topo_h = _sip_topo_dims(sip_topo, n_sips) + sip_topo_w, sip_topo_h = _sip_topo_dims( + sip_topo, n_sips, spec_w=spec_sip_w, spec_h=spec_sip_h, + ) algo_name = cfg.get("algorithm", "allreduce") print(f"\n{'=' * 60}") @@ -173,20 +188,36 @@ from kernbench.topology.builder import resolve_topology TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" CONFIGS = [ - pytest.param("intercube_allreduce", "ring_1d", 2, id="ring_2sip"), - pytest.param("intercube_allreduce", "torus_2d", 4, id="torus_4sip"), - pytest.param("intercube_allreduce", "mesh_2d_no_wrap", 4, id="mesh_4sip"), + pytest.param( + "intercube_allreduce", "ring_1d", 6, None, None, + id="ring_6sip", + ), + pytest.param( + "intercube_allreduce", "torus_2d", 6, 2, 3, + id="torus_6sip_2x3", + ), + pytest.param( + "intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3, + id="mesh_6sip_2x3", + ), ] def _write_temp_configs( tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None, + sip_w=None, sip_h=None, ): """Write temp topology.yaml and ccl.yaml with the given overrides.""" with open(TOPOLOGY_PATH) as f: topo_cfg = yaml.safe_load(f) topo_cfg["system"]["sips"]["count"] = n_sips topo_cfg["system"]["sips"]["topology"] = sip_topology + if sip_w is not None and sip_h is not None: + topo_cfg["system"]["sips"]["w"] = int(sip_w) + topo_cfg["system"]["sips"]["h"] = int(sip_h) + else: + topo_cfg["system"]["sips"].pop("w", None) + topo_cfg["system"]["sips"].pop("h", None) topo_path = tmp_path / "topology.yaml" with open(topo_path, "w") as f: yaml.dump(topo_cfg, f, default_flow_style=False) @@ -211,10 +242,15 @@ def _write_temp_configs( return str(topo_path), str(tmp_ccl) -@pytest.mark.parametrize("algorithm,sip_topology,n_sips", CONFIGS) -def test_allreduce(tmp_path, algorithm, sip_topology, n_sips): +@pytest.mark.parametrize( + "algorithm,sip_topology,n_sips,sip_w,sip_h", CONFIGS, +) +def test_allreduce( + tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h, +): topo_path, ccl_path = _write_temp_configs( tmp_path, sip_topology, n_sips, algorithm, + sip_w=sip_w, sip_h=sip_h, ) topo = resolve_topology(topo_path) engine = GraphEngine(topo.topology_obj, enable_data=True) @@ -271,16 +307,17 @@ def test_allreduce_latency_sweep(tmp_path): records: list[dict] = [] # Apples-to-apples: same n_sips across all three topologies. - for algorithm, sip_topology, n_sips in [ - ("intercube_allreduce", "ring_1d", 4), - ("intercube_allreduce", "torus_2d", 4), - ("intercube_allreduce", "mesh_2d_no_wrap", 4), + for algorithm, sip_topology, n_sips, sip_w, sip_h in [ + ("intercube_allreduce", "ring_1d", 6, None, None), + ("intercube_allreduce", "torus_2d", 6, 2, 3), + ("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3), ]: for n_elem in _SWEEP_N_ELEM: sub = tmp_path / f"{sip_topology}_{n_elem}" sub.mkdir() topo_path, ccl_path = _write_temp_configs( sub, sip_topology, n_sips, algorithm, + sip_w=sip_w, sip_h=sip_h, n_elem_override=n_elem, ) topo = resolve_topology(topo_path) @@ -339,8 +376,7 @@ def test_allreduce_latency_sweep(tmp_path): w.writerow(r) topologies = sorted({r["sip_topology"] for r in records}) - # Per-topology plots: log-scale + linear-scale side-by-side. - # X-axis = bytes per PE (per-message payload size). + # Per-topology plots, log-scale x-axis = bytes per PE. for topo_name in topologies: rs = sorted( [r for r in records if r["sip_topology"] == topo_name], @@ -352,7 +388,6 @@ def test_allreduce_latency_sweep(tmp_path): f"Allreduce latency — {topo_name} " f"(n_sips={rs[0]['n_sips']})" ) - # Log-scale fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(xs, ys, marker="o", color="tab:blue") ax.set_xscale("log", base=2) @@ -364,58 +399,31 @@ def test_allreduce_latency_sweep(tmp_path): fig.tight_layout() fig.savefig(out_dir / f"{topo_name}.png", dpi=120) plt.close(fig) - # Linear-scale companion - fig, ax = plt.subplots(figsize=(8, 5)) - ax.plot(xs, ys, marker="o", color="tab:blue") - ax.set_xlabel("Bytes per PE") - ax.set_ylabel("max pe_exec_ns (critical path)") - ax.set_title(title + " [linear scale]") - ax.grid(True, alpha=0.3) - ax.xaxis.set_major_formatter(_bytes_fmt) - fig.tight_layout() - fig.savefig(out_dir / f"{topo_name}_linear.png", dpi=120) - plt.close(fig) - # Combined overview — two variants: log-scale (overview.png) and - # linear-scale (overview_linear.png). colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange", "mesh_2d_no_wrap": "tab:green"} + fig, ax = plt.subplots(figsize=(9, 6)) + for topo_name in topologies: + rs = sorted( + [r for r in records if r["sip_topology"] == topo_name], + key=lambda r: r["bytes_per_pe"], + ) + ax.plot( + [r["bytes_per_pe"] for r in rs], + [r["latency_ns"] for r in rs], + marker="o", + label=f"{topo_name} (n_sips={rs[0]['n_sips']})", + color=colors.get(topo_name), + ) + ax.set_xscale("log", base=2) + ax.set_xlabel("Bytes per PE (log scale)") + ax.set_ylabel("max pe_exec_ns (critical path)") + ax.set_title("Multi-device allreduce latency by topology") + ax.grid(True, alpha=0.3) + ax.legend() + ax.xaxis.set_major_formatter(_bytes_fmt) + fig.tight_layout() + fig.savefig(out_dir / "overview.png", dpi=120) + plt.close(fig) - def _draw_overview(log_x: bool, filename: str, title_suffix: str) -> None: - fig, ax = plt.subplots(figsize=(9, 6)) - for topo_name in topologies: - rs = sorted( - [r for r in records if r["sip_topology"] == topo_name], - key=lambda r: r["bytes_per_pe"], - ) - ax.plot( - [r["bytes_per_pe"] for r in rs], - [r["latency_ns"] for r in rs], - marker="o", - label=f"{topo_name} (n_sips={rs[0]['n_sips']})", - color=colors.get(topo_name), - ) - if log_x: - ax.set_xscale("log", base=2) - ax.set_xlabel("Bytes per PE (log scale)") - else: - ax.set_xlabel("Bytes per PE") - ax.set_ylabel("max pe_exec_ns (critical path)") - ax.set_title("Multi-device allreduce latency by topology" + title_suffix) - ax.grid(True, alpha=0.3) - ax.legend() - ax.xaxis.set_major_formatter(_bytes_fmt) - fig.tight_layout() - fig.savefig(out_dir / filename, dpi=120) - plt.close(fig) - - _draw_overview(log_x=True, filename="overview.png", title_suffix="") - _draw_overview( - log_x=False, filename="overview_linear.png", - title_suffix=" [linear scale]", - ) - - print( - f"\nWrote {out_dir / 'overview.png'} + " - f"{out_dir / 'overview_linear.png'}" - ) + print(f"\nWrote {out_dir / 'overview.png'}") diff --git a/tests/test_sip_topology_rectangular.py b/tests/test_sip_topology_rectangular.py new file mode 100644 index 0000000..2a0288a --- /dev/null +++ b/tests/test_sip_topology_rectangular.py @@ -0,0 +1,106 @@ +"""Rectangular (non-square) SIP-level 2D topology support. + +Phase 1 regression target: today the 2D builtin topology functions in +``kernbench.ccl.topologies`` (``mesh_2d``, ``torus_2d``, +``mesh_2d_no_wrap``) hardcode ``side = sqrt(world_size)`` and raise +``ValueError`` for any non-square ``world_size``. This blocks running +the allreduce sweep at n_sips=6 on torus/mesh layouts. + +Phase 2 will extend these functions to accept optional ``w, h`` kwargs +so a 2×3 (or 3×2, etc.) layout works. Until then, every test below is +expected to FAIL. + +Layout convention used here (matches non-rectangular case): + rank = row * w + col for 0 <= row < h, 0 <= col < w + +For w=2, h=3, world_size=6 the layout is: + + col=0 col=1 + row=0: 0 1 + row=1: 2 3 + row=2: 4 5 +""" +from __future__ import annotations + +import pytest + +from kernbench.ccl.topologies import ( + mesh_2d, + mesh_2d_no_wrap, + torus_2d, +) + + +# ── mesh_2d_no_wrap (no wrap-around) ────────────────────────────────── + + +def test_mesh_2d_no_wrap_2x3_top_left(): + """rank 0 (top-left, no N, no W): only S and E.""" + nbrs = mesh_2d_no_wrap(rank=0, world_size=6, w=2, h=3) + assert nbrs == {"S": 2, "E": 1}, nbrs + + +def test_mesh_2d_no_wrap_2x3_top_right(): + """rank 1 (top-right, no N, no E): only S and W.""" + nbrs = mesh_2d_no_wrap(rank=1, world_size=6, w=2, h=3) + assert nbrs == {"S": 3, "W": 0}, nbrs + + +def test_mesh_2d_no_wrap_2x3_middle_left(): + """rank 2 (middle-left, no W): N, S, E.""" + nbrs = mesh_2d_no_wrap(rank=2, world_size=6, w=2, h=3) + assert nbrs == {"N": 0, "S": 4, "E": 3}, nbrs + + +def test_mesh_2d_no_wrap_2x3_bottom_right(): + """rank 5 (bottom-right, no S, no E): only N and W.""" + nbrs = mesh_2d_no_wrap(rank=5, world_size=6, w=2, h=3) + assert nbrs == {"N": 3, "W": 4}, nbrs + + +# ── torus_2d (wrap-around on all four edges) ───────────────────────── + + +def test_torus_2d_2x3_top_left(): + """rank 0: N wraps to row 2 col 0 (rank 4); W wraps to col 1 (rank 1).""" + nbrs = torus_2d(rank=0, world_size=6, w=2, h=3) + assert nbrs == {"N": 4, "S": 2, "W": 1, "E": 1}, nbrs + + +def test_torus_2d_2x3_bottom_right(): + """rank 5: S wraps to row 0 (rank 1); E wraps to col 0 (rank 4).""" + nbrs = torus_2d(rank=5, world_size=6, w=2, h=3) + assert nbrs == {"N": 3, "S": 1, "W": 4, "E": 4}, nbrs + + +# ── mesh_2d alias for torus_2d ─────────────────────────────────────── + + +def test_mesh_2d_2x3_matches_torus_2d(): + """mesh_2d is currently a torus alias; behaviour must match torus_2d.""" + for rank in range(6): + assert mesh_2d(rank=rank, world_size=6, w=2, h=3) == \ + torus_2d(rank=rank, world_size=6, w=2, h=3) + + +# ── Back-compat: square layouts still work without w/h kwargs ──────── + + +def test_square_back_compat_mesh_2d_no_wrap(): + """Calling without w, h should still work for square world_size.""" + nbrs = mesh_2d_no_wrap(rank=0, world_size=4) + assert nbrs == {"S": 2, "E": 1}, nbrs + + +def test_square_back_compat_torus_2d(): + nbrs = torus_2d(rank=0, world_size=4) + assert nbrs == {"N": 2, "S": 2, "W": 1, "E": 1}, nbrs + + +# ── Validation: w*h must match world_size ──────────────────────────── + + +def test_rectangular_dims_must_match_world_size(): + """Phase 2 contract: explicit w, h must satisfy w*h == world_size.""" + with pytest.raises(ValueError): + mesh_2d_no_wrap(rank=0, world_size=6, w=3, h=3) # 9 != 6