"""Phase 1 tests for ADR-0024 SIP-level TP launcher (MVP scope). Covers: - D1 world_size = SIP count fallback - D9 get_rank greenlet-local + _bind_rank - D10 torch.ahbm.set_device + torch.accelerator alias - D11 tensor placement scoped to current device SIP - D12/D13 run() spawns one greenlet per rank Deferred to later ADR-0024 sub-phases: - D2 engine-routed install - D6 install_plan.py - D7 epoch barrier (this phase uses simple submit+yield+wait) - D8 validator registry """ from __future__ import annotations import os import textwrap import pytest from greenlet import greenlet from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext # ── Fixtures / helpers ──────────────────────────────────────────────── class _FakeCtx: """Minimal ctx double — only exposes what AhbmCCLBackend.__init__ uses. Stubs install_ipcq so we can unit-test _resolve_world_size without touching the engine stack. """ def __init__(self, spec: dict) -> None: self.spec = spec self.install_calls: list[dict] = [] def install_ipcq(self, **kwargs) -> dict: self.install_calls.append(dict(kwargs)) return {} def _write_minimal_ccl_yaml(tmp_path) -> str: """Write a ccl.yaml with NO world_size override — forces topology derivation.""" body = textwrap.dedent("""\ defaults: algorithm: ring_allreduce_tcm buffer_kind: tcm backpressure: sleep n_slots: 4 slot_size: 4096 vc_chunk_size: 256 ipcq_credit_size_bytes: 16 algorithms: ring_allreduce_tcm: module: kernbench.ccl.algorithms.ring_allreduce topology: ring_1d buffer_kind: tcm n_elem: 8 """) yaml_path = tmp_path / "ccl.yaml" yaml_path.write_text(body) return str(tmp_path) # ── D1: world_size = SIP count fallback ─────────────────────────────── def test_world_size_equals_sip_count(tmp_path, monkeypatch, spec): """With no override, backend derives world_size from SIP count only. Topology has 2 SIPs × 16 cubes × 8 PEs = 256 PEs. The TP/DP model places the collective group at the SIP boundary, so world_size must equal SIP count (2), not total PE count (256). """ monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path)) ctx = _FakeCtx(spec=spec) backend = AhbmCCLBackend(torch_ctx=ctx) expected = int(spec["system"]["sips"]["count"]) assert backend.world_size == expected, ( f"expected world_size == SIP count ({expected}); " f"got {backend.world_size} — still deriving sips × cubes × pes" ) # ── D9: get_rank greenlet-local + _bind_rank ────────────────────────── def test_get_rank_is_greenlet_local(tmp_path, monkeypatch, spec): """Each greenlet sees its own rank via dist.get_rank(). Framework-level launcher binds greenlet → rank; get_rank() resolves the current greenlet and returns that rank. """ monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path)) ctx = _FakeCtx(spec=spec) dc = DistributedContext() dc._ctx_ref = ctx dc.init_process_group(backend="ahbm") assert dc.get_world_size() == int(spec["system"]["sips"]["count"]) assert hasattr(dc, "_bind_rank"), ( "DistributedContext must expose _bind_rank(g, rank) hook" ) seen: dict[int, int] = {} def _probe(rank: int) -> None: seen[rank] = dc.get_rank() g0 = greenlet(lambda: _probe(0)) g1 = greenlet(lambda: _probe(1)) dc._bind_rank(g0, 0) dc._bind_rank(g1, 1) g0.switch() g1.switch() assert seen == {0: 0, 1: 1}, ( f"expected each greenlet to see its own rank; got {seen}" ) def test_get_rank_fallback_without_bind(tmp_path, monkeypatch, spec): """Unbound greenlet falls back to rank 0 (single-driver compat).""" monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path)) ctx = _FakeCtx(spec=spec) dc = DistributedContext() dc._ctx_ref = ctx dc.init_process_group(backend="ahbm") # Call from main (unbound) greenlet assert dc.get_rank() == 0 # ── D10/D11: torch.ahbm.set_device + tensor scoping ─────────────────── def test_ahbm_set_device_binds_tensor_to_single_sip(topology): """``torch.ahbm.set_device(rank)`` + default-sip DPPolicy → tensor on SIP rank. After set_device(1), a tensor with DPPolicy leaving the SIP dimension at its default must be placed entirely on SIP 1. """ from kernbench.policy.placement.dp import DPPolicy from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) with RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_ahbm_set_device", spec=topology.topology_obj.spec, ) as ctx: assert hasattr(ctx, "ahbm"), ( "RuntimeContext must expose .ahbm namespace (ADR-0024 D10)" ) ctx.ahbm.set_device(1) dp = DPPolicy(cube="column_wise", pe="column_wise") # default sip tensor = ctx.zeros((1, 128), dtype="f16", dp=dp, name="probe") shard_sips = {s.sip for s in tensor._handle.shards} assert shard_sips == {1}, ( f"after ahbm.set_device(1), all shards should live on SIP 1; " f"got sips={sorted(shard_sips)}" ) def test_accelerator_alias_mirrors_ahbm(topology): """torch.accelerator.set_device_index(r) is an alias for ahbm.set_device(r) (PyTorch 2.x device-agnostic surface).""" from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) with RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_accelerator_alias", spec=topology.topology_obj.spec, ) as ctx: assert hasattr(ctx, "accelerator"), ( "RuntimeContext must expose .accelerator namespace (ADR-0024 D10)" ) ctx.accelerator.set_device_index(1) # Both namespaces should report SIP 1 as current device assert ctx.ahbm.current_device() == 1 assert ctx.accelerator.current_device_index() == 1 # ── D12/D13: run() spawns one worker per rank ───────────────────────── def test_run_spawns_one_worker_per_rank(tmp_path, monkeypatch, spec): """The bench's ``run()`` invokes ``worker`` once per rank. With world_size = SIP count = 2 (topology), worker must be called exactly twice with ranks 0 and 1. Each call sees world_size=2. """ project_root = os.path.abspath( os.path.join(os.path.dirname(__file__), "..") ) monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path)) import benches.ccl_allreduce as bench calls: list[tuple[int, int]] = [] def _fake_worker(rank: int, world_size: int, torch) -> None: calls.append((rank, world_size)) monkeypatch.setattr(bench, "worker", _fake_worker) from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology topo = resolve_topology(os.path.join(project_root, "topology.yaml")) engine = GraphEngine(topo.topology_obj, enable_data=True) with RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_run_spawns", spec=topo.topology_obj.spec, ) as ctx: bench.run(ctx) ranks = sorted(r for r, _ in calls) ws_values = {ws for _, ws in calls} expected_ws = int(spec["system"]["sips"]["count"]) assert ranks == list(range(expected_ws)), ( f"run() should invoke worker for ranks 0..{expected_ws - 1}; " f"saw ranks={ranks}" ) assert ws_values == {expected_ws}, ( f"each worker should see world_size={expected_ws}; saw {ws_values}" )