Files
kernbench2/src/kernbench/ccl/topologies.py
T
mukesh e9cc40f74d Rectangular SIP topology + 6-device allreduce sweep
mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs;
sqrt fall-back preserved for square layouts (back-compat tests
confirm 4-SIP and 9-SIP square configs still work). sfr_config
reads system.sips.w/h from spec and threads dims through to the
topology fn.

test_allreduce_multidevice CONFIGS switched from 4 SIPs (square)
to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3.
_write_temp_configs writes system.sips.w/h when supplied;
_sip_topo_dims reads them back. Latency sweep loop also moved to
6-SIP layouts. Linear-scale plot variants dropped -- only log-scale
*.png + summary.csv emitted. Plots in tests/allreduce_latency_plots
regenerated.

New tests/test_sip_topology_rectangular.py asserts neighbor
correctness for 2x3 layouts and back-compat for square fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 15:13:14 -07:00

176 lines
5.4 KiB
Python

"""Builtin neighbor topology generators for CCL backend (ADR-0023 D11).
Each generator takes ``(rank, world_size)`` and returns a
``dict[direction, peer_rank]`` for that rank. ``direction`` is one of
``"N" | "S" | "E" | "W"`` for ring/mesh, or
``"parent" | "child_left" | "child_right"`` for tree topologies.
Algorithm modules may override the generated map by defining a
``neighbors(rank, world_size, neighbor_map) -> dict | None`` function in
the same module (see D11 / D15). ``resolve_topology`` wires these together.
"""
from __future__ import annotations
from typing import Any, Callable
NeighborMap = dict[str, int]
TopologyFn = Callable[[int, int], NeighborMap]
# ── Builtin generators ───────────────────────────────────────────────
def ring_1d(rank: int, world_size: int) -> NeighborMap:
"""1D bidirectional ring (E/W)."""
return {
"E": (rank + 1) % world_size,
"W": (rank - 1) % world_size,
}
def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap:
"""1D unidirectional ring (E only)."""
return {"E": (rank + 1) % world_size}
def _resolve_2d_dims(
world_size: int, w: int | None, h: int | None, name: str,
) -> tuple[int, int]:
if w is not None and h is not None:
if w * h != world_size:
raise ValueError(
f"{name}: w*h ({w}*{h}) != world_size ({world_size})"
)
return w, h
side = int(round(world_size ** 0.5))
if side * side != world_size:
raise ValueError(
f"{name} requires square world_size or explicit w,h, "
f"got {world_size}"
)
return side, side
def mesh_2d(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D mesh (N/S/E/W) with wrap-around on all four edges.
Layout: rank = row * w + col. When w, h are given, supports
rectangular (e.g. 2x3) layouts. Otherwise falls back to square
side = sqrt(world_size).
"""
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d")
r, c = divmod(rank, w)
return {
"N": ((r - 1) % h) * w + c,
"S": ((r + 1) % h) * w + c,
"W": r * w + (c - 1) % w,
"E": r * w + (c + 1) % w,
}
def tree_binary(rank: int, world_size: int) -> NeighborMap:
"""Binary tree rooted at rank 0.
Children of rank r are 2r+1 and 2r+2 (if within world_size).
Parent of rank r > 0 is (r-1)//2.
Returned keys (only those that exist):
"parent", "child_left", "child_right"
"""
n: NeighborMap = {}
if rank > 0:
n["parent"] = (rank - 1) // 2
left = 2 * rank + 1
right = 2 * rank + 2
if left < world_size:
n["child_left"] = left
if right < world_size:
n["child_right"] = right
return n
def torus_2d(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D torus (N/S/E/W) with wrap-around on all edges. Alias for mesh_2d."""
return mesh_2d(rank, world_size, w=w, h=h)
def mesh_2d_no_wrap(
rank: int, world_size: int,
w: int | None = None, h: int | None = None,
) -> NeighborMap:
"""2D mesh (N/S/E/W) WITHOUT wrap-around. Supports rectangular dims."""
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d_no_wrap")
r, c = divmod(rank, w)
n: NeighborMap = {}
if r > 0:
n["N"] = (r - 1) * w + c
if r < h - 1:
n["S"] = (r + 1) * w + c
if c > 0:
n["W"] = r * w + (c - 1)
if c < w - 1:
n["E"] = r * w + (c + 1)
return n
def none(rank: int, world_size: int) -> NeighborMap:
"""Empty map — algorithm's neighbors() must build from scratch."""
return {}
_BUILTIN: dict[str, TopologyFn] = {
"ring_1d": ring_1d,
"ring_1d_unidir": ring_1d_unidir,
"mesh_2d": mesh_2d,
"torus_2d": torus_2d,
"mesh_2d_no_wrap": mesh_2d_no_wrap,
"tree_binary": tree_binary,
"none": none,
}
# ── Resolution ───────────────────────────────────────────────────────
def resolve_topology(
name: str, algo_module: Any | None = None,
) -> TopologyFn:
"""Return a callable ``(rank, world_size) -> NeighborMap``.
Args:
name: builtin topology name from ccl.yaml. Must be one of
``ring_1d``, ``ring_1d_unidir``, ``mesh_2d``, ``tree_binary``,
or ``none``.
algo_module: optional algorithm module. If it defines
``neighbors(rank, world_size, neighbor_map)``, that hook is
invoked after the builtin to override the result.
Returning None from neighbors() leaves the builtin map
unchanged; returning a dict replaces it.
Raises:
ValueError: if ``name`` is not a known builtin.
"""
if name not in _BUILTIN:
raise ValueError(
f"Unknown topology '{name}'. "
f"Available builtins: {list(_BUILTIN)}"
)
builtin_fn = _BUILTIN[name]
override_fn = getattr(algo_module, "neighbors", None) if algo_module else None
if override_fn is None or not callable(override_fn):
return builtin_fn
def _wrapped(rank: int, world_size: int) -> NeighborMap:
base = builtin_fn(rank, world_size)
result = override_fn(rank, world_size, base)
if result is None:
return base
return result
return _wrapped