"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023). Algorithm: 1D ring of N PEs, each PE starts with one tile of data. After ``world_size - 1`` rounds, every PE's accumulator holds the sum of all PE tiles. Strategy -------- Each PE starts with its own tile in HBM. The kernel: 1. Loads the local tile into a TensorHandle (the accumulator). 2. In each of ``world_size - 1`` rounds: - Sends the current accumulator/recv slot to the E neighbor. - Receives a tile from the W neighbor — the recv handle points into the per-direction TCM slot. - Adds the received tile to the accumulator using the TensorHandle operator overload, which dispatches to ``MathCmd`` (PE_MATH). 3. Stores the final accumulator back to HBM via tl.store. The store is recorded in op_log with both src and dst, so Phase 2 will copy the replayed math result from PE-local scratch into HBM. ADR-0020 D3 split: Phase 1 simulates timing only — math results are not yet computed, so the accumulator data flowing through Phase 1 may be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write in stable t_start order, producing correct final HBM contents. """ from __future__ import annotations def kernel_args(world_size: int, n_elem: int) -> tuple: """Return the positional kernel arguments for the ahbm backend. Ring all-reduce takes (n_elem, world_size) after the tensor pointer. """ return (n_elem, world_size) def kernel(t_ptr, n_elem, world_size, tl): """Ring all-reduce. Args: t_ptr: HBM base address of the column-sharded tensor — all PEs share this base. The per-PE slice lives at ``t_ptr + global_rank * n_elem * 2``. n_elem: number of f16 elements per tile. world_size: total number of participating ranks (passed by host). tl: TLContext (auto-injected, ADR-0022). The kernel derives the global rank from ``program_id(axis=0)`` (local PE) and ``program_id(axis=1)`` (cube id): rank = cube_id * pes_per_cube + local_pe """ local_pe = tl.program_id(axis=0) cube_id = tl.program_id(axis=1) pes_per_cube = tl.num_programs(axis=0) rank = cube_id * pes_per_cube + local_pe nbytes = n_elem * 2 # f16 # Each PE reads from its own slice of the shared base address pe_addr = t_ptr + rank * nbytes # Load the local tile — handle points at HBM[pe_addr]. acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") # The ring forwards each received tile to the next neighbor (NOT the # cumulative accumulator), so every rank's tile passes through every # rank exactly once. The accumulator sums the new arrival each round. current = acc for _step in range(world_size - 1): tl.send(dir="E", src=current) recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") # TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real # numpy in Phase 2 via DataExecutor). The result handle lives at # an auto-allocated PE-local scratch addr. acc = acc + recv current = recv # forward W's tile to E next round # Final result back to this PE's HBM slice. Op_log captures the # source (scratch addr) and dst (HBM slice) so Phase 2 copies the # accumulated value into HBM for verification. tl.store(pe_addr, acc)