"""CCL all-reduce bench (ADR-0024 Phase A). Driven entirely by ``ccl.yaml`` + ``topology.yaml``: - ``defaults.algorithm`` in ``ccl.yaml`` picks which kernel to run. - ``world_size`` resolution: explicit override in ccl.yaml > defaults > topology's SIP count. ADR-0024 D1: topology fallback is the SIP count (each rank = one SIP, TP boundary). - ``run()`` is hybrid: - If ``world_size == topology SIP count`` (the intended new path): spawn one greenlet per rank, bind it via ``dist._bind_rank``, and each worker calls ``torch.ahbm.set_device(rank)`` + runs its portion of the collective. Cross-rank IPCQ exchange handles the reduce. - Legacy path (``world_size > SIP count``, via explicit ccl.yaml override): single worker at rank 0 with the full tensor distributed across all participating PEs via ``_derive_dp``. Retained for backward compatibility with existing kernel / topology tests. """ from __future__ import annotations import numpy as np from greenlet import greenlet from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.policy.placement.dp import DPPolicy # Default per-rank tile size if ccl.yaml doesn't override it. DEFAULT_N_ELEM = 32 def _derive_dp(spec: dict, world_size: int) -> DPPolicy: """Legacy DPPolicy for world_size > SIP count (rank = flat PE index). Used only in the ccl.yaml-override path so the existing matrix tests with explicit world_size (8, 16, 7 etc.) keep working. The new ADR-0024 TP path (rank = SIP) uses a per-rank DPPolicy inside the worker instead. """ sips = int(spec["system"]["sips"]["count"]) cm = spec["sip"]["cube_mesh"] pl = spec["cube"]["pe_layout"] pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"]) cubes_per_sip = int(cm["w"]) * int(cm["h"]) total = sips * cubes_per_sip * pes_per_cube if world_size == total: return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") if world_size <= pes_per_cube: return DPPolicy( sip="replicate", cube="replicate", pe="column_wise", num_sips=1, num_cubes=1, num_pes=world_size, ) if world_size <= cubes_per_sip * pes_per_cube: return DPPolicy( sip="replicate", cube="column_wise", pe="column_wise", num_sips=1, num_cubes=world_size // pes_per_cube, ) return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") def worker(rank: int, world_size: int, torch) -> None: """Per-rank worker (new TP path) OR single-worker legacy driver. Behaviour depends on whether this call originates from the multi-greenlet launcher (new path) or from the legacy single-call fallback; distinguished by which ``dp`` layout applies. """ cfg = resolve_algorithm_config(load_ccl_config()) algo_name = cfg["algorithm"] n_elem = int(cfg.get("n_elem", DEFAULT_N_ELEM)) spec = torch.spec or {} n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) if world_size == n_sips: # ADR-0024 new path: rank = SIP, worker sees its SIP's # representative PE via torch.ahbm.set_device. torch.ahbm.set_device(rank) dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) tensor = torch.zeros( (1, n_elem), dtype="f16", dp=dp, name=f"ccl_in_r{rank}", ) # Each rank initialises its tile with (rank + 1); after all_reduce # every rank sees sum(1..world_size). init = np.full((1, n_elem), float(rank + 1), dtype=np.float16) tensor.copy_(torch.from_numpy(init)) torch.distributed.all_reduce(tensor, op="sum") result = tensor.numpy() expected = float(sum(range(1, world_size + 1))) all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1)) if rank == 0: if all_ok: print(f" {algo_name} (ws={world_size}): {world_size} OK") else: print( f" [FAIL] rank {rank} " f"(ws={world_size}, algo={algo_name}): " f"got mean={float(result.reshape(-1).mean()):.3f}, " f"expected={expected:.3f}" ) print( f" {algo_name} (ws={world_size}): " f"0 OK / {world_size} FAIL" ) return # Legacy path: world_size overridden via ccl.yaml to exceed SIP count. # Single-worker at rank 0; whole tensor distributed across all # participating PEs using the derived DPPolicy. Matches pre-ADR-0024 # behaviour. dp = _derive_dp(spec, world_size) tensor = torch.zeros( (1, world_size * n_elem), dtype="f16", dp=dp, name="ccl_in", ) init = np.zeros((1, world_size * n_elem), dtype=np.float16) for r in range(world_size): init[0, r * n_elem : (r + 1) * n_elem] = float(r + 1) tensor.copy_(torch.from_numpy(init)) torch.distributed.all_reduce(tensor, op="sum") result = tensor.numpy() expected = float(sum(range(1, world_size + 1))) all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1)) if rank == 0: if all_ok: print(f" {algo_name} (ws={world_size}): {world_size} OK") else: flat = result.reshape(-1) n_fail = 0 for r in range(world_size): slice_r = flat[r * n_elem : (r + 1) * n_elem] if not np.allclose(slice_r, expected, rtol=1e-1, atol=1e-1): n_fail += 1 if n_fail <= 5: print( f" [FAIL] rank {r} " f"(ws={world_size}, algo={algo_name}): " f"got mean={float(slice_r.mean()):.3f}, " f"expected={expected:.3f}" ) print( f" {algo_name} (ws={world_size}): " f"{world_size - n_fail} OK / {n_fail} FAIL" ) def run(torch) -> None: """CLI entry — dispatch to multi-greenlet path when ws == SIP count, else fall back to single-worker legacy path for ccl.yaml override compat. """ dist = torch.distributed dist.init_process_group(backend="ahbm") world_size = dist.get_world_size() spec = torch.spec or {} n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) if world_size == n_sips: # ADR-0024 D12/D13: one greenlet per rank. After each scheduler # round, the main greenlet drains any pending collective handles # (ADR-0024 D7) — this must happen in the main context, not inside # a worker, so env.run is invoked with main as the current greenlet # and kernel_runner's spawned kernel greenlets correctly get main # as their parent. backend = dist._backend gs: list[greenlet] = [] for rank in range(world_size): def _entry(r: int = rank) -> None: worker(r, world_size, torch) g = greenlet(_entry) dist._bind_rank(g, rank) gs.append(g) while True: alive = [g for g in gs if not g.dead] if not alive: break for g in alive: if not g.dead: g.switch() # Drain pending collective handles. All sibling workers have # either submitted (and yielded) or completed; their kernels # are live in the SimPy queue, ready to exchange via IPCQ. pending = backend._pending_collective_handles if pending: for h, _sip_id, meta in pending: torch.wait(h, _meta=meta) backend._pending_collective_handles = [] else: # Legacy single-worker path (ccl.yaml world_size override). worker(rank=dist.get_rank(), world_size=world_size, torch=torch)