"""Builtin neighbor topology generators for CCL backend (ADR-0023 D11). Each generator takes ``(rank, world_size)`` and returns a ``dict[direction, peer_rank]`` for that rank. ``direction`` is one of ``"N" | "S" | "E" | "W"`` for ring/mesh, or ``"parent" | "child_left" | "child_right"`` for tree topologies. Algorithm modules may override the generated map by defining a ``neighbors(rank, world_size, neighbor_map) -> dict | None`` function in the same module (see D11 / D15). ``resolve_topology`` wires these together. """ from __future__ import annotations from typing import Any, Callable NeighborMap = dict[str, int] TopologyFn = Callable[[int, int], NeighborMap] # ── Builtin generators ─────────────────────────────────────────────── def ring_1d(rank: int, world_size: int) -> NeighborMap: """1D bidirectional ring (E/W).""" return { "E": (rank + 1) % world_size, "W": (rank - 1) % world_size, } def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap: """1D unidirectional ring (E only).""" return {"E": (rank + 1) % world_size} def _resolve_2d_dims( world_size: int, w: int | None, h: int | None, name: str, ) -> tuple[int, int]: if w is not None and h is not None: if w * h != world_size: raise ValueError( f"{name}: w*h ({w}*{h}) != world_size ({world_size})" ) return w, h side = int(round(world_size ** 0.5)) if side * side != world_size: raise ValueError( f"{name} requires square world_size or explicit w,h, " f"got {world_size}" ) return side, side def mesh_2d( rank: int, world_size: int, w: int | None = None, h: int | None = None, ) -> NeighborMap: """2D mesh (N/S/E/W) with wrap-around on all four edges. Layout: rank = row * w + col. When w, h are given, supports rectangular (e.g. 2x3) layouts. Otherwise falls back to square side = sqrt(world_size). """ w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d") r, c = divmod(rank, w) return { "N": ((r - 1) % h) * w + c, "S": ((r + 1) % h) * w + c, "W": r * w + (c - 1) % w, "E": r * w + (c + 1) % w, } def tree_binary(rank: int, world_size: int) -> NeighborMap: """Binary tree rooted at rank 0. Children of rank r are 2r+1 and 2r+2 (if within world_size). Parent of rank r > 0 is (r-1)//2. Returned keys (only those that exist): "parent", "child_left", "child_right" """ n: NeighborMap = {} if rank > 0: n["parent"] = (rank - 1) // 2 left = 2 * rank + 1 right = 2 * rank + 2 if left < world_size: n["child_left"] = left if right < world_size: n["child_right"] = right return n def torus_2d( rank: int, world_size: int, w: int | None = None, h: int | None = None, ) -> NeighborMap: """2D torus (N/S/E/W) with wrap-around on all edges. Alias for mesh_2d.""" return mesh_2d(rank, world_size, w=w, h=h) def mesh_2d_no_wrap( rank: int, world_size: int, w: int | None = None, h: int | None = None, ) -> NeighborMap: """2D mesh (N/S/E/W) WITHOUT wrap-around. Supports rectangular dims.""" w, h = _resolve_2d_dims(world_size, w, h, "mesh_2d_no_wrap") r, c = divmod(rank, w) n: NeighborMap = {} if r > 0: n["N"] = (r - 1) * w + c if r < h - 1: n["S"] = (r + 1) * w + c if c > 0: n["W"] = r * w + (c - 1) if c < w - 1: n["E"] = r * w + (c + 1) return n def none(rank: int, world_size: int) -> NeighborMap: """Empty map — algorithm's neighbors() must build from scratch.""" return {} _BUILTIN: dict[str, TopologyFn] = { "ring_1d": ring_1d, "ring_1d_unidir": ring_1d_unidir, "mesh_2d": mesh_2d, "torus_2d": torus_2d, "mesh_2d_no_wrap": mesh_2d_no_wrap, "tree_binary": tree_binary, "none": none, } # ── Resolution ─────────────────────────────────────────────────────── def resolve_topology( name: str, algo_module: Any | None = None, ) -> TopologyFn: """Return a callable ``(rank, world_size) -> NeighborMap``. Args: name: builtin topology name from ccl.yaml. Must be one of ``ring_1d``, ``ring_1d_unidir``, ``mesh_2d``, ``tree_binary``, or ``none``. algo_module: optional algorithm module. If it defines ``neighbors(rank, world_size, neighbor_map)``, that hook is invoked after the builtin to override the result. Returning None from neighbors() leaves the builtin map unchanged; returning a dict replaces it. Raises: ValueError: if ``name`` is not a known builtin. """ if name not in _BUILTIN: raise ValueError( f"Unknown topology '{name}'. " f"Available builtins: {list(_BUILTIN)}" ) builtin_fn = _BUILTIN[name] override_fn = getattr(algo_module, "neighbors", None) if algo_module else None if override_fn is None or not callable(override_fn): return builtin_fn def _wrapped(rank: int, world_size: int) -> NeighborMap: base = builtin_fn(rank, world_size) result = override_fn(rank, world_size, base) if result is None: return base return result return _wrapped