"""Tests for the torch.distributed-compat facade (ADR-0023 D11). These tests verify the public API surface of ``DistributedContext`` + ``AhbmCCLBackend``. End-to-end correctness of the allreduce itself is covered by tests/test_ccl_allreduce_matrix.py. """ from __future__ import annotations from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext def test_init_process_group_requires_ctx_ref(): """Using DistributedContext without RuntimeContext binding should fail.""" dist = DistributedContext() # Not bound to a RuntimeContext → init should raise. try: dist.init_process_group(backend="ahbm") assert False, "expected RuntimeError" except RuntimeError: pass def test_init_process_group_rejects_unknown_backend(): """Unknown backend raises ValueError (matches pytorch behavior).""" dist = DistributedContext() dist._ctx_ref = object() # dummy; won't be reached before the check try: dist.init_process_group(backend="nccl") assert False, "expected ValueError" except ValueError: pass def test_distributed_pytorch_compat_surface(): """DistributedContext only exposes real torch.distributed API names.""" # Every public attribute should either be a real pytorch name or private. allowed = { "init_process_group", "is_initialized", "get_world_size", "get_rank", "get_backend", "all_reduce", "barrier", } dc = DistributedContext() for attr in dir(dc): if attr.startswith("_"): continue assert attr in allowed, ( f"DistributedContext exposes non-pytorch API: {attr!r}" ) def test_backend_class_surface(): """AhbmCCLBackend exposes only all_reduce + barrier + world_size.""" # Ensure we don't accidentally leak internal method names. public = {m for m in dir(AhbmCCLBackend) if not m.startswith("_")} # Class must at minimum expose these. assert "all_reduce" in public assert "barrier" in public assert "world_size" in public