Files
kernbench2/benches/ccl_allreduce.py
T
ywkang cfc2d74ec4 Refactor ccl_allreduce bench: rank=SIP only, remove rank=PE legacy path
The unified ccl_allreduce bench previously carried two execution models
in one worker with ``if world_size == n_sips:`` branching:
  - TP mode (rank = SIP, ADR-0024/0027): proper ProcessGroup semantics.
  - Legacy rank = PE mode: single-driver worker allocating one big tensor
    distributed across all PEs via _derive_dp, with kernel-level SPMD via
    program_id.

The second model is unnecessary — intra-SIP PE-level collectives are
expressed inside the kernel (tl.send/tl.recv with program_id, IPCQ) and
do not need a host-side ProcessGroup. Removing it lets the bench be a
clean reference implementation of the TP launcher.

benches/ccl_allreduce.py:
- Config resolved once in run() via _resolve_cfg -> _BenchCfg dataclass.
- rank != n_sips now raises RuntimeError explicitly.
- _worker / _allocate_rank_tile / _init_with_rank_value / _report each
  have one concern; duplicated init + verification paths collapsed.
- _derive_dp and the second verify+print block deleted.
- 166 lines -> 91 lines.

ccl.yaml:
- mesh_allreduce_4 (world_size: 4) and tree_allreduce_7 (world_size: 7)
  algorithm entries removed (rank = PE only).
- Algorithm kernel files (kernbench.ccl.algorithms.mesh_allreduce,
  tree_allreduce) kept as-is for direct-dispatch future use.

tests/test_ccl_allreduce_matrix.py:
- Matrix shrinks from 7 cases to 3: ring × {tcm, hbm, sram} at ws =
  topology SIP count (= 2). mesh_2x2, tree_binary_7, ring_multi_cube,
  and the three ring_*_8 cases removed.

tests/test_ccl_performance.py:
- _run_8rank renamed to _run_ring; world_size: 8 override dropped; now
  exercises rank = SIP ring all-reduce.

tests/test_mp_spawn.py, tests/test_ccl_ddp_launcher.py:
- Monkeypatch target updated from bench.worker to bench._worker
  (signature now takes BenchCfg instead of (rank, world_size)).

555 passed, 1 intentional skip. Tests that directly call
install_ipcq(world_size_override=N) for kernel-level sanity
(test_ccl_hello_world_guide, test_recv_copy_to_dst, test_tl_recv_async,
test_ccl_deadlock_detection) are unchanged — they never went through
the bench and still exercise the kernel-only path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 16:45:27 -07:00

103 lines
3.4 KiB
Python

"""CCL all-reduce bench (ADR-0024 + ADR-0027).
Pure TP launcher model: rank = SIP. Each rank owns a ``(1, n_elem)`` tile
initialised to ``rank + 1``; after ``dist.all_reduce(op="sum")`` every rank
must see ``sum(1..world_size)``. Rank 0 prints the pass/fail line.
Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml``
(SIP count → world_size).
Legacy ``rank = PE`` single-driver path was removed — intra-SIP PE-level
collective is expressed by the kernel itself via ``tl.program_id`` and
does not need a host-side ``ProcessGroup``.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.policy.placement.dp import DPPolicy
DEFAULT_N_ELEM = 32
@dataclass(frozen=True)
class _BenchCfg:
algorithm: str
n_elem: int
world_size: int
def _resolve_cfg(torch) -> _BenchCfg:
"""Read ccl.yaml once at host side; enforce rank = SIP contract."""
merged = resolve_algorithm_config(load_ccl_config())
ws = torch.distributed.get_world_size()
spec = torch.spec or {}
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
if ws != n_sips:
raise RuntimeError(
f"ccl_allreduce bench requires world_size == topology SIP count "
f"(world_size={ws}, n_sips={n_sips}). rank = PE mode was removed "
f"(intra-SIP collectives are expressed inside the kernel)."
)
return _BenchCfg(
algorithm=merged["algorithm"],
n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)),
world_size=ws,
)
def _rank_local_dp() -> DPPolicy:
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
def _allocate_rank_tile(torch, rank: int, cfg: _BenchCfg):
"""Allocate this rank's ``(1, n_elem)`` tile on its SIP."""
return torch.zeros(
(1, cfg.n_elem), dtype="f16",
dp=_rank_local_dp(), name=f"ccl_in_r{rank}",
)
def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None:
"""Fill the tile with the scalar ``rank + 1`` (deterministic + easy to verify)."""
arr = np.full((1, cfg.n_elem), float(rank + 1), dtype=np.float16)
tensor.copy_(torch.from_numpy(arr))
def _report(result: np.ndarray, cfg: _BenchCfg) -> None:
"""Single-line pass/fail printer (rank 0 only, called after all_reduce)."""
expected = float(sum(range(1, cfg.world_size + 1)))
ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
if ok:
print(f" {cfg.algorithm} (ws={cfg.world_size}): {cfg.world_size} OK")
return
got = float(result.reshape(-1).mean())
print(
f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): "
f"got mean={got:.3f}, expected={expected:.3f}"
)
print(
f" {cfg.algorithm} (ws={cfg.world_size}): "
f"0 OK / {cfg.world_size} FAIL"
)
def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
torch.ahbm.set_device(rank)
tensor = _allocate_rank_tile(torch, rank, cfg)
_init_with_rank_value(torch, tensor, rank, cfg)
torch.distributed.all_reduce(tensor, op="sum")
if rank == 0:
_report(tensor.numpy(), cfg)
def run(torch) -> None:
torch.distributed.init_process_group(backend="ahbm")
cfg = _resolve_cfg(torch)
torch.multiprocessing.spawn(
_worker, args=(cfg, torch), nprocs=cfg.world_size,
)