Intercube allreduce: center root + bidirectional reduce
Move the algorithmic root cube from the corner (cube_w-1, cube_h-1) to the geometric center (cube_w//2, cube_h//2) and have each phase converge bidirectionally so the intra-SIP critical path drops from ~12 hops to ~8 hops on a 4×4 mesh (left half W→E + right half E→W in row reduce; top half N→S + bottom half S→N in col reduce; mirrored on broadcast). Result on torus_2d 6 SIPs at 96 KB / PE on TCM: before (corner root) : 22.0 µs after (center root) : 17.2 µs (−22%) Same shape on ring_1d (−7%) and mesh_2d_no_wrap (−12%); also holds across SRAM and HBM (~−20% each). Phase 1 test (test_intercube_root_center.py) asserts the torus_2d 96 KB latency drops below 20.5 µs and that all 96 cubes still validate (correctness preserved). Plot updates: - overview.png: replace constant 10.6 µs theoretical line with user-supplied hand-derived curve (per-cube packet count = bytes_per_pe × 8 PEs ÷ 128 B; 1346 ns startup + 1.20 ns/pkt). - All summary.csv numbers and per-topology PNGs regenerated. - pe2pe_latency_plots and ipcq diagram emitter PNGs refreshed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -111,6 +111,11 @@ def allreduce_intercube_multidevice(
|
||||
):
|
||||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||||
|
||||
Root cube sits at the geometric center (cube_w//2, cube_h//2) and
|
||||
each phase converges bidirectionally so the intra-SIP critical path
|
||||
is ~half what a corner-root walk would be (e.g., 4×4 mesh: 4 hops
|
||||
reduce + 4 hops broadcast vs 6+6 with corner root).
|
||||
|
||||
Args:
|
||||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||||
n_elem: f16 elements per cube tile.
|
||||
@@ -128,34 +133,59 @@ def allreduce_intercube_multidevice(
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
|
||||
root_col = cube_w // 2
|
||||
root_row = cube_h // 2
|
||||
root_cube = root_row * cube_w + root_col
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
# ── Phase 1: row reduce — converge at col == root_col ──
|
||||
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
||||
# walks E→W; the root_col cube merges both sides.
|
||||
if col == 0 and root_col > 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
elif 0 < col < root_col:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
elif col == root_col:
|
||||
if root_col > 0:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_w - 1 > root_col:
|
||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_col < col < cube_w - 1:
|
||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
||||
tl.send(dir="W", src=acc)
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
# ── Phase 2: col reduce on col == root_col — converge at row == root_row ──
|
||||
if col == root_col:
|
||||
if row == 0 and root_row > 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
elif 0 < row < root_row:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
elif row == root_row:
|
||||
if root_row > 0:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_h - 1 > root_row:
|
||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_row < row < cube_h - 1:
|
||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
||||
tl.send(dir="N", src=acc)
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
@@ -164,24 +194,36 @@ def allreduce_intercube_multidevice(
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
# ── Phase 4: col broadcast on col == root_col, outward from root_row ──
|
||||
if col == root_col:
|
||||
if row == root_row:
|
||||
if root_row > 0:
|
||||
tl.send(dir="N", src=acc)
|
||||
if cube_h - 1 > root_row:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < root_row:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
if row > 0:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > root_row:
|
||||
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
if row < cube_h - 1:
|
||||
tl.send(dir="S", src=acc)
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
# ── Phase 5: row broadcast outward from root_col ──
|
||||
if col == root_col:
|
||||
if root_col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
if cube_w - 1 > root_col:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < root_col:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
if col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > root_col:
|
||||
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
if col < cube_w - 1:
|
||||
tl.send(dir="E", src=acc)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user