Files
kernbench2/tests/test_ccl_ddp_launcher.py
T
ywkang cfc2d74ec4 Refactor ccl_allreduce bench: rank=SIP only, remove rank=PE legacy path
The unified ccl_allreduce bench previously carried two execution models
in one worker with ``if world_size == n_sips:`` branching:
  - TP mode (rank = SIP, ADR-0024/0027): proper ProcessGroup semantics.
  - Legacy rank = PE mode: single-driver worker allocating one big tensor
    distributed across all PEs via _derive_dp, with kernel-level SPMD via
    program_id.

The second model is unnecessary — intra-SIP PE-level collectives are
expressed inside the kernel (tl.send/tl.recv with program_id, IPCQ) and
do not need a host-side ProcessGroup. Removing it lets the bench be a
clean reference implementation of the TP launcher.

benches/ccl_allreduce.py:
- Config resolved once in run() via _resolve_cfg -> _BenchCfg dataclass.
- rank != n_sips now raises RuntimeError explicitly.
- _worker / _allocate_rank_tile / _init_with_rank_value / _report each
  have one concern; duplicated init + verification paths collapsed.
- _derive_dp and the second verify+print block deleted.
- 166 lines -> 91 lines.

ccl.yaml:
- mesh_allreduce_4 (world_size: 4) and tree_allreduce_7 (world_size: 7)
  algorithm entries removed (rank = PE only).
- Algorithm kernel files (kernbench.ccl.algorithms.mesh_allreduce,
  tree_allreduce) kept as-is for direct-dispatch future use.

tests/test_ccl_allreduce_matrix.py:
- Matrix shrinks from 7 cases to 3: ring × {tcm, hbm, sram} at ws =
  topology SIP count (= 2). mesh_2x2, tree_binary_7, ring_multi_cube,
  and the three ring_*_8 cases removed.

tests/test_ccl_performance.py:
- _run_8rank renamed to _run_ring; world_size: 8 override dropped; now
  exercises rank = SIP ring all-reduce.

tests/test_mp_spawn.py, tests/test_ccl_ddp_launcher.py:
- Monkeypatch target updated from bench.worker to bench._worker
  (signature now takes BenchCfg instead of (rank, world_size)).

555 passed, 1 intentional skip. Tests that directly call
install_ipcq(world_size_override=N) for kernel-level sanity
(test_ccl_hello_world_guide, test_recv_copy_to_dst, test_tl_recv_async,
test_ccl_deadlock_detection) are unchanged — they never went through
the bench and still exercise the kernel-only path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 16:45:27 -07:00

245 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase 1 tests for ADR-0024 SIP-level TP launcher (MVP scope).
Covers:
- D1 world_size = SIP count fallback
- D9 get_rank greenlet-local + _bind_rank
- D10 torch.ahbm.set_device + torch.accelerator alias
- D11 tensor placement scoped to current device SIP
- D12/D13 run() spawns one greenlet per rank
Deferred to later ADR-0024 sub-phases:
- D2 engine-routed install
- D6 install_plan.py
- D7 epoch barrier (this phase uses simple submit+yield+wait)
- D8 validator registry
"""
from __future__ import annotations
import os
import textwrap
import pytest
from greenlet import greenlet
from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext
# ── Fixtures / helpers ────────────────────────────────────────────────
class _FakeCtx:
"""Minimal ctx double — only exposes what AhbmCCLBackend.__init__ uses.
Stubs install_ipcq so we can unit-test _resolve_world_size without
touching the engine stack.
"""
def __init__(self, spec: dict) -> None:
self.spec = spec
self.install_calls: list[dict] = []
def install_ipcq(self, **kwargs) -> dict:
self.install_calls.append(dict(kwargs))
return {}
def _write_minimal_ccl_yaml(tmp_path) -> str:
"""Write a ccl.yaml with NO world_size override — forces topology derivation."""
body = textwrap.dedent("""\
defaults:
algorithm: ring_allreduce_tcm
buffer_kind: tcm
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
ring_allreduce_tcm:
module: kernbench.ccl.algorithms.ring_allreduce
topology: ring_1d
buffer_kind: tcm
n_elem: 8
""")
yaml_path = tmp_path / "ccl.yaml"
yaml_path.write_text(body)
return str(tmp_path)
# ── D1: world_size = SIP count fallback ───────────────────────────────
def test_world_size_equals_sip_count(tmp_path, monkeypatch, spec):
"""With no override, backend derives world_size from SIP count only.
Topology has 2 SIPs × 16 cubes × 8 PEs = 256 PEs. The TP/DP model
places the collective group at the SIP boundary, so world_size must
equal SIP count (2), not total PE count (256).
"""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
backend = AhbmCCLBackend(torch_ctx=ctx)
expected = int(spec["system"]["sips"]["count"])
assert backend.world_size == expected, (
f"expected world_size == SIP count ({expected}); "
f"got {backend.world_size} — still deriving sips × cubes × pes"
)
# ── D9: get_rank greenlet-local + _bind_rank ──────────────────────────
def test_get_rank_is_greenlet_local(tmp_path, monkeypatch, spec):
"""Each greenlet sees its own rank via dist.get_rank().
Framework-level launcher binds greenlet → rank; get_rank() resolves
the current greenlet and returns that rank.
"""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
dc = DistributedContext()
dc._ctx_ref = ctx
dc.init_process_group(backend="ahbm")
assert dc.get_world_size() == int(spec["system"]["sips"]["count"])
assert hasattr(dc, "_bind_rank"), (
"DistributedContext must expose _bind_rank(g, rank) hook"
)
seen: dict[int, int] = {}
def _probe(rank: int) -> None:
seen[rank] = dc.get_rank()
g0 = greenlet(lambda: _probe(0))
g1 = greenlet(lambda: _probe(1))
dc._bind_rank(g0, 0)
dc._bind_rank(g1, 1)
g0.switch()
g1.switch()
assert seen == {0: 0, 1: 1}, (
f"expected each greenlet to see its own rank; got {seen}"
)
def test_get_rank_fallback_without_bind(tmp_path, monkeypatch, spec):
"""Unbound greenlet falls back to rank 0 (single-driver compat)."""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
dc = DistributedContext()
dc._ctx_ref = ctx
dc.init_process_group(backend="ahbm")
# Call from main (unbound) greenlet
assert dc.get_rank() == 0
# ── D10/D11: torch.ahbm.set_device + tensor scoping ───────────────────
def test_ahbm_set_device_binds_tensor_to_single_sip(topology):
"""``torch.ahbm.set_device(rank)`` + default-sip DPPolicy → tensor on SIP rank.
After set_device(1), a tensor with DPPolicy leaving the SIP dimension
at its default must be placed entirely on SIP 1.
"""
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_ahbm_set_device",
spec=topology.topology_obj.spec,
) as ctx:
assert hasattr(ctx, "ahbm"), (
"RuntimeContext must expose .ahbm namespace (ADR-0024 D10)"
)
ctx.ahbm.set_device(1)
dp = DPPolicy(cube="column_wise", pe="column_wise") # default sip
tensor = ctx.zeros((1, 128), dtype="f16", dp=dp, name="probe")
shard_sips = {s.sip for s in tensor._handle.shards}
assert shard_sips == {1}, (
f"after ahbm.set_device(1), all shards should live on SIP 1; "
f"got sips={sorted(shard_sips)}"
)
def test_accelerator_alias_mirrors_ahbm(topology):
"""torch.accelerator.set_device_index(r) is an alias for ahbm.set_device(r)
(PyTorch 2.x device-agnostic surface)."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_accelerator_alias",
spec=topology.topology_obj.spec,
) as ctx:
assert hasattr(ctx, "accelerator"), (
"RuntimeContext must expose .accelerator namespace (ADR-0024 D10)"
)
ctx.accelerator.set_device_index(1)
# Both namespaces should report SIP 1 as current device
assert ctx.ahbm.current_device() == 1
assert ctx.accelerator.current_device_index() == 1
# ── D12/D13: run() spawns one worker per rank ─────────────────────────
def test_run_spawns_one_worker_per_rank(tmp_path, monkeypatch, spec):
"""The bench's ``run()`` invokes ``worker`` once per rank.
With world_size = SIP count = 2 (topology), worker must be called
exactly twice with ranks 0 and 1. Each call sees world_size=2.
"""
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")
)
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
import benches.ccl_allreduce as bench
calls: list[tuple[int, int]] = []
def _fake_worker(rank, cfg, torch) -> None:
calls.append((rank, cfg.world_size))
monkeypatch.setattr(bench, "_worker", _fake_worker)
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
engine = GraphEngine(topo.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_run_spawns",
spec=topo.topology_obj.spec,
) as ctx:
bench.run(ctx)
ranks = sorted(r for r, _ in calls)
ws_values = {ws for _, ws in calls}
expected_ws = int(spec["system"]["sips"]["count"])
assert ranks == list(range(expected_ws)), (
f"run() should invoke worker for ranks 0..{expected_ws - 1}; "
f"saw ranks={ranks}"
)
assert ws_values == {expected_ws}, (
f"each worker should see world_size={expected_ws}; saw {ws_values}"
)