Intercube allreduce: pe0 cube-mesh reduce + multi-SIP ring/torus/mesh

New intercube allreduce kernel replacing the old flat ring algorithms.
Reduces across the 4x4 cube mesh within each SIP (pe0-only, same-lane),
then inter-SIP exchange on root cube, then broadcast back. Supports
ring_1d, torus_2d, and mesh_2d_no_wrap SIP topologies driven by
topology.yaml. Integrated with dist.init_process_group / dist.all_reduce.

New files:
- src/kernbench/ccl/algorithms/intercube_allreduce.py (kernel)
- src/kernbench/ccl/sfr_config.py (configure_sfr_intercube_multisip)
- tests/test_allreduce_multidevice.py (config-driven, 3 topologies)
- tests/test_distributed_intercube_allreduce.py (full distributed path)
- tests/test_intercube_sfr_config.py (SFR wiring verification)

Modified:
- distributed.py: AhbmCCLBackend uses configure_sfr_intercube_multisip
- topologies.py: added torus_2d, mesh_2d_no_wrap
- install.py: global_E/W/N/S in _OPPOSITE_DIR
- topology.yaml: added system.sips.topology
- ccl.yaml: single intercube_allreduce algorithm
- benches/ccl_allreduce.py: row_wise cube-mesh tensor layout

Removed old flat-ring algorithms and their tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-16 17:33:42 -07:00
parent cfc2d74ec4
commit 1d8b9401e5
30 changed files with 876 additions and 2892 deletions
@@ -0,0 +1,189 @@
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
Reduces across the 4×4 cube mesh within each SIP, then exchanges
between SIPs using the configured SIP topology, and broadcasts back.
Supported SIP topologies (selected via ``sip_topo_kind``):
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
"""
from __future__ import annotations
SIP_TOPO_RING = 0
SIP_TOPO_TORUS = 1
SIP_TOPO_MESH = 2
TOPO_NAME_TO_KIND = {
"ring_1d": SIP_TOPO_RING,
"torus_2d": SIP_TOPO_TORUS,
"mesh_2d": SIP_TOPO_TORUS,
"mesh_2d_no_wrap": SIP_TOPO_MESH,
}
def kernel_args(world_size: int, n_elem: int) -> tuple:
cube_w = 4
cube_h = 4
return (n_elem, cube_w, cube_h, world_size)
def _inter_sip_ring(acc, n_sips, n_elem, tl):
current = acc
for _ in range(n_sips - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
# Row ring (global_E / global_W)
current = acc
for _ in range(sip_topo_w - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
# Col ring (global_S / global_N)
current = acc
for _ in range(sip_topo_h - 1):
tl.send(dir="global_S", src=current)
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
sip_row = sip_rank // sip_topo_w
sip_col = sip_rank % sip_topo_w
# Row reduce W → E
if sip_col == 0:
tl.send(dir="global_E", src=acc)
elif sip_col < sip_topo_w - 1:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_E", src=acc)
else:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Row broadcast E → W
if sip_col == sip_topo_w - 1:
tl.send(dir="global_W", src=acc)
elif sip_col > 0:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
tl.send(dir="global_W", src=acc)
else:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
# Col reduce N → S
if sip_row == 0:
tl.send(dir="global_S", src=acc)
elif sip_row < sip_topo_h - 1:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_S", src=acc)
else:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Col broadcast S → N
if sip_row == sip_topo_h - 1:
tl.send(dir="global_N", src=acc)
elif sip_row > 0:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
tl.send(dir="global_N", src=acc)
else:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
return acc
def allreduce_intercube_multidevice(
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
):
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
Args:
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
n_elem: f16 elements per cube tile.
cube_w: cube mesh width (columns).
cube_h: cube mesh height (rows).
n_sips: number of SIPs.
sip_rank: this SIP's rank (0-based).
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
tl: TLContext (auto-injected).
"""
cube_id = tl.program_id(axis=1)
row = cube_id // cube_w
col = cube_id % cube_w
nbytes = n_elem * 2
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# ── Phase 1: row reduce W → E ──
if col == 0:
tl.send(dir="E", src=acc)
elif col < cube_w - 1:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="E", src=acc)
else:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 2: col reduce N → S on rightmost column ──
if col == cube_w - 1:
if row == 0:
tl.send(dir="S", src=acc)
elif row < cube_h - 1:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
else:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 3: inter-SIP exchange on root cube ──
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# ── Phase 4: col broadcast S → N on rightmost column ──
if col == cube_w - 1:
if row == cube_h - 1:
tl.send(dir="N", src=acc)
elif row > 0:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
tl.send(dir="N", src=acc)
else:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
# ── Phase 5: row broadcast E → W ──
if col == cube_w - 1:
tl.send(dir="W", src=acc)
elif col > 0:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.send(dir="W", src=acc)
else:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.store(pe_addr, acc)
kernel = allreduce_intercube_multidevice