"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023). Two-phase binary tree all-reduce: Phase 1 (reduce up): - leaf nodes send their value to ``parent`` - internal nodes recv from each child, sum, then send to ``parent`` - root accumulates child contributions; final acc holds global sum Phase 2 (broadcast down): - root sends acc to ``child_left`` and ``child_right`` (if present) - internal nodes recv from ``parent``, then forward to children - all ranks store the final acc to HBM Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the data flow so Phase 2 produces correct final HBM contents. The kernel deliberately avoids the store→reload→send pattern: math/recv handles are passed directly to the next send so PE_DMA snapshots a deterministic source addr that Phase 2 can replay. """ from __future__ import annotations def kernel_args(world_size: int, n_elem: int) -> tuple: """Return the positional kernel arguments for the ahbm backend.""" return (n_elem, world_size) def kernel(t_ptr, n_elem, world_size, tl): """Tree all-reduce. Args: t_ptr: HBM base address. n_elem: number of f16 elements per tile. world_size: total number of participating ranks (passed by host). tl: TLContext (ADR-0022). Global rank from program_id(0/1). """ local_pe = tl.program_id(axis=0) cube_id = tl.program_id(axis=1) pes_per_cube = tl.num_programs(axis=0) rank = cube_id * pes_per_cube + local_pe nbytes = n_elem * 2 pe_addr = t_ptr + rank * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") # Compute children/parent existence (matches tree_binary topology generator) has_parent = rank > 0 left = 2 * rank + 1 right = 2 * rank + 2 has_left = left < world_size has_right = right < world_size # ── Phase 1: reduce up ── if has_left: recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16") acc = acc + recv if has_right: recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16") acc = acc + recv if has_parent: # Send the math/load handle directly — its addr is either the # original HBM tile (leaf) or the PE-local scratch where the # accumulator lives. Phase 2 ipcq_copy replays from the same addr. tl.send(dir="parent", src=acc) # ── Phase 2: broadcast down ── if has_parent: # Replace acc with the value broadcast from the parent (the global # sum). The recv handle points at the parent-direction TCM slot. acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16") if has_left: tl.send(dir="child_left", src=acc) if has_right: tl.send(dir="child_right", src=acc) # Final store to HBM for the bench's verification path. tl.store(pe_addr, acc)