Intercube allreduce: pe0 cube-mesh reduce + multi-SIP ring/torus/mesh
New intercube allreduce kernel replacing the old flat ring algorithms. Reduces across the 4x4 cube mesh within each SIP (pe0-only, same-lane), then inter-SIP exchange on root cube, then broadcast back. Supports ring_1d, torus_2d, and mesh_2d_no_wrap SIP topologies driven by topology.yaml. Integrated with dist.init_process_group / dist.all_reduce. New files: - src/kernbench/ccl/algorithms/intercube_allreduce.py (kernel) - src/kernbench/ccl/sfr_config.py (configure_sfr_intercube_multisip) - tests/test_allreduce_multidevice.py (config-driven, 3 topologies) - tests/test_distributed_intercube_allreduce.py (full distributed path) - tests/test_intercube_sfr_config.py (SFR wiring verification) Modified: - distributed.py: AhbmCCLBackend uses configure_sfr_intercube_multisip - topologies.py: added torus_2d, mesh_2d_no_wrap - install.py: global_E/W/N/S in _OPPOSITE_DIR - topology.yaml: added system.sips.topology - ccl.yaml: single intercube_allreduce algorithm - benches/ccl_allreduce.py: row_wise cube-mesh tensor layout Removed old flat-ring algorithms and their tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
|
||||
|
||||
Reduces across the 4×4 cube mesh within each SIP, then exchanges
|
||||
between SIPs using the configured SIP topology, and broadcasts back.
|
||||
|
||||
Supported SIP topologies (selected via ``sip_topo_kind``):
|
||||
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
|
||||
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
|
||||
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
|
||||
|
||||
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
SIP_TOPO_RING = 0
|
||||
SIP_TOPO_TORUS = 1
|
||||
SIP_TOPO_MESH = 2
|
||||
|
||||
TOPO_NAME_TO_KIND = {
|
||||
"ring_1d": SIP_TOPO_RING,
|
||||
"torus_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d_no_wrap": SIP_TOPO_MESH,
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
def _inter_sip_ring(acc, n_sips, n_elem, tl):
|
||||
current = acc
|
||||
for _ in range(n_sips - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
# Row ring (global_E / global_W)
|
||||
current = acc
|
||||
for _ in range(sip_topo_w - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
# Col ring (global_S / global_N)
|
||||
current = acc
|
||||
for _ in range(sip_topo_h - 1):
|
||||
tl.send(dir="global_S", src=current)
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
sip_row = sip_rank // sip_topo_w
|
||||
sip_col = sip_rank % sip_topo_w
|
||||
|
||||
# Row reduce W → E
|
||||
if sip_col == 0:
|
||||
tl.send(dir="global_E", src=acc)
|
||||
elif sip_col < sip_topo_w - 1:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Row broadcast E → W
|
||||
if sip_col == sip_topo_w - 1:
|
||||
tl.send(dir="global_W", src=acc)
|
||||
elif sip_col > 0:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Col reduce N → S
|
||||
if sip_row == 0:
|
||||
tl.send(dir="global_S", src=acc)
|
||||
elif sip_row < sip_topo_h - 1:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Col broadcast S → N
|
||||
if sip_row == sip_topo_h - 1:
|
||||
tl.send(dir="global_N", src=acc)
|
||||
elif sip_row > 0:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def allreduce_intercube_multidevice(
|
||||
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
|
||||
):
|
||||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||||
|
||||
Args:
|
||||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||||
n_elem: f16 elements per cube tile.
|
||||
cube_w: cube mesh width (columns).
|
||||
cube_h: cube mesh height (rows).
|
||||
n_sips: number of SIPs.
|
||||
sip_rank: this SIP's rank (0-based).
|
||||
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
|
||||
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
|
||||
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
cube_id = tl.program_id(axis=1)
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
kernel = allreduce_intercube_multidevice
|
||||
Reference in New Issue
Block a user