Files
kernbench2/tests/test_ccl_ddp_launcher.py
T
ywkang 4ba0a83e71 Implement ADR-0024 Phase A: SIP-level TP launcher MVP
Scope (Phase A):
- D1: world_size fallback = SIP count (rank = SIP, TP boundary)
- D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0)
- D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias
- D11: tensor placement scoped to current-device SIP (post-hoc pe_index
  shift — ADR-0026 replaces with structural coords)
- D12/D13: multi-greenlet run() with simple round-robin scheduler;
  hybrid dispatch (ws == SIP count → multi-greenlet, else legacy
  single-worker for ccl.yaml override compat)
- D7 partial: backend.all_reduce submit + yield + wait via launch()'s
  new _defer_wait flag; parent-less greenlets skip yield
- Relaxed shard-count check (len(shards) > 0 instead of == world_size)
- rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips

Deferred to Phase B:
- Engine-routed install (D2) — keeps sideband
- install_plan.py module (D6) — keeps install.py
- Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock
- Validator registry (D8)
- Cross-SIP multi-greenlet + real kernel integration — matrix
  ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix;
  marked xfail(run=False) pending Phase B diagnosis (suspected per-rank
  kernel_args / program_id mismatch)

Tests:
- test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13
- test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override
  cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 /
  tree_binary_7) all pass via legacy path

514 tests pass, 1 xfail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 09:00:28 -07:00

245 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase 1 tests for ADR-0024 SIP-level TP launcher (MVP scope).
Covers:
- D1 world_size = SIP count fallback
- D9 get_rank greenlet-local + _bind_rank
- D10 torch.ahbm.set_device + torch.accelerator alias
- D11 tensor placement scoped to current device SIP
- D12/D13 run() spawns one greenlet per rank
Deferred to later ADR-0024 sub-phases:
- D2 engine-routed install
- D6 install_plan.py
- D7 epoch barrier (this phase uses simple submit+yield+wait)
- D8 validator registry
"""
from __future__ import annotations
import os
import textwrap
import pytest
from greenlet import greenlet
from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext
# ── Fixtures / helpers ────────────────────────────────────────────────
class _FakeCtx:
"""Minimal ctx double — only exposes what AhbmCCLBackend.__init__ uses.
Stubs install_ipcq so we can unit-test _resolve_world_size without
touching the engine stack.
"""
def __init__(self, spec: dict) -> None:
self.spec = spec
self.install_calls: list[dict] = []
def install_ipcq(self, **kwargs) -> dict:
self.install_calls.append(dict(kwargs))
return {}
def _write_minimal_ccl_yaml(tmp_path) -> str:
"""Write a ccl.yaml with NO world_size override — forces topology derivation."""
body = textwrap.dedent("""\
defaults:
algorithm: ring_allreduce_tcm
buffer_kind: tcm
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
ring_allreduce_tcm:
module: kernbench.ccl.algorithms.ring_allreduce
topology: ring_1d
buffer_kind: tcm
n_elem: 8
""")
yaml_path = tmp_path / "ccl.yaml"
yaml_path.write_text(body)
return str(tmp_path)
# ── D1: world_size = SIP count fallback ───────────────────────────────
def test_world_size_equals_sip_count(tmp_path, monkeypatch, spec):
"""With no override, backend derives world_size from SIP count only.
Topology has 2 SIPs × 16 cubes × 8 PEs = 256 PEs. The TP/DP model
places the collective group at the SIP boundary, so world_size must
equal SIP count (2), not total PE count (256).
"""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
backend = AhbmCCLBackend(torch_ctx=ctx)
expected = int(spec["system"]["sips"]["count"])
assert backend.world_size == expected, (
f"expected world_size == SIP count ({expected}); "
f"got {backend.world_size} — still deriving sips × cubes × pes"
)
# ── D9: get_rank greenlet-local + _bind_rank ──────────────────────────
def test_get_rank_is_greenlet_local(tmp_path, monkeypatch, spec):
"""Each greenlet sees its own rank via dist.get_rank().
Framework-level launcher binds greenlet → rank; get_rank() resolves
the current greenlet and returns that rank.
"""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
dc = DistributedContext()
dc._ctx_ref = ctx
dc.init_process_group(backend="ahbm")
assert dc.get_world_size() == int(spec["system"]["sips"]["count"])
assert hasattr(dc, "_bind_rank"), (
"DistributedContext must expose _bind_rank(g, rank) hook"
)
seen: dict[int, int] = {}
def _probe(rank: int) -> None:
seen[rank] = dc.get_rank()
g0 = greenlet(lambda: _probe(0))
g1 = greenlet(lambda: _probe(1))
dc._bind_rank(g0, 0)
dc._bind_rank(g1, 1)
g0.switch()
g1.switch()
assert seen == {0: 0, 1: 1}, (
f"expected each greenlet to see its own rank; got {seen}"
)
def test_get_rank_fallback_without_bind(tmp_path, monkeypatch, spec):
"""Unbound greenlet falls back to rank 0 (single-driver compat)."""
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
ctx = _FakeCtx(spec=spec)
dc = DistributedContext()
dc._ctx_ref = ctx
dc.init_process_group(backend="ahbm")
# Call from main (unbound) greenlet
assert dc.get_rank() == 0
# ── D10/D11: torch.ahbm.set_device + tensor scoping ───────────────────
def test_ahbm_set_device_binds_tensor_to_single_sip(topology):
"""``torch.ahbm.set_device(rank)`` + default-sip DPPolicy → tensor on SIP rank.
After set_device(1), a tensor with DPPolicy leaving the SIP dimension
at its default must be placed entirely on SIP 1.
"""
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_ahbm_set_device",
spec=topology.topology_obj.spec,
) as ctx:
assert hasattr(ctx, "ahbm"), (
"RuntimeContext must expose .ahbm namespace (ADR-0024 D10)"
)
ctx.ahbm.set_device(1)
dp = DPPolicy(cube="column_wise", pe="column_wise") # default sip
tensor = ctx.zeros((1, 128), dtype="f16", dp=dp, name="probe")
shard_sips = {s.sip for s in tensor._handle.shards}
assert shard_sips == {1}, (
f"after ahbm.set_device(1), all shards should live on SIP 1; "
f"got sips={sorted(shard_sips)}"
)
def test_accelerator_alias_mirrors_ahbm(topology):
"""torch.accelerator.set_device_index(r) is an alias for ahbm.set_device(r)
(PyTorch 2.x device-agnostic surface)."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_accelerator_alias",
spec=topology.topology_obj.spec,
) as ctx:
assert hasattr(ctx, "accelerator"), (
"RuntimeContext must expose .accelerator namespace (ADR-0024 D10)"
)
ctx.accelerator.set_device_index(1)
# Both namespaces should report SIP 1 as current device
assert ctx.ahbm.current_device() == 1
assert ctx.accelerator.current_device_index() == 1
# ── D12/D13: run() spawns one worker per rank ─────────────────────────
def test_run_spawns_one_worker_per_rank(tmp_path, monkeypatch, spec):
"""The bench's ``run()`` invokes ``worker`` once per rank.
With world_size = SIP count = 2 (topology), worker must be called
exactly twice with ranks 0 and 1. Each call sees world_size=2.
"""
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")
)
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
import benches.ccl_allreduce as bench
calls: list[tuple[int, int]] = []
def _fake_worker(rank: int, world_size: int, torch) -> None:
calls.append((rank, world_size))
monkeypatch.setattr(bench, "worker", _fake_worker)
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
engine = GraphEngine(topo.topology_obj, enable_data=True)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_run_spawns",
spec=topo.topology_obj.spec,
) as ctx:
bench.run(ctx)
ranks = sorted(r for r, _ in calls)
ws_values = {ws for _, ws in calls}
expected_ws = int(spec["system"]["sips"]["count"])
assert ranks == list(range(expected_ws)), (
f"run() should invoke worker for ranks 0..{expected_ws - 1}; "
f"saw ranks={ranks}"
)
assert ws_values == {expected_ws}, (
f"each worker should see world_size={expected_ws}; saw {ws_values}"
)