Files
kernbench2/src/kernbench/ccl/algorithms/intercube_allreduce.py
T
mukesh ad5f01ab13 Merge origin/master: combine single-cube fast path + center-root reduce
Conflict resolution:
- intercube_allreduce.py: kept origin's `if single_cube:` early-exit
  (TP launches kernel on one cube/rank → skip intra-SIP mesh and go
  direct to inter-SIP exchange) AND replaced the multi-cube body with
  the local center-root + bidirectional reduce/broadcast (8-hop
  critical path on 4×4 vs 12 with corner root).
- tests/{allreduce,pe2pe}_latency_plots/: kept the local move to
  docs/diagrams/; dropped origin's stale content edits to the old
  paths (regenerable derived artifacts).
- docs/diagrams/pe2pe_latency_plots/summary.csv: kept local
  (post-Phase-2 + center-root values).

Origin contributions retained as-is:
- pyproject.toml: matplotlib >= 3.7 dep.
- runtime_api/distributed.py: derive effective cube_w/h from tensor
  shard placement so single-cube TP paths get cube_w=cube_h=1.
- kernel_args() now accepts optional cube_w/cube_h kwargs.

Verified post-merge:
- test_intercube_root_center.py: 2/2 (center-root multi-cube path).
- test_tp_layers.py + test_tp_mlp.py: 10/10 (single-cube TP path).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 21:41:46 -07:00

249 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
Reduces across the 4×4 cube mesh within each SIP, then exchanges
between SIPs using the configured SIP topology, and broadcasts back.
Supported SIP topologies (selected via ``sip_topo_kind``):
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
"""
from __future__ import annotations
SIP_TOPO_RING = 0
SIP_TOPO_TORUS = 1
SIP_TOPO_MESH = 2
TOPO_NAME_TO_KIND = {
"ring_1d": SIP_TOPO_RING,
"torus_2d": SIP_TOPO_TORUS,
"mesh_2d": SIP_TOPO_TORUS,
"mesh_2d_no_wrap": SIP_TOPO_MESH,
}
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
return (n_elem, cube_w, cube_h, world_size)
def _inter_sip_ring(acc, n_sips, n_elem, tl):
current = acc
for _ in range(n_sips - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
# Row ring (global_E / global_W)
current = acc
for _ in range(sip_topo_w - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
# Col ring (global_S / global_N)
current = acc
for _ in range(sip_topo_h - 1):
tl.send(dir="global_S", src=current)
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
sip_row = sip_rank // sip_topo_w
sip_col = sip_rank % sip_topo_w
# Row reduce W → E
if sip_col == 0:
tl.send(dir="global_E", src=acc)
elif sip_col < sip_topo_w - 1:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_E", src=acc)
else:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Row broadcast E → W
if sip_col == sip_topo_w - 1:
tl.send(dir="global_W", src=acc)
elif sip_col > 0:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
tl.send(dir="global_W", src=acc)
else:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
# Col reduce N → S
if sip_row == 0:
tl.send(dir="global_S", src=acc)
elif sip_row < sip_topo_h - 1:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_S", src=acc)
else:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Col broadcast S → N
if sip_row == sip_topo_h - 1:
tl.send(dir="global_N", src=acc)
elif sip_row > 0:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
tl.send(dir="global_N", src=acc)
else:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
return acc
def allreduce_intercube_multidevice(
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
):
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
Root cube sits at the geometric center (cube_w//2, cube_h//2) and
each phase converges bidirectionally so the intra-SIP critical path
is ~half what a corner-root walk would be (e.g., 4×4 mesh: 4 hops
reduce + 4 hops broadcast vs 6+6 with corner root).
Args:
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
n_elem: f16 elements per cube tile.
cube_w: cube mesh width (columns).
cube_h: cube mesh height (rows).
n_sips: number of SIPs.
sip_rank: this SIP's rank (0-based).
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
tl: TLContext (auto-injected).
"""
cube_id = tl.program_id(axis=1)
row = cube_id // cube_w
col = cube_id % cube_w
nbytes = n_elem * 2
single_cube = (cube_w == 1 and cube_h == 1)
root_col = cube_w // 2
root_row = cube_h // 2
root_cube = root_row * cube_w + root_col
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
if single_cube:
# ── Single-cube mode: skip intra-SIP reduce, go directly to
# inter-SIP exchange (TP use case: one cube per rank). ──
if n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
else:
# ── Multi-cube mode: center-root bidirectional reduce
# + inter-SIP exchange + bidirectional broadcast ──
# Phase 1: row reduce — converge at col == root_col.
# Left half (col < root_col) walks W→E; right half (col > root_col)
# walks E→W; the root_col cube merges both sides.
if col == 0 and root_col > 0:
tl.send(dir="E", src=acc)
elif 0 < col < root_col:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="E", src=acc)
elif col == root_col:
if root_col > 0:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
if cube_w - 1 > root_col:
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
acc = acc + recv
elif root_col < col < cube_w - 1:
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="W", src=acc)
elif col == cube_w - 1 and cube_w - 1 > root_col:
tl.send(dir="W", src=acc)
# Phase 2: col reduce on col == root_col — converge at row == root_row.
if col == root_col:
if row == 0 and root_row > 0:
tl.send(dir="S", src=acc)
elif 0 < row < root_row:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
elif row == root_row:
if root_row > 0:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
if cube_h - 1 > root_row:
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
acc = acc + recv
elif root_row < row < cube_h - 1:
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="N", src=acc)
elif row == cube_h - 1 and cube_h - 1 > root_row:
tl.send(dir="N", src=acc)
# Phase 3: inter-SIP exchange on root cube.
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# Phase 4: col broadcast on col == root_col, outward from root_row.
if col == root_col:
if row == root_row:
if root_row > 0:
tl.send(dir="N", src=acc)
if cube_h - 1 > root_row:
tl.send(dir="S", src=acc)
elif row < root_row:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
if row > 0:
tl.send(dir="N", src=acc)
elif row > root_row:
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
if row < cube_h - 1:
tl.send(dir="S", src=acc)
# Phase 5: row broadcast outward from root_col.
if col == root_col:
if root_col > 0:
tl.send(dir="W", src=acc)
if cube_w - 1 > root_col:
tl.send(dir="E", src=acc)
elif col < root_col:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
if col > 0:
tl.send(dir="W", src=acc)
elif col > root_col:
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
if col < cube_w - 1:
tl.send(dir="E", src=acc)
tl.store(pe_addr, acc)
kernel = allreduce_intercube_multidevice