cfc2d74ec4
The unified ccl_allreduce bench previously carried two execution models
in one worker with ``if world_size == n_sips:`` branching:
- TP mode (rank = SIP, ADR-0024/0027): proper ProcessGroup semantics.
- Legacy rank = PE mode: single-driver worker allocating one big tensor
distributed across all PEs via _derive_dp, with kernel-level SPMD via
program_id.
The second model is unnecessary — intra-SIP PE-level collectives are
expressed inside the kernel (tl.send/tl.recv with program_id, IPCQ) and
do not need a host-side ProcessGroup. Removing it lets the bench be a
clean reference implementation of the TP launcher.
benches/ccl_allreduce.py:
- Config resolved once in run() via _resolve_cfg -> _BenchCfg dataclass.
- rank != n_sips now raises RuntimeError explicitly.
- _worker / _allocate_rank_tile / _init_with_rank_value / _report each
have one concern; duplicated init + verification paths collapsed.
- _derive_dp and the second verify+print block deleted.
- 166 lines -> 91 lines.
ccl.yaml:
- mesh_allreduce_4 (world_size: 4) and tree_allreduce_7 (world_size: 7)
algorithm entries removed (rank = PE only).
- Algorithm kernel files (kernbench.ccl.algorithms.mesh_allreduce,
tree_allreduce) kept as-is for direct-dispatch future use.
tests/test_ccl_allreduce_matrix.py:
- Matrix shrinks from 7 cases to 3: ring × {tcm, hbm, sram} at ws =
topology SIP count (= 2). mesh_2x2, tree_binary_7, ring_multi_cube,
and the three ring_*_8 cases removed.
tests/test_ccl_performance.py:
- _run_8rank renamed to _run_ring; world_size: 8 override dropped; now
exercises rank = SIP ring all-reduce.
tests/test_mp_spawn.py, tests/test_ccl_ddp_launcher.py:
- Monkeypatch target updated from bench.worker to bench._worker
(signature now takes BenchCfg instead of (rank, world_size)).
555 passed, 1 intentional skip. Tests that directly call
install_ipcq(world_size_override=N) for kernel-level sanity
(test_ccl_hello_world_guide, test_recv_copy_to_dst, test_tl_recv_async,
test_ccl_deadlock_detection) are unchanged — they never went through
the bench and still exercise the kernel-only path.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
245 lines
8.3 KiB
Python
245 lines
8.3 KiB
Python
"""Phase 1 tests for ADR-0024 SIP-level TP launcher (MVP scope).
|
||
|
||
Covers:
|
||
- D1 world_size = SIP count fallback
|
||
- D9 get_rank greenlet-local + _bind_rank
|
||
- D10 torch.ahbm.set_device + torch.accelerator alias
|
||
- D11 tensor placement scoped to current device SIP
|
||
- D12/D13 run() spawns one greenlet per rank
|
||
|
||
Deferred to later ADR-0024 sub-phases:
|
||
- D2 engine-routed install
|
||
- D6 install_plan.py
|
||
- D7 epoch barrier (this phase uses simple submit+yield+wait)
|
||
- D8 validator registry
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import textwrap
|
||
|
||
import pytest
|
||
from greenlet import greenlet
|
||
|
||
from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext
|
||
|
||
|
||
# ── Fixtures / helpers ────────────────────────────────────────────────
|
||
|
||
|
||
class _FakeCtx:
|
||
"""Minimal ctx double — only exposes what AhbmCCLBackend.__init__ uses.
|
||
|
||
Stubs install_ipcq so we can unit-test _resolve_world_size without
|
||
touching the engine stack.
|
||
"""
|
||
|
||
def __init__(self, spec: dict) -> None:
|
||
self.spec = spec
|
||
self.install_calls: list[dict] = []
|
||
|
||
def install_ipcq(self, **kwargs) -> dict:
|
||
self.install_calls.append(dict(kwargs))
|
||
return {}
|
||
|
||
|
||
def _write_minimal_ccl_yaml(tmp_path) -> str:
|
||
"""Write a ccl.yaml with NO world_size override — forces topology derivation."""
|
||
body = textwrap.dedent("""\
|
||
defaults:
|
||
algorithm: ring_allreduce_tcm
|
||
buffer_kind: tcm
|
||
backpressure: sleep
|
||
n_slots: 4
|
||
slot_size: 4096
|
||
vc_chunk_size: 256
|
||
ipcq_credit_size_bytes: 16
|
||
|
||
algorithms:
|
||
ring_allreduce_tcm:
|
||
module: kernbench.ccl.algorithms.ring_allreduce
|
||
topology: ring_1d
|
||
buffer_kind: tcm
|
||
n_elem: 8
|
||
""")
|
||
yaml_path = tmp_path / "ccl.yaml"
|
||
yaml_path.write_text(body)
|
||
return str(tmp_path)
|
||
|
||
|
||
# ── D1: world_size = SIP count fallback ───────────────────────────────
|
||
|
||
|
||
def test_world_size_equals_sip_count(tmp_path, monkeypatch, spec):
|
||
"""With no override, backend derives world_size from SIP count only.
|
||
|
||
Topology has 2 SIPs × 16 cubes × 8 PEs = 256 PEs. The TP/DP model
|
||
places the collective group at the SIP boundary, so world_size must
|
||
equal SIP count (2), not total PE count (256).
|
||
"""
|
||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||
ctx = _FakeCtx(spec=spec)
|
||
backend = AhbmCCLBackend(torch_ctx=ctx)
|
||
expected = int(spec["system"]["sips"]["count"])
|
||
assert backend.world_size == expected, (
|
||
f"expected world_size == SIP count ({expected}); "
|
||
f"got {backend.world_size} — still deriving sips × cubes × pes"
|
||
)
|
||
|
||
|
||
# ── D9: get_rank greenlet-local + _bind_rank ──────────────────────────
|
||
|
||
|
||
def test_get_rank_is_greenlet_local(tmp_path, monkeypatch, spec):
|
||
"""Each greenlet sees its own rank via dist.get_rank().
|
||
|
||
Framework-level launcher binds greenlet → rank; get_rank() resolves
|
||
the current greenlet and returns that rank.
|
||
"""
|
||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||
ctx = _FakeCtx(spec=spec)
|
||
dc = DistributedContext()
|
||
dc._ctx_ref = ctx
|
||
dc.init_process_group(backend="ahbm")
|
||
assert dc.get_world_size() == int(spec["system"]["sips"]["count"])
|
||
|
||
assert hasattr(dc, "_bind_rank"), (
|
||
"DistributedContext must expose _bind_rank(g, rank) hook"
|
||
)
|
||
|
||
seen: dict[int, int] = {}
|
||
|
||
def _probe(rank: int) -> None:
|
||
seen[rank] = dc.get_rank()
|
||
|
||
g0 = greenlet(lambda: _probe(0))
|
||
g1 = greenlet(lambda: _probe(1))
|
||
dc._bind_rank(g0, 0)
|
||
dc._bind_rank(g1, 1)
|
||
g0.switch()
|
||
g1.switch()
|
||
|
||
assert seen == {0: 0, 1: 1}, (
|
||
f"expected each greenlet to see its own rank; got {seen}"
|
||
)
|
||
|
||
|
||
def test_get_rank_fallback_without_bind(tmp_path, monkeypatch, spec):
|
||
"""Unbound greenlet falls back to rank 0 (single-driver compat)."""
|
||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||
ctx = _FakeCtx(spec=spec)
|
||
dc = DistributedContext()
|
||
dc._ctx_ref = ctx
|
||
dc.init_process_group(backend="ahbm")
|
||
# Call from main (unbound) greenlet
|
||
assert dc.get_rank() == 0
|
||
|
||
|
||
# ── D10/D11: torch.ahbm.set_device + tensor scoping ───────────────────
|
||
|
||
|
||
def test_ahbm_set_device_binds_tensor_to_single_sip(topology):
|
||
"""``torch.ahbm.set_device(rank)`` + default-sip DPPolicy → tensor on SIP rank.
|
||
|
||
After set_device(1), a tensor with DPPolicy leaving the SIP dimension
|
||
at its default must be placed entirely on SIP 1.
|
||
"""
|
||
from kernbench.policy.placement.dp import DPPolicy
|
||
from kernbench.runtime_api.context import RuntimeContext
|
||
from kernbench.runtime_api.types import DeviceSelector
|
||
from kernbench.sim_engine.engine import GraphEngine
|
||
|
||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||
with RuntimeContext(
|
||
engine=engine,
|
||
target_device=DeviceSelector("all"),
|
||
correlation_id="test_ahbm_set_device",
|
||
spec=topology.topology_obj.spec,
|
||
) as ctx:
|
||
assert hasattr(ctx, "ahbm"), (
|
||
"RuntimeContext must expose .ahbm namespace (ADR-0024 D10)"
|
||
)
|
||
ctx.ahbm.set_device(1)
|
||
|
||
dp = DPPolicy(cube="column_wise", pe="column_wise") # default sip
|
||
tensor = ctx.zeros((1, 128), dtype="f16", dp=dp, name="probe")
|
||
|
||
shard_sips = {s.sip for s in tensor._handle.shards}
|
||
assert shard_sips == {1}, (
|
||
f"after ahbm.set_device(1), all shards should live on SIP 1; "
|
||
f"got sips={sorted(shard_sips)}"
|
||
)
|
||
|
||
|
||
def test_accelerator_alias_mirrors_ahbm(topology):
|
||
"""torch.accelerator.set_device_index(r) is an alias for ahbm.set_device(r)
|
||
(PyTorch 2.x device-agnostic surface)."""
|
||
from kernbench.runtime_api.context import RuntimeContext
|
||
from kernbench.runtime_api.types import DeviceSelector
|
||
from kernbench.sim_engine.engine import GraphEngine
|
||
|
||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||
with RuntimeContext(
|
||
engine=engine,
|
||
target_device=DeviceSelector("all"),
|
||
correlation_id="test_accelerator_alias",
|
||
spec=topology.topology_obj.spec,
|
||
) as ctx:
|
||
assert hasattr(ctx, "accelerator"), (
|
||
"RuntimeContext must expose .accelerator namespace (ADR-0024 D10)"
|
||
)
|
||
ctx.accelerator.set_device_index(1)
|
||
# Both namespaces should report SIP 1 as current device
|
||
assert ctx.ahbm.current_device() == 1
|
||
assert ctx.accelerator.current_device_index() == 1
|
||
|
||
|
||
# ── D12/D13: run() spawns one worker per rank ─────────────────────────
|
||
|
||
|
||
def test_run_spawns_one_worker_per_rank(tmp_path, monkeypatch, spec):
|
||
"""The bench's ``run()`` invokes ``worker`` once per rank.
|
||
|
||
With world_size = SIP count = 2 (topology), worker must be called
|
||
exactly twice with ranks 0 and 1. Each call sees world_size=2.
|
||
"""
|
||
project_root = os.path.abspath(
|
||
os.path.join(os.path.dirname(__file__), "..")
|
||
)
|
||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||
|
||
import benches.ccl_allreduce as bench
|
||
|
||
calls: list[tuple[int, int]] = []
|
||
|
||
def _fake_worker(rank, cfg, torch) -> None:
|
||
calls.append((rank, cfg.world_size))
|
||
|
||
monkeypatch.setattr(bench, "_worker", _fake_worker)
|
||
|
||
from kernbench.runtime_api.context import RuntimeContext
|
||
from kernbench.runtime_api.types import DeviceSelector
|
||
from kernbench.sim_engine.engine import GraphEngine
|
||
from kernbench.topology.builder import resolve_topology
|
||
|
||
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
|
||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||
with RuntimeContext(
|
||
engine=engine,
|
||
target_device=DeviceSelector("all"),
|
||
correlation_id="test_run_spawns",
|
||
spec=topo.topology_obj.spec,
|
||
) as ctx:
|
||
bench.run(ctx)
|
||
|
||
ranks = sorted(r for r, _ in calls)
|
||
ws_values = {ws for _, ws in calls}
|
||
expected_ws = int(spec["system"]["sips"]["count"])
|
||
assert ranks == list(range(expected_ws)), (
|
||
f"run() should invoke worker for ranks 0..{expected_ws - 1}; "
|
||
f"saw ranks={ranks}"
|
||
)
|
||
assert ws_values == {expected_ws}, (
|
||
f"each worker should see world_size={expected_ws}; saw {ws_values}"
|
||
)
|