"""Tests for the mock CCL runtime (ADR-0023 D15).""" from __future__ import annotations import numpy as np from kernbench.ccl.algorithms import ring_allreduce from kernbench.ccl.testing import run_kernel_in_mock def test_ring_allreduce_4_ranks(): """Run the ring all-reduce kernel under the mock runtime, no SimPy.""" n_elem = 8 inputs = [ np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(4) ] expected = sum(inputs) # [10, 10, ..., 10] outputs = run_kernel_in_mock( kernel_fn=ring_allreduce.kernel, world_size=4, topology="ring_1d", inputs=inputs, kernel_args=(n_elem, 4), ) assert len(outputs) == 4 for r in range(4): assert np.allclose(outputs[r], expected) def test_ring_allreduce_8_ranks(): n_elem = 16 inputs = [ np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(8) ] expected = sum(inputs) # [36, 36, ...] outputs = run_kernel_in_mock( kernel_fn=ring_allreduce.kernel, world_size=8, topology="ring_1d", inputs=inputs, kernel_args=(n_elem, 8), ) for r in range(8): assert np.allclose(outputs[r], expected) def test_ring_allreduce_random_data(): n_elem = 32 rng = np.random.default_rng(42) inputs = [rng.standard_normal(n_elem).astype(np.float16) for _ in range(4)] expected = sum(inputs) outputs = run_kernel_in_mock( kernel_fn=ring_allreduce.kernel, world_size=4, topology="ring_1d", inputs=inputs, kernel_args=(n_elem, 4), ) for r in range(4): assert np.allclose(outputs[r], expected, rtol=1e-2, atol=1e-2) def test_mock_runtime_invalid_direction_raises(): """A kernel that uses an unsupported direction should raise.""" import pytest def bad_kernel(t_ptr, n_elem, tl): tl.send(dir="N", src_addr=0, nbytes=2, shape=(1,), dtype="f16", space="hbm") inputs = [np.array([1.0], dtype=np.float16) for _ in range(2)] with pytest.raises(Exception): run_kernel_in_mock( kernel_fn=bad_kernel, world_size=2, topology="ring_1d", inputs=inputs, kernel_args=(1,), )