"""Intercube all-reduce kernel (pe0-only, same-lane across cubes). Reduces across the 4×4 cube mesh within each SIP, then exchanges between SIPs using the configured SIP topology, and broadcasts back. Supported SIP topologies (selected via ``sip_topo_kind``): 0 — ring_1d: global_E/global_W ring, n_sips-1 rounds 1 — torus_2d: row ring (global_E/W) + col ring (global_S/N) 2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast IPCQ wiring is handled by ``configure_sfr_intercube_multisip``. """ from __future__ import annotations SIP_TOPO_RING = 0 SIP_TOPO_TORUS = 1 SIP_TOPO_MESH = 2 TOPO_NAME_TO_KIND = { "ring_1d": SIP_TOPO_RING, "torus_2d": SIP_TOPO_TORUS, "mesh_2d": SIP_TOPO_TORUS, "mesh_2d_no_wrap": SIP_TOPO_MESH, } def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple: return (n_elem, cube_w, cube_h, world_size) def _inter_sip_ring(acc, n_sips, n_elem, tl): current = acc for _ in range(n_sips - 1): tl.send(dir="global_E", src=current) recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16") acc = acc + recv current = recv return acc def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl): # Row ring (global_E / global_W) current = acc for _ in range(sip_topo_w - 1): tl.send(dir="global_E", src=current) recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16") acc = acc + recv current = recv # Col ring (global_S / global_N) current = acc for _ in range(sip_topo_h - 1): tl.send(dir="global_S", src=current) recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16") acc = acc + recv current = recv return acc def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl): sip_row = sip_rank // sip_topo_w sip_col = sip_rank % sip_topo_w # Row reduce W → E if sip_col == 0: tl.send(dir="global_E", src=acc) elif sip_col < sip_topo_w - 1: recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="global_E", src=acc) else: recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16") acc = acc + recv # Row broadcast E → W if sip_col == sip_topo_w - 1: tl.send(dir="global_W", src=acc) elif sip_col > 0: acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16") tl.send(dir="global_W", src=acc) else: acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16") # Col reduce N → S if sip_row == 0: tl.send(dir="global_S", src=acc) elif sip_row < sip_topo_h - 1: recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="global_S", src=acc) else: recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16") acc = acc + recv # Col broadcast S → N if sip_row == sip_topo_h - 1: tl.send(dir="global_N", src=acc) elif sip_row > 0: acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16") tl.send(dir="global_N", src=acc) else: acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16") return acc def allreduce_intercube_multidevice( t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank, sip_topo_kind, sip_topo_w, sip_topo_h, tl, ): """Intercube all-reduce (pe0-only) with configurable SIP topology. Root cube sits at the geometric center (cube_w//2, cube_h//2) and each phase converges bidirectionally so the intra-SIP critical path is ~half what a corner-root walk would be (e.g., 4×4 mesh: 4 hops reduce + 4 hops broadcast vs 6+6 with corner root). Args: t_ptr: VA base of the row-wise-sharded tensor on this SIP. n_elem: f16 elements per cube tile. cube_w: cube mesh width (columns). cube_h: cube mesh height (rows). n_sips: number of SIPs. sip_rank: this SIP's rank (0-based). sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d. sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring). sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring). tl: TLContext (auto-injected). """ cube_id = tl.program_id(axis=1) row = cube_id // cube_w col = cube_id % cube_w nbytes = n_elem * 2 single_cube = (cube_w == 1 and cube_h == 1) root_col = cube_w // 2 root_row = cube_h // 2 root_cube = root_row * cube_w + root_col pe_addr = t_ptr + cube_id * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") if single_cube: # ── Single-cube mode: skip intra-SIP reduce, go directly to # inter-SIP exchange (TP use case: one cube per rank). ── if n_sips > 1: if sip_topo_kind == SIP_TOPO_RING: acc = _inter_sip_ring(acc, n_sips, n_elem, tl) elif sip_topo_kind == SIP_TOPO_TORUS: acc = _inter_sip_torus_2d( acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) elif sip_topo_kind == SIP_TOPO_MESH: acc = _inter_sip_mesh_2d( acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) else: # ── Multi-cube mode: center-root bidirectional reduce # + inter-SIP exchange + bidirectional broadcast ── # Phase 1: row reduce — converge at col == root_col. # Left half (col < root_col) walks W→E; right half (col > root_col) # walks E→W; the root_col cube merges both sides. if col == 0 and root_col > 0: tl.send(dir="E", src=acc) elif 0 < col < root_col: recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="E", src=acc) elif col == root_col: if root_col > 0: recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") acc = acc + recv if cube_w - 1 > root_col: recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16") acc = acc + recv elif root_col < col < cube_w - 1: recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="W", src=acc) elif col == cube_w - 1 and cube_w - 1 > root_col: tl.send(dir="W", src=acc) # Phase 2: col reduce on col == root_col — converge at row == root_row. if col == root_col: if row == 0 and root_row > 0: tl.send(dir="S", src=acc) elif 0 < row < root_row: recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="S", src=acc) elif row == root_row: if root_row > 0: recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") acc = acc + recv if cube_h - 1 > root_row: recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16") acc = acc + recv elif root_row < row < cube_h - 1: recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16") acc = acc + recv tl.send(dir="N", src=acc) elif row == cube_h - 1 and cube_h - 1 > root_row: tl.send(dir="N", src=acc) # Phase 3: inter-SIP exchange on root cube. if cube_id == root_cube and n_sips > 1: if sip_topo_kind == SIP_TOPO_RING: acc = _inter_sip_ring(acc, n_sips, n_elem, tl) elif sip_topo_kind == SIP_TOPO_TORUS: acc = _inter_sip_torus_2d( acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) elif sip_topo_kind == SIP_TOPO_MESH: acc = _inter_sip_mesh_2d( acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) # Phase 4: col broadcast on col == root_col, outward from root_row. if col == root_col: if row == root_row: if root_row > 0: tl.send(dir="N", src=acc) if cube_h - 1 > root_row: tl.send(dir="S", src=acc) elif row < root_row: acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") if row > 0: tl.send(dir="N", src=acc) elif row > root_row: acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16") if row < cube_h - 1: tl.send(dir="S", src=acc) # Phase 5: row broadcast outward from root_col. if col == root_col: if root_col > 0: tl.send(dir="W", src=acc) if cube_w - 1 > root_col: tl.send(dir="E", src=acc) elif col < root_col: acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") if col > 0: tl.send(dir="W", src=acc) elif col > root_col: acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16") if col < cube_w - 1: tl.send(dir="E", src=acc) tl.store(pe_addr, acc) kernel = allreduce_intercube_multidevice