"""Tests for recv_mode='copy_to_dst' (ADR-0023 D9.5).""" from __future__ import annotations import numpy as np def test_recv_copy_to_dst_via_simpy_runner(): """Run a kernel that uses tl.recv(..., dst_addr=..., dst_space=...). Verify the data is moved to the dst location after recv. """ import importlib from kernbench.policy.placement.dp import DPPolicy from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology from kernbench.common.pe_commands import TensorHandle def kernel(t_ptr, n_elem, dst_buf_addr, tl): rank = tl.program_id(axis=0) ws = tl.num_programs(axis=0) nbytes = n_elem * 2 # Each PE sends own data, then recv into a custom dst slot current = TensorHandle( id="loc", addr=t_ptr + rank * nbytes, shape=(n_elem,), dtype="f16", nbytes=nbytes, data=None, space="hbm", ) tl.send(dir="E", src=current) # copy_to_dst: move into a per-rank scratch HBM addr recv = tl.recv( dir="W", shape=(n_elem,), dtype="f16", dst_addr=dst_buf_addr + rank * nbytes, dst_space="hbm", ) # Sanity: recv handle should now point to our dst addr assert recv.addr == dst_buf_addr + rank * nbytes assert recv.space == "hbm" topo = resolve_topology("topology.yaml") def run(torch): plan = torch.install_ipcq( algorithm="ring_allreduce_tcm", world_size_override=8, ) a = torch.zeros( (1, 8 * 8), dtype="f16", dp=DPPolicy( sip="replicate", cube="replicate", pe="column_wise", num_sips=1, num_cubes=1, ), name="copy_in", ) store = torch.engine.memory_store base = a._handle.va_base or a._handle.shards[0].pa nbytes = 8 * 2 for r in range(8): store.write("hbm", base + r * nbytes, np.full((8,), float(r + 1), dtype=np.float16)) # Use a separate dst region (synthetic addresses) dst_buf = 0xC0FFEE_0000 torch.launch("ring_allreduce_tcm", kernel, a, 8, dst_buf) # After the kernel, dst_buf + r*16 should contain rank (r-1)%8's data for r in range(8): arr = store.read("hbm", dst_buf + r * nbytes, shape=(8,), dtype="f16") expected = float(((r - 1) % 8) + 1) assert np.allclose(arr, expected), f"rank {r}: got {arr}, expected {expected}" result = run_bench( topology=topo, bench_fn=run, device=resolve_device("all"), engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True ), ) assert result.completion.ok