Intercube allreduce: center root + bidirectional reduce

Move the algorithmic root cube from the corner (cube_w-1,
cube_h-1) to the geometric center (cube_w//2, cube_h//2) and
have each phase converge bidirectionally so the intra-SIP
critical path drops from ~12 hops to ~8 hops on a 4×4 mesh
(left half W→E + right half E→W in row reduce; top half N→S +
bottom half S→N in col reduce; mirrored on broadcast).

Result on torus_2d 6 SIPs at 96 KB / PE on TCM:
  before (corner root)  : 22.0 µs
  after  (center root)  : 17.2 µs   (−22%)

Same shape on ring_1d (−7%) and mesh_2d_no_wrap (−12%); also
holds across SRAM and HBM (~−20% each).

Phase 1 test (test_intercube_root_center.py) asserts the
torus_2d 96 KB latency drops below 20.5 µs and that all 96
cubes still validate (correctness preserved).

Plot updates:
- overview.png: replace constant 10.6 µs theoretical line with
  user-supplied hand-derived curve (per-cube packet count =
  bytes_per_pe × 8 PEs ÷ 128 B; 1346 ns startup + 1.20 ns/pkt).
- All summary.csv numbers and per-topology PNGs regenerated.
- pe2pe_latency_plots and ipcq diagram emitter PNGs refreshed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 21:28:58 -07:00
parent 84a1325e5c
commit 1c5752a9ec
16 changed files with 324 additions and 157 deletions
@@ -111,6 +111,11 @@ def allreduce_intercube_multidevice(
):
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
Root cube sits at the geometric center (cube_w//2, cube_h//2) and
each phase converges bidirectionally so the intra-SIP critical path
is ~half what a corner-root walk would be (e.g., 4×4 mesh: 4 hops
reduce + 4 hops broadcast vs 6+6 with corner root).
Args:
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
n_elem: f16 elements per cube tile.
@@ -128,34 +133,59 @@ def allreduce_intercube_multidevice(
col = cube_id % cube_w
nbytes = n_elem * 2
root_col = cube_w // 2
root_row = cube_h // 2
root_cube = root_row * cube_w + root_col
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# ── Phase 1: row reduce W → E ──
if col == 0:
# ── Phase 1: row reduce — converge at col == root_col ──
# Left half (col < root_col) walks W→E; right half (col > root_col)
# walks E→W; the root_col cube merges both sides.
if col == 0 and root_col > 0:
tl.send(dir="E", src=acc)
elif col < cube_w - 1:
elif 0 < col < root_col:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="E", src=acc)
else:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
elif col == root_col:
if root_col > 0:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
if cube_w - 1 > root_col:
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
acc = acc + recv
elif root_col < col < cube_w - 1:
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="W", src=acc)
elif col == cube_w - 1 and cube_w - 1 > root_col:
tl.send(dir="W", src=acc)
# ── Phase 2: col reduce N → S on rightmost column ──
if col == cube_w - 1:
if row == 0:
# ── Phase 2: col reduce on col == root_col — converge at row == root_row ──
if col == root_col:
if row == 0 and root_row > 0:
tl.send(dir="S", src=acc)
elif row < cube_h - 1:
elif 0 < row < root_row:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
else:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
elif row == root_row:
if root_row > 0:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
if cube_h - 1 > root_row:
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
acc = acc + recv
elif root_row < row < cube_h - 1:
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="N", src=acc)
elif row == cube_h - 1 and cube_h - 1 > root_row:
tl.send(dir="N", src=acc)
# ── Phase 3: inter-SIP exchange on root cube ──
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
@@ -164,24 +194,36 @@ def allreduce_intercube_multidevice(
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# ── Phase 4: col broadcast S → N on rightmost column ──
if col == cube_w - 1:
if row == cube_h - 1:
tl.send(dir="N", src=acc)
elif row > 0:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
tl.send(dir="N", src=acc)
else:
# ── Phase 4: col broadcast on col == root_col, outward from root_row ──
if col == root_col:
if row == root_row:
if root_row > 0:
tl.send(dir="N", src=acc)
if cube_h - 1 > root_row:
tl.send(dir="S", src=acc)
elif row < root_row:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
if row > 0:
tl.send(dir="N", src=acc)
elif row > root_row:
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
if row < cube_h - 1:
tl.send(dir="S", src=acc)
# ── Phase 5: row broadcast E → W ──
if col == cube_w - 1:
tl.send(dir="W", src=acc)
elif col > 0:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.send(dir="W", src=acc)
else:
# ── Phase 5: row broadcast outward from root_col ──
if col == root_col:
if root_col > 0:
tl.send(dir="W", src=acc)
if cube_w - 1 > root_col:
tl.send(dir="E", src=acc)
elif col < root_col:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
if col > 0:
tl.send(dir="W", src=acc)
elif col > root_col:
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
if col < cube_w - 1:
tl.send(dir="E", src=acc)
tl.store(pe_addr, acc)