"""Hierarchical all-reduce kernel (ADR-0023). 3-level reduce + broadcast exploiting the topology hierarchy: Level 1 — Intra-cube (8 PEs, E/W, fastest link): Bidirectional ring reduce to PE 0. Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe): Bidirectional ring reduce of PE 0s to cube 0 PE 0. Level 3 — Inter-SIP (2 SIPs, parent): Pair exchange between SIP representatives. Broadcast — Reverse chain through levels 2 and 1. Bidirectional reduce: left-half sends toward node 0 via dir_dec, right-half sends via dir_inc (wrapping). Representative receives from both sides. Rounds per level = ceil((group_size - 1) / 2). Direction pairing (ring): Send via dir_dec at PE K → recv via dir_inc at PE K-1 Send via dir_inc at PE K → recv via dir_dec at PE K+1 """ from __future__ import annotations def kernel_args(world_size: int, n_elem: int) -> tuple: """Positional kernel args for the ahbm backend.""" pes_per_cube = 8 num_sips = max(1, world_size // 128) if world_size > 128 else 1 cubes_per_sip = world_size // (pes_per_cube * num_sips) return (n_elem, pes_per_cube, cubes_per_sip, num_sips) def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict: """Build the 3-level neighbor map.""" pes_per_cube = 8 num_sips = max(1, world_size // 128) if world_size > 128 else 1 cubes_per_sip = world_size // (pes_per_cube * num_sips) pe_id = rank % pes_per_cube cube_global = rank // pes_per_cube sip_id = cube_global // cubes_per_sip local_cube_id = cube_global % cubes_per_sip result = {} # Level 1: intra-cube ring (E/W, all PEs) cube_base = cube_global * pes_per_cube result["E"] = cube_base + (pe_id + 1) % pes_per_cube result["W"] = cube_base + (pe_id - 1) % pes_per_cube # Level 2: inter-cube ring (N/S, PE 0 only) if pe_id == 0 and cubes_per_sip > 1: sip_base = sip_id * cubes_per_sip * pes_per_cube next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube result["N"] = next_cube_pe0 result["S"] = prev_cube_pe0 # Level 3: inter-SIP (parent, PE 0 cube 0 only) if pe_id == 0 and local_cube_id == 0 and num_sips > 1: other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube result["parent"] = other_sip_pe0 return result def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype): """Bidirectional ring reduce to node 0. Left half (1..half): chain reduces via dir_dec (toward lower IDs). Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec). Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0). Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc). Node 0: recvs left sum via dir_inc, right sum via dir_dec. Direction pairing: send dir_dec at K → recv dir_inc at K-1. send dir_inc at K → recv dir_dec at K+1. """ if group_size <= 1: return acc half = group_size // 2 if my_id == 0: # Representative: recv left-half sum via dir_inc (from PE 1) recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype) acc = acc + recv # Recv right-half sum via dir_dec (from PE N-1, wrapped) if group_size - half - 1 >= 1: recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype) acc = acc + recv elif my_id <= half: # Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec if my_id < half: # not the far-edge recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype) acc = acc + recv tl.send(dir=dir_dec, src=acc) else: # Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc if my_id > half + 1: # not the near-edge recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype) acc = acc + recv tl.send(dir=dir_inc, src=acc) return acc def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype): """Linear chain broadcast from node 0 via dir_inc. Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit from the ring pairing), stores, sends via dir_inc → node 2. Etc. Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1. """ if group_size <= 1: return acc # In ring pairing: send via dir_inc at K → recv via dir_dec at K+1. # dir_dec is the "other" direction. We infer it from the ring: # if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S". _recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc) if my_id == 0: tl.send(dir=dir_inc, src=acc) else: acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype) if my_id < group_size - 1: tl.send(dir=dir_inc, src=acc) return acc def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl): """Hierarchical all-reduce. Args: t_ptr: HBM base address (column-sharded VA). n_elem: f16 elements per tile. pes_per_cube: PEs per cube (typically 8). cubes_per_sip: cubes per SIP (typically 16). num_sips: number of SIPs (typically 2). tl: TLContext (auto-injected). """ pe_id = tl.program_id(axis=0) cube_global = tl.program_id(axis=1) sip_id = cube_global // cubes_per_sip local_cube_id = cube_global % cubes_per_sip rank = cube_global * pes_per_cube + pe_id nbytes = n_elem * 2 pe_addr = t_ptr + rank * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") shape = (n_elem,) dtype = "f16" # ── Level 1: intra-cube bidirectional reduce to PE 0 ── acc = _bidir_reduce( tl, acc, my_id=pe_id, group_size=pes_per_cube, dir_inc="E", dir_dec="W", shape=shape, dtype=dtype, ) # ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ── if pe_id == 0 and cubes_per_sip > 1: acc = _bidir_reduce( tl, acc, my_id=local_cube_id, group_size=cubes_per_sip, dir_inc="N", dir_dec="S", shape=shape, dtype=dtype, ) # ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ── if pe_id == 0 and local_cube_id == 0 and num_sips > 1: tl.send(dir="parent", src=acc) recv = tl.recv(dir="parent", shape=shape, dtype=dtype) acc = acc + recv # ── Broadcast back ── # Level 2: cube 0 PE 0 → all PE 0s via chain if pe_id == 0 and cubes_per_sip > 1: acc = _chain_broadcast( tl, acc, my_id=local_cube_id, group_size=cubes_per_sip, dir_inc="N", shape=shape, dtype=dtype, ) # Level 1: PE 0 → all PEs in cube via chain acc = _chain_broadcast( tl, acc, my_id=pe_id, group_size=pes_per_cube, dir_inc="E", shape=shape, dtype=dtype, ) tl.store(pe_addr, acc)