"""CCL all-reduce bench — single unified entry point. Driven entirely by ``ccl.yaml`` + ``topology.yaml``: - ``defaults.algorithm`` in ``ccl.yaml`` picks which kernel to run (``ring_allreduce_{tcm,hbm,sram}`` / ``mesh_allreduce_4`` / ``tree_allreduce_7``). - ``world_size`` is derived from the algorithm entry's override or from the topology spec (``sips × cubes_per_sip × pes_per_cube``). - The host code uses only real PyTorch ``torch.distributed`` names: ``init_process_group``, ``get_world_size``, ``get_rank``, ``all_reduce``. The bench is split into ``worker(rank, world_size, torch)`` — the per-rank business logic, designed to look like a real PyTorch DDP training worker so future model benches can reuse the same skeleton — and ``run(torch)`` — the kernbench-specific launcher that initializes the process group and invokes the worker. """ from __future__ import annotations import numpy as np from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.policy.placement.dp import DPPolicy # Default per-rank tile size if ccl.yaml doesn't override it. Real # pytorch benches hardcode batch/feature dims similarly. DEFAULT_N_ELEM = 32 def _derive_dp(spec: dict, world_size: int) -> DPPolicy: """Pick a DPPolicy that fans the tensor across exactly ``world_size`` PEs. Mirrors what a real PyTorch DDP user does manually with ``tensor.to(f"cuda:{rank}")``: the host code chooses the placement so that the collective sees the right number of participating ranks. """ sips = int(spec["system"]["sips"]["count"]) cm = spec["sip"]["cube_mesh"] pl = spec["cube"]["pe_layout"] pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"]) cubes_per_sip = int(cm["w"]) * int(cm["h"]) total = sips * cubes_per_sip * pes_per_cube if world_size == total: return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") if world_size <= pes_per_cube: return DPPolicy( sip="replicate", cube="replicate", pe="column_wise", num_sips=1, num_cubes=1, num_pes=world_size, ) if world_size <= cubes_per_sip * pes_per_cube: return DPPolicy( sip="replicate", cube="column_wise", pe="column_wise", num_sips=1, num_cubes=world_size // pes_per_cube, ) return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") def worker(rank: int, world_size: int, torch) -> None: """Per-rank business logic. Mirrors a real PyTorch DDP worker. In real PyTorch DDP, this function runs in N separate processes, each with its own ``rank``. In kernbench (single-process multi-device) it is invoked once with ``rank=0`` on the single host driver; the actual per-PE parallelism is handled by ``torch.launch`` fanning out the kernel across all participating PEs via the tensor's DPPolicy. The ``rank`` parameter is therefore always 0 today, and is kept as an explicit argument for parity with real DDP workers (``if rank == 0`` logging guards, future multi-host extensions). """ cfg = resolve_algorithm_config(load_ccl_config()) algo_name = cfg["algorithm"] n_elem = int(cfg.get("n_elem", DEFAULT_N_ELEM)) # Pick a DP that produces exactly ``world_size`` shards on this topology. dp = _derive_dp(torch.spec, world_size) tensor = torch.zeros( (1, world_size * n_elem), dtype="f16", dp=dp, name="ccl_in", ) # Initialize: CCL rank r's slice gets value (r + 1). Real PyTorch idiom: # target.copy_(torch.from_numpy(source)) init = np.zeros((1, world_size * n_elem), dtype=np.float16) for r in range(world_size): init[0, r * n_elem : (r + 1) * n_elem] = float(r + 1) tensor.copy_(torch.from_numpy(init)) # The main act: one all_reduce call — the backend installs IPCQ at # init_process_group time and here only dispatches the kernel. torch.distributed.all_reduce(tensor, op="sum") # Verify: each shard should hold sum(1..world_size) after all-reduce. result = tensor.numpy() expected = float(sum(range(1, world_size + 1))) all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1)) # Print only on rank 0 — real PyTorch DDP idiom for single-source logs. if rank == 0: if all_ok: print(f" {algo_name} (ws={world_size}): {world_size} OK") else: flat = result.reshape(-1) n_fail = 0 for r in range(world_size): slice_r = flat[r * n_elem : (r + 1) * n_elem] if not np.allclose(slice_r, expected, rtol=1e-1, atol=1e-1): n_fail += 1 if n_fail <= 5: print( f" [FAIL] rank {r} " f"(ws={world_size}, algo={algo_name}): " f"got mean={float(slice_r.mean()):.3f}, " f"expected={expected:.3f}" ) print( f" {algo_name} (ws={world_size}): " f"{world_size - n_fail} OK / {n_fail} FAIL" ) def run(torch) -> None: """CLI entry point: initialize the process group, invoke worker.""" dist = torch.distributed dist.init_process_group(backend="ahbm") worker( rank=dist.get_rank(), world_size=dist.get_world_size(), torch=torch, )