"""Tests for tl.recv_async + tl.wait (ADR-0023 D4).""" from __future__ import annotations import numpy as np from kernbench.ccl.testing import run_kernel_in_mock def kernel_async_recv(t_ptr, n_elem, tl): """Each PE issues recv_async first, then send, then wait — this exercises the non-blocking path. Uses TensorHandle math (PE_MATH) for accumulation so Phase 2 produces correct final HBM contents.""" rank = tl.program_id(axis=0) world_size = tl.num_programs(axis=0) nbytes = n_elem * 2 pe_addr = t_ptr + rank * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") current = acc for _step in range(world_size - 1): future = tl.recv_async(dir="W", shape=(n_elem,), dtype="f16") tl.send(dir="E", src=current) recv = tl.wait(future) acc = acc + recv current = recv # forward W's tile to E next round tl.store(pe_addr, acc) def test_recv_async_mock_runtime(): n_elem = 8 inputs = [ np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(4) ] expected = sum(inputs) outputs = run_kernel_in_mock( kernel_fn=kernel_async_recv, world_size=4, topology="ring_1d", inputs=inputs, kernel_args=(n_elem,), ) for r in range(4): assert np.allclose(outputs[r], expected) def test_recv_async_simpy_runner(): """Run the async kernel through the real SimPy stack via the install_ipcq + launch path. """ import importlib from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology # Re-use the standard 8-PE bench skeleton but swap in the async kernel. topo = resolve_topology("topology.yaml") # Build a tiny inline bench module import types mod = types.ModuleType("inline_bench_async") from kernbench.policy.placement.dp import DPPolicy def run(torch): plan = torch.install_ipcq( algorithm="ring_allreduce_tcm", world_size_override=8, ) a = torch.zeros( (1, 8 * 8), dtype="f16", dp=DPPolicy( sip="replicate", cube="replicate", pe="column_wise", num_sips=1, num_cubes=1, ), name="async_in", ) store = torch.engine.memory_store base = a._handle.va_base or a._handle.shards[0].pa nbytes = 8 * 2 for r in range(8): store.write("hbm", base + r * nbytes, np.full((8,), float(r + 1), dtype=np.float16)) torch.launch("ring_allreduce_tcm", kernel_async_recv, a, 8) for r in range(8): result = store.read("hbm", base + r * nbytes, shape=(8,), dtype="f16") expected = float(sum(range(1, 9))) # 36 assert np.allclose(result, expected, rtol=1e-2, atol=1e-2), \ f"rank {r}: got {result}, expected {expected}" mod.run = run result = run_bench( topology=topo, bench_fn=mod.run, device=resolve_device("all"), engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True ), ) assert result.completion.ok