ad5f01ab13
Conflict resolution:
- intercube_allreduce.py: kept origin's `if single_cube:` early-exit
(TP launches kernel on one cube/rank → skip intra-SIP mesh and go
direct to inter-SIP exchange) AND replaced the multi-cube body with
the local center-root + bidirectional reduce/broadcast (8-hop
critical path on 4×4 vs 12 with corner root).
- tests/{allreduce,pe2pe}_latency_plots/: kept the local move to
docs/diagrams/; dropped origin's stale content edits to the old
paths (regenerable derived artifacts).
- docs/diagrams/pe2pe_latency_plots/summary.csv: kept local
(post-Phase-2 + center-root values).
Origin contributions retained as-is:
- pyproject.toml: matplotlib >= 3.7 dep.
- runtime_api/distributed.py: derive effective cube_w/h from tensor
shard placement so single-cube TP paths get cube_w=cube_h=1.
- kernel_args() now accepts optional cube_w/cube_h kwargs.
Verified post-merge:
- test_intercube_root_center.py: 2/2 (center-root multi-cube path).
- test_tp_layers.py + test_tp_mlp.py: 10/10 (single-cube TP path).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
249 lines
9.1 KiB
Python
249 lines
9.1 KiB
Python
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
|
||
|
||
Reduces across the 4×4 cube mesh within each SIP, then exchanges
|
||
between SIPs using the configured SIP topology, and broadcasts back.
|
||
|
||
Supported SIP topologies (selected via ``sip_topo_kind``):
|
||
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
|
||
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
|
||
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
|
||
|
||
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
SIP_TOPO_RING = 0
|
||
SIP_TOPO_TORUS = 1
|
||
SIP_TOPO_MESH = 2
|
||
|
||
TOPO_NAME_TO_KIND = {
|
||
"ring_1d": SIP_TOPO_RING,
|
||
"torus_2d": SIP_TOPO_TORUS,
|
||
"mesh_2d": SIP_TOPO_TORUS,
|
||
"mesh_2d_no_wrap": SIP_TOPO_MESH,
|
||
}
|
||
|
||
|
||
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
|
||
return (n_elem, cube_w, cube_h, world_size)
|
||
|
||
|
||
def _inter_sip_ring(acc, n_sips, n_elem, tl):
|
||
current = acc
|
||
for _ in range(n_sips - 1):
|
||
tl.send(dir="global_E", src=current)
|
||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
current = recv
|
||
return acc
|
||
|
||
|
||
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||
# Row ring (global_E / global_W)
|
||
current = acc
|
||
for _ in range(sip_topo_w - 1):
|
||
tl.send(dir="global_E", src=current)
|
||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
current = recv
|
||
# Col ring (global_S / global_N)
|
||
current = acc
|
||
for _ in range(sip_topo_h - 1):
|
||
tl.send(dir="global_S", src=current)
|
||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
current = recv
|
||
return acc
|
||
|
||
|
||
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||
sip_row = sip_rank // sip_topo_w
|
||
sip_col = sip_rank % sip_topo_w
|
||
|
||
# Row reduce W → E
|
||
if sip_col == 0:
|
||
tl.send(dir="global_E", src=acc)
|
||
elif sip_col < sip_topo_w - 1:
|
||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="global_E", src=acc)
|
||
else:
|
||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
|
||
# Row broadcast E → W
|
||
if sip_col == sip_topo_w - 1:
|
||
tl.send(dir="global_W", src=acc)
|
||
elif sip_col > 0:
|
||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||
tl.send(dir="global_W", src=acc)
|
||
else:
|
||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||
|
||
# Col reduce N → S
|
||
if sip_row == 0:
|
||
tl.send(dir="global_S", src=acc)
|
||
elif sip_row < sip_topo_h - 1:
|
||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="global_S", src=acc)
|
||
else:
|
||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
|
||
# Col broadcast S → N
|
||
if sip_row == sip_topo_h - 1:
|
||
tl.send(dir="global_N", src=acc)
|
||
elif sip_row > 0:
|
||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||
tl.send(dir="global_N", src=acc)
|
||
else:
|
||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||
|
||
return acc
|
||
|
||
|
||
def allreduce_intercube_multidevice(
|
||
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
|
||
):
|
||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||
|
||
Root cube sits at the geometric center (cube_w//2, cube_h//2) and
|
||
each phase converges bidirectionally so the intra-SIP critical path
|
||
is ~half what a corner-root walk would be (e.g., 4×4 mesh: 4 hops
|
||
reduce + 4 hops broadcast vs 6+6 with corner root).
|
||
|
||
Args:
|
||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||
n_elem: f16 elements per cube tile.
|
||
cube_w: cube mesh width (columns).
|
||
cube_h: cube mesh height (rows).
|
||
n_sips: number of SIPs.
|
||
sip_rank: this SIP's rank (0-based).
|
||
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
|
||
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
|
||
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
|
||
tl: TLContext (auto-injected).
|
||
"""
|
||
cube_id = tl.program_id(axis=1)
|
||
row = cube_id // cube_w
|
||
col = cube_id % cube_w
|
||
nbytes = n_elem * 2
|
||
single_cube = (cube_w == 1 and cube_h == 1)
|
||
|
||
root_col = cube_w // 2
|
||
root_row = cube_h // 2
|
||
root_cube = root_row * cube_w + root_col
|
||
|
||
pe_addr = t_ptr + cube_id * nbytes
|
||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||
|
||
if single_cube:
|
||
# ── Single-cube mode: skip intra-SIP reduce, go directly to
|
||
# inter-SIP exchange (TP use case: one cube per rank). ──
|
||
if n_sips > 1:
|
||
if sip_topo_kind == SIP_TOPO_RING:
|
||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||
acc = _inter_sip_torus_2d(
|
||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||
acc = _inter_sip_mesh_2d(
|
||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||
else:
|
||
# ── Multi-cube mode: center-root bidirectional reduce
|
||
# + inter-SIP exchange + bidirectional broadcast ──
|
||
|
||
# Phase 1: row reduce — converge at col == root_col.
|
||
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
||
# walks E→W; the root_col cube merges both sides.
|
||
if col == 0 and root_col > 0:
|
||
tl.send(dir="E", src=acc)
|
||
elif 0 < col < root_col:
|
||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="E", src=acc)
|
||
elif col == root_col:
|
||
if root_col > 0:
|
||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
if cube_w - 1 > root_col:
|
||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
elif root_col < col < cube_w - 1:
|
||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="W", src=acc)
|
||
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
||
tl.send(dir="W", src=acc)
|
||
|
||
# Phase 2: col reduce on col == root_col — converge at row == root_row.
|
||
if col == root_col:
|
||
if row == 0 and root_row > 0:
|
||
tl.send(dir="S", src=acc)
|
||
elif 0 < row < root_row:
|
||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="S", src=acc)
|
||
elif row == root_row:
|
||
if root_row > 0:
|
||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
if cube_h - 1 > root_row:
|
||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
elif root_row < row < cube_h - 1:
|
||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||
acc = acc + recv
|
||
tl.send(dir="N", src=acc)
|
||
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
||
tl.send(dir="N", src=acc)
|
||
|
||
# Phase 3: inter-SIP exchange on root cube.
|
||
if cube_id == root_cube and n_sips > 1:
|
||
if sip_topo_kind == SIP_TOPO_RING:
|
||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||
acc = _inter_sip_torus_2d(
|
||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||
acc = _inter_sip_mesh_2d(
|
||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||
|
||
# Phase 4: col broadcast on col == root_col, outward from root_row.
|
||
if col == root_col:
|
||
if row == root_row:
|
||
if root_row > 0:
|
||
tl.send(dir="N", src=acc)
|
||
if cube_h - 1 > root_row:
|
||
tl.send(dir="S", src=acc)
|
||
elif row < root_row:
|
||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||
if row > 0:
|
||
tl.send(dir="N", src=acc)
|
||
elif row > root_row:
|
||
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||
if row < cube_h - 1:
|
||
tl.send(dir="S", src=acc)
|
||
|
||
# Phase 5: row broadcast outward from root_col.
|
||
if col == root_col:
|
||
if root_col > 0:
|
||
tl.send(dir="W", src=acc)
|
||
if cube_w - 1 > root_col:
|
||
tl.send(dir="E", src=acc)
|
||
elif col < root_col:
|
||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||
if col > 0:
|
||
tl.send(dir="W", src=acc)
|
||
elif col > root_col:
|
||
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||
if col < cube_w - 1:
|
||
tl.send(dir="E", src=acc)
|
||
|
||
tl.store(pe_addr, acc)
|
||
|
||
|
||
kernel = allreduce_intercube_multidevice
|