e9cc40f74d
mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs; sqrt fall-back preserved for square layouts (back-compat tests confirm 4-SIP and 9-SIP square configs still work). sfr_config reads system.sips.w/h from spec and threads dims through to the topology fn. test_allreduce_multidevice CONFIGS switched from 4 SIPs (square) to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3. _write_temp_configs writes system.sips.w/h when supplied; _sip_topo_dims reads them back. Latency sweep loop also moved to 6-SIP layouts. Linear-scale plot variants dropped -- only log-scale *.png + summary.csv emitted. Plots in tests/allreduce_latency_plots regenerated. New tests/test_sip_topology_rectangular.py asserts neighbor correctness for 2x3 layouts and back-compat for square fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
176 lines
5.4 KiB
Python
176 lines
5.4 KiB
Python
"""Builtin neighbor topology generators for CCL backend (ADR-0023 D11).
|
|
|
|
Each generator takes ``(rank, world_size)`` and returns a
|
|
``dict[direction, peer_rank]`` for that rank. ``direction`` is one of
|
|
``"N" | "S" | "E" | "W"`` for ring/mesh, or
|
|
``"parent" | "child_left" | "child_right"`` for tree topologies.
|
|
|
|
Algorithm modules may override the generated map by defining a
|
|
``neighbors(rank, world_size, neighbor_map) -> dict | None`` function in
|
|
the same module (see D11 / D15). ``resolve_topology`` wires these together.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable
|
|
|
|
NeighborMap = dict[str, int]
|
|
TopologyFn = Callable[[int, int], NeighborMap]
|
|
|
|
|
|
# ── Builtin generators ───────────────────────────────────────────────
|
|
|
|
|
|
def ring_1d(rank: int, world_size: int) -> NeighborMap:
|
|
"""1D bidirectional ring (E/W)."""
|
|
return {
|
|
"E": (rank + 1) % world_size,
|
|
"W": (rank - 1) % world_size,
|
|
}
|
|
|
|
|
|
def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap:
|
|
"""1D unidirectional ring (E only)."""
|
|
return {"E": (rank + 1) % world_size}
|
|
|
|
|
|
def _resolve_2d_dims(
|
|
world_size: int, w: int | None, h: int | None, name: str,
|
|
) -> tuple[int, int]:
|
|
if w is not None and h is not None:
|
|
if w * h != world_size:
|
|
raise ValueError(
|
|
f"{name}: w*h ({w}*{h}) != world_size ({world_size})"
|
|
)
|
|
return w, h
|
|
side = int(round(world_size ** 0.5))
|
|
if side * side != world_size:
|
|
raise ValueError(
|
|
f"{name} requires square world_size or explicit w,h, "
|
|
f"got {world_size}"
|
|
)
|
|
return side, side
|
|
|
|
|
|
def mesh_2d(
|
|
rank: int, world_size: int,
|
|
w: int | None = None, h: int | None = None,
|
|
) -> NeighborMap:
|
|
"""2D mesh (N/S/E/W) with wrap-around on all four edges.
|
|
|
|
Layout: rank = row * w + col. When w, h are given, supports
|
|
rectangular (e.g. 2x3) layouts. Otherwise falls back to square
|
|
side = sqrt(world_size).
|
|
"""
|
|
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d")
|
|
r, c = divmod(rank, w)
|
|
return {
|
|
"N": ((r - 1) % h) * w + c,
|
|
"S": ((r + 1) % h) * w + c,
|
|
"W": r * w + (c - 1) % w,
|
|
"E": r * w + (c + 1) % w,
|
|
}
|
|
|
|
|
|
def tree_binary(rank: int, world_size: int) -> NeighborMap:
|
|
"""Binary tree rooted at rank 0.
|
|
|
|
Children of rank r are 2r+1 and 2r+2 (if within world_size).
|
|
Parent of rank r > 0 is (r-1)//2.
|
|
Returned keys (only those that exist):
|
|
"parent", "child_left", "child_right"
|
|
"""
|
|
n: NeighborMap = {}
|
|
if rank > 0:
|
|
n["parent"] = (rank - 1) // 2
|
|
left = 2 * rank + 1
|
|
right = 2 * rank + 2
|
|
if left < world_size:
|
|
n["child_left"] = left
|
|
if right < world_size:
|
|
n["child_right"] = right
|
|
return n
|
|
|
|
|
|
def torus_2d(
|
|
rank: int, world_size: int,
|
|
w: int | None = None, h: int | None = None,
|
|
) -> NeighborMap:
|
|
"""2D torus (N/S/E/W) with wrap-around on all edges. Alias for mesh_2d."""
|
|
return mesh_2d(rank, world_size, w=w, h=h)
|
|
|
|
|
|
def mesh_2d_no_wrap(
|
|
rank: int, world_size: int,
|
|
w: int | None = None, h: int | None = None,
|
|
) -> NeighborMap:
|
|
"""2D mesh (N/S/E/W) WITHOUT wrap-around. Supports rectangular dims."""
|
|
w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d_no_wrap")
|
|
r, c = divmod(rank, w)
|
|
n: NeighborMap = {}
|
|
if r > 0:
|
|
n["N"] = (r - 1) * w + c
|
|
if r < h - 1:
|
|
n["S"] = (r + 1) * w + c
|
|
if c > 0:
|
|
n["W"] = r * w + (c - 1)
|
|
if c < w - 1:
|
|
n["E"] = r * w + (c + 1)
|
|
return n
|
|
|
|
|
|
def none(rank: int, world_size: int) -> NeighborMap:
|
|
"""Empty map — algorithm's neighbors() must build from scratch."""
|
|
return {}
|
|
|
|
|
|
_BUILTIN: dict[str, TopologyFn] = {
|
|
"ring_1d": ring_1d,
|
|
"ring_1d_unidir": ring_1d_unidir,
|
|
"mesh_2d": mesh_2d,
|
|
"torus_2d": torus_2d,
|
|
"mesh_2d_no_wrap": mesh_2d_no_wrap,
|
|
"tree_binary": tree_binary,
|
|
"none": none,
|
|
}
|
|
|
|
|
|
# ── Resolution ───────────────────────────────────────────────────────
|
|
|
|
|
|
def resolve_topology(
|
|
name: str, algo_module: Any | None = None,
|
|
) -> TopologyFn:
|
|
"""Return a callable ``(rank, world_size) -> NeighborMap``.
|
|
|
|
Args:
|
|
name: builtin topology name from ccl.yaml. Must be one of
|
|
``ring_1d``, ``ring_1d_unidir``, ``mesh_2d``, ``tree_binary``,
|
|
or ``none``.
|
|
algo_module: optional algorithm module. If it defines
|
|
``neighbors(rank, world_size, neighbor_map)``, that hook is
|
|
invoked after the builtin to override the result.
|
|
Returning None from neighbors() leaves the builtin map
|
|
unchanged; returning a dict replaces it.
|
|
|
|
Raises:
|
|
ValueError: if ``name`` is not a known builtin.
|
|
"""
|
|
if name not in _BUILTIN:
|
|
raise ValueError(
|
|
f"Unknown topology '{name}'. "
|
|
f"Available builtins: {list(_BUILTIN)}"
|
|
)
|
|
builtin_fn = _BUILTIN[name]
|
|
override_fn = getattr(algo_module, "neighbors", None) if algo_module else None
|
|
if override_fn is None or not callable(override_fn):
|
|
return builtin_fn
|
|
|
|
def _wrapped(rank: int, world_size: int) -> NeighborMap:
|
|
base = builtin_fn(rank, world_size)
|
|
result = override_fn(rank, world_size, base)
|
|
if result is None:
|
|
return base
|
|
return result
|
|
|
|
return _wrapped
|