"""CCL all-reduce bench (ADR-0024 + ADR-0027). Pure TP launcher model: rank = SIP. Each rank owns a ``(N_CUBES, n_elem)`` tensor sharded row-wise across the cube mesh (pe0 per cube). After ``dist.all_reduce(op="sum")`` every cube on every rank must hold ``N_CUBES * sum(1..world_size)``. Rank 0 prints the pass/fail line. Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml`` (SIP count → world_size, cube_mesh → N_CUBES). """ from __future__ import annotations from dataclasses import dataclass import numpy as np from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.policy.placement.dp import DPPolicy DEFAULT_N_ELEM = 8 @dataclass(frozen=True) class _BenchCfg: algorithm: str n_elem: int n_cubes: int world_size: int def _resolve_cfg(torch) -> _BenchCfg: """Read ccl.yaml + topology once at host side.""" merged = resolve_algorithm_config(load_ccl_config()) ws = torch.distributed.get_world_size() spec = torch.spec or {} n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) if ws != n_sips: raise RuntimeError( f"ccl_allreduce bench requires world_size == topology SIP count " f"(world_size={ws}, n_sips={n_sips})." ) cm = spec.get("sip", {}).get("cube_mesh", {}) n_cubes = int(cm.get("w", 4)) * int(cm.get("h", 4)) return _BenchCfg( algorithm=merged["algorithm"], n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)), n_cubes=n_cubes, world_size=ws, ) def _rank_dp(n_cubes: int) -> DPPolicy: return DPPolicy(cube="row_wise", pe="replicate", num_cubes=n_cubes, num_pes=1) def _allocate_rank_tensor(torch, rank: int, cfg: _BenchCfg): """Allocate this rank's ``(n_cubes, n_elem)`` tensor on its SIP.""" return torch.zeros( (cfg.n_cubes, cfg.n_elem), dtype="f16", dp=_rank_dp(cfg.n_cubes), name=f"ccl_in_r{rank}", ) def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None: """Fill all cubes with the scalar ``rank + 1``.""" arr = np.full((cfg.n_cubes, cfg.n_elem), float(rank + 1), dtype=np.float16) tensor.copy_(torch.from_numpy(arr)) def _report(result: np.ndarray, cfg: _BenchCfg) -> None: """Single-line pass/fail printer (rank 0 only).""" expected = float(cfg.n_cubes * sum(range(1, cfg.world_size + 1))) ok = True for cube_id in range(cfg.n_cubes): if not np.allclose(result[cube_id], expected, rtol=1e-1, atol=1e-1): ok = False break if ok: total = cfg.world_size * cfg.n_cubes print(f" {cfg.algorithm} (ws={cfg.world_size}): {total} OK") return got = float(result.reshape(-1).mean()) print( f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): " f"got mean={got:.3f}, expected={expected:.3f}" ) def _worker(rank: int, cfg: _BenchCfg, torch) -> None: torch.ahbm.set_device(rank) tensor = _allocate_rank_tensor(torch, rank, cfg) _init_with_rank_value(torch, tensor, rank, cfg) torch.distributed.all_reduce(tensor, op="sum") if rank == 0: _report(tensor.numpy(), cfg) def run(torch) -> None: torch.distributed.init_process_group(backend="ahbm") cfg = _resolve_cfg(torch) torch.multiprocessing.spawn( _worker, args=(cfg, torch), nprocs=cfg.world_size, )