Refactor ccl_allreduce bench: rank=SIP only, remove rank=PE legacy path
The unified ccl_allreduce bench previously carried two execution models
in one worker with ``if world_size == n_sips:`` branching:
- TP mode (rank = SIP, ADR-0024/0027): proper ProcessGroup semantics.
- Legacy rank = PE mode: single-driver worker allocating one big tensor
distributed across all PEs via _derive_dp, with kernel-level SPMD via
program_id.
The second model is unnecessary — intra-SIP PE-level collectives are
expressed inside the kernel (tl.send/tl.recv with program_id, IPCQ) and
do not need a host-side ProcessGroup. Removing it lets the bench be a
clean reference implementation of the TP launcher.
benches/ccl_allreduce.py:
- Config resolved once in run() via _resolve_cfg -> _BenchCfg dataclass.
- rank != n_sips now raises RuntimeError explicitly.
- _worker / _allocate_rank_tile / _init_with_rank_value / _report each
have one concern; duplicated init + verification paths collapsed.
- _derive_dp and the second verify+print block deleted.
- 166 lines -> 91 lines.
ccl.yaml:
- mesh_allreduce_4 (world_size: 4) and tree_allreduce_7 (world_size: 7)
algorithm entries removed (rank = PE only).
- Algorithm kernel files (kernbench.ccl.algorithms.mesh_allreduce,
tree_allreduce) kept as-is for direct-dispatch future use.
tests/test_ccl_allreduce_matrix.py:
- Matrix shrinks from 7 cases to 3: ring × {tcm, hbm, sram} at ws =
topology SIP count (= 2). mesh_2x2, tree_binary_7, ring_multi_cube,
and the three ring_*_8 cases removed.
tests/test_ccl_performance.py:
- _run_8rank renamed to _run_ring; world_size: 8 override dropped; now
exercises rank = SIP ring all-reduce.
tests/test_mp_spawn.py, tests/test_ccl_ddp_launcher.py:
- Monkeypatch target updated from bench.worker to bench._worker
(signature now takes BenchCfg instead of (rank, world_size)).
555 passed, 1 intentional skip. Tests that directly call
install_ipcq(world_size_override=N) for kernel-level sanity
(test_ccl_hello_world_guide, test_recv_copy_to_dst, test_tl_recv_async,
test_ccl_deadlock_detection) are unchanged — they never went through
the bench and still exercise the kernel-only path.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+78
-141
@@ -1,165 +1,102 @@
|
||||
"""CCL all-reduce bench (ADR-0024 Phase A).
|
||||
"""CCL all-reduce bench (ADR-0024 + ADR-0027).
|
||||
|
||||
Driven entirely by ``ccl.yaml`` + ``topology.yaml``:
|
||||
Pure TP launcher model: rank = SIP. Each rank owns a ``(1, n_elem)`` tile
|
||||
initialised to ``rank + 1``; after ``dist.all_reduce(op="sum")`` every rank
|
||||
must see ``sum(1..world_size)``. Rank 0 prints the pass/fail line.
|
||||
|
||||
- ``defaults.algorithm`` in ``ccl.yaml`` picks which kernel to run.
|
||||
- ``world_size`` resolution: explicit override in ccl.yaml > defaults >
|
||||
topology's SIP count. ADR-0024 D1: topology fallback is the SIP count
|
||||
(each rank = one SIP, TP boundary).
|
||||
- ``run()`` is hybrid:
|
||||
- If ``world_size == topology SIP count`` (the intended new path):
|
||||
spawn one greenlet per rank, bind it via ``dist._bind_rank``, and
|
||||
each worker calls ``torch.ahbm.set_device(rank)`` + runs its portion
|
||||
of the collective. Cross-rank IPCQ exchange handles the reduce.
|
||||
- Legacy path (``world_size > SIP count``, via explicit ccl.yaml
|
||||
override): single worker at rank 0 with the full tensor distributed
|
||||
across all participating PEs via ``_derive_dp``. Retained for
|
||||
backward compatibility with existing kernel / topology tests.
|
||||
Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml``
|
||||
(SIP count → world_size).
|
||||
|
||||
Legacy ``rank = PE`` single-driver path was removed — intra-SIP PE-level
|
||||
collective is expressed by the kernel itself via ``tl.program_id`` and
|
||||
does not need a host-side ``ProcessGroup``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# Default per-rank tile size if ccl.yaml doesn't override it.
|
||||
DEFAULT_N_ELEM = 32
|
||||
|
||||
|
||||
def _derive_dp(spec: dict, world_size: int) -> DPPolicy:
|
||||
"""Legacy DPPolicy for world_size > SIP count (rank = flat PE index).
|
||||
|
||||
Used only in the ccl.yaml-override path so the existing matrix tests
|
||||
with explicit world_size (8, 16, 7 etc.) keep working. ADR-0026:
|
||||
DPPolicy is intra-device only, so this legacy path now always stays
|
||||
within a single SIP and distributes the override world_size across
|
||||
that SIP's cubes and PEs.
|
||||
"""
|
||||
pl = spec["cube"]["pe_layout"]
|
||||
pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"])
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
cubes_per_sip = int(cm["w"]) * int(cm["h"])
|
||||
if world_size <= pes_per_cube:
|
||||
return DPPolicy(
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1, num_pes=world_size,
|
||||
)
|
||||
if world_size <= cubes_per_sip * pes_per_cube:
|
||||
return DPPolicy(
|
||||
cube="column_wise", pe="column_wise",
|
||||
num_cubes=world_size // pes_per_cube,
|
||||
)
|
||||
return DPPolicy(cube="column_wise", pe="column_wise")
|
||||
@dataclass(frozen=True)
|
||||
class _BenchCfg:
|
||||
algorithm: str
|
||||
n_elem: int
|
||||
world_size: int
|
||||
|
||||
|
||||
def worker(rank: int, world_size: int, torch) -> None:
|
||||
"""Per-rank worker (new TP path) OR single-worker legacy driver.
|
||||
|
||||
Behaviour depends on whether this call originates from the
|
||||
multi-greenlet launcher (new path) or from the legacy single-call
|
||||
fallback; distinguished by which ``dp`` layout applies.
|
||||
"""
|
||||
cfg = resolve_algorithm_config(load_ccl_config())
|
||||
algo_name = cfg["algorithm"]
|
||||
n_elem = int(cfg.get("n_elem", DEFAULT_N_ELEM))
|
||||
|
||||
def _resolve_cfg(torch) -> _BenchCfg:
|
||||
"""Read ccl.yaml once at host side; enforce rank = SIP contract."""
|
||||
merged = resolve_algorithm_config(load_ccl_config())
|
||||
ws = torch.distributed.get_world_size()
|
||||
spec = torch.spec or {}
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
|
||||
if world_size == n_sips:
|
||||
# ADR-0024 new path: rank = SIP, worker sees its SIP's
|
||||
# representative PE via torch.ahbm.set_device.
|
||||
torch.ahbm.set_device(rank)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1)
|
||||
tensor = torch.zeros(
|
||||
(1, n_elem), dtype="f16", dp=dp, name=f"ccl_in_r{rank}",
|
||||
if ws != n_sips:
|
||||
raise RuntimeError(
|
||||
f"ccl_allreduce bench requires world_size == topology SIP count "
|
||||
f"(world_size={ws}, n_sips={n_sips}). rank = PE mode was removed "
|
||||
f"(intra-SIP collectives are expressed inside the kernel)."
|
||||
)
|
||||
# Each rank initialises its tile with (rank + 1); after all_reduce
|
||||
# every rank sees sum(1..world_size).
|
||||
init = np.full((1, n_elem), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(init))
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
result = tensor.numpy()
|
||||
expected = float(sum(range(1, world_size + 1)))
|
||||
all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
|
||||
if rank == 0:
|
||||
if all_ok:
|
||||
print(f" {algo_name} (ws={world_size}): {world_size} OK")
|
||||
else:
|
||||
print(
|
||||
f" [FAIL] rank {rank} "
|
||||
f"(ws={world_size}, algo={algo_name}): "
|
||||
f"got mean={float(result.reshape(-1).mean()):.3f}, "
|
||||
f"expected={expected:.3f}"
|
||||
)
|
||||
print(
|
||||
f" {algo_name} (ws={world_size}): "
|
||||
f"0 OK / {world_size} FAIL"
|
||||
)
|
||||
return
|
||||
|
||||
# Legacy path: world_size overridden via ccl.yaml to exceed SIP count.
|
||||
# Single-worker at rank 0; whole tensor distributed across all
|
||||
# participating PEs using the derived DPPolicy. Matches pre-ADR-0024
|
||||
# behaviour.
|
||||
dp = _derive_dp(spec, world_size)
|
||||
tensor = torch.zeros(
|
||||
(1, world_size * n_elem), dtype="f16", dp=dp, name="ccl_in",
|
||||
return _BenchCfg(
|
||||
algorithm=merged["algorithm"],
|
||||
n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)),
|
||||
world_size=ws,
|
||||
)
|
||||
init = np.zeros((1, world_size * n_elem), dtype=np.float16)
|
||||
for r in range(world_size):
|
||||
init[0, r * n_elem : (r + 1) * n_elem] = float(r + 1)
|
||||
tensor.copy_(torch.from_numpy(init))
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
|
||||
result = tensor.numpy()
|
||||
expected = float(sum(range(1, world_size + 1)))
|
||||
all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
|
||||
|
||||
def _rank_local_dp() -> DPPolicy:
|
||||
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
|
||||
def _allocate_rank_tile(torch, rank: int, cfg: _BenchCfg):
|
||||
"""Allocate this rank's ``(1, n_elem)`` tile on its SIP."""
|
||||
return torch.zeros(
|
||||
(1, cfg.n_elem), dtype="f16",
|
||||
dp=_rank_local_dp(), name=f"ccl_in_r{rank}",
|
||||
)
|
||||
|
||||
|
||||
def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None:
|
||||
"""Fill the tile with the scalar ``rank + 1`` (deterministic + easy to verify)."""
|
||||
arr = np.full((1, cfg.n_elem), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(arr))
|
||||
|
||||
|
||||
def _report(result: np.ndarray, cfg: _BenchCfg) -> None:
|
||||
"""Single-line pass/fail printer (rank 0 only, called after all_reduce)."""
|
||||
expected = float(sum(range(1, cfg.world_size + 1)))
|
||||
ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
|
||||
if ok:
|
||||
print(f" {cfg.algorithm} (ws={cfg.world_size}): {cfg.world_size} OK")
|
||||
return
|
||||
got = float(result.reshape(-1).mean())
|
||||
print(
|
||||
f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): "
|
||||
f"got mean={got:.3f}, expected={expected:.3f}"
|
||||
)
|
||||
print(
|
||||
f" {cfg.algorithm} (ws={cfg.world_size}): "
|
||||
f"0 OK / {cfg.world_size} FAIL"
|
||||
)
|
||||
|
||||
|
||||
def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
|
||||
torch.ahbm.set_device(rank)
|
||||
tensor = _allocate_rank_tile(torch, rank, cfg)
|
||||
_init_with_rank_value(torch, tensor, rank, cfg)
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
if rank == 0:
|
||||
if all_ok:
|
||||
print(f" {algo_name} (ws={world_size}): {world_size} OK")
|
||||
else:
|
||||
flat = result.reshape(-1)
|
||||
n_fail = 0
|
||||
for r in range(world_size):
|
||||
slice_r = flat[r * n_elem : (r + 1) * n_elem]
|
||||
if not np.allclose(slice_r, expected, rtol=1e-1, atol=1e-1):
|
||||
n_fail += 1
|
||||
if n_fail <= 5:
|
||||
print(
|
||||
f" [FAIL] rank {r} "
|
||||
f"(ws={world_size}, algo={algo_name}): "
|
||||
f"got mean={float(slice_r.mean()):.3f}, "
|
||||
f"expected={expected:.3f}"
|
||||
)
|
||||
print(
|
||||
f" {algo_name} (ws={world_size}): "
|
||||
f"{world_size - n_fail} OK / {n_fail} FAIL"
|
||||
)
|
||||
_report(tensor.numpy(), cfg)
|
||||
|
||||
|
||||
def run(torch) -> None:
|
||||
"""CLI entry — dispatch to multi-greenlet path when ws == SIP count,
|
||||
else fall back to single-worker legacy path for ccl.yaml override compat.
|
||||
"""
|
||||
dist = torch.distributed
|
||||
dist.init_process_group(backend="ahbm")
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
spec = torch.spec or {}
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
|
||||
if world_size == n_sips:
|
||||
# ADR-0027 D1: ``torch.multiprocessing.spawn`` replaces the prior
|
||||
# hand-rolled greenlet loop. The spawn namespace absorbs the
|
||||
# scheduler drain (D0.4) so kernel_runner's spawned kernel greenlets
|
||||
# correctly get main as their parent (ADR-0024 Phase B blocker
|
||||
# resolved via D0 worker-wait generalisation).
|
||||
torch.multiprocessing.spawn(
|
||||
worker, args=(world_size, torch), nprocs=world_size,
|
||||
)
|
||||
else:
|
||||
# Legacy single-worker path (ccl.yaml world_size override).
|
||||
worker(rank=dist.get_rank(), world_size=world_size, torch=torch)
|
||||
torch.distributed.init_process_group(backend="ahbm")
|
||||
cfg = _resolve_cfg(torch)
|
||||
torch.multiprocessing.spawn(
|
||||
_worker, args=(cfg, torch), nprocs=cfg.world_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user