"""ADR-0027 T1: TP parallel_state (D3). Phase 1: ``kernbench.tp`` module does not exist yet — tests fail at import. Phase 2 (D2/D3) lands the package and these pass. """ from __future__ import annotations import pytest def _make_ctx(topology): from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) return RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_t1", spec=topology.topology_obj.spec, ) def test_tp_package_importable(): """D2: kernbench.tp must be importable.""" import kernbench.tp as tp assert hasattr(tp, "initialize_model_parallel") assert hasattr(tp, "get_tensor_model_parallel_world_size") assert hasattr(tp, "get_tensor_model_parallel_rank") def test_initialize_model_parallel_matches_world_size(topology, tmp_path, monkeypatch): """D3: TP size must equal dist world_size; otherwise NotImplementedError.""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) assert tp.get_tensor_model_parallel_world_size() == ws def test_initialize_mismatched_ws_raises(topology): """D3: calling with tp_size != world_size raises NotImplementedError.""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() with pytest.raises(NotImplementedError): tp.initialize_model_parallel(ws + 1) def test_get_tp_rank_is_greenlet_local(topology): """D3: get_tensor_model_parallel_rank returns greenlet-local rank (delegates to torch.distributed.get_rank, ADR-0024 D9).""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) observed: list[int] = [] def _worker(rank: int): observed.append(tp.get_tensor_model_parallel_rank()) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) assert sorted(observed) == list(range(ws)) def test_get_world_size_before_init_raises(): """D3: uninitialised TP group → accessing world_size fails informatively.""" from kernbench.tp import parallel_state # Reset internal state if previous tests (or parallel workers) left it set. parallel_state._reset_for_tests() with pytest.raises((RuntimeError, AssertionError, TypeError)): _ = parallel_state.get_tensor_model_parallel_world_size() + 0