Merge origin/master: combine single-cube fast path + center-root reduce
Conflict resolution:
- intercube_allreduce.py: kept origin's `if single_cube:` early-exit
(TP launches kernel on one cube/rank → skip intra-SIP mesh and go
direct to inter-SIP exchange) AND replaced the multi-cube body with
the local center-root + bidirectional reduce/broadcast (8-hop
critical path on 4×4 vs 12 with corner root).
- tests/{allreduce,pe2pe}_latency_plots/: kept the local move to
docs/diagrams/; dropped origin's stale content edits to the old
paths (regenerable derived artifacts).
- docs/diagrams/pe2pe_latency_plots/summary.csv: kept local
(post-Phase-2 + center-root values).
Origin contributions retained as-is:
- pyproject.toml: matplotlib >= 3.7 dep.
- runtime_api/distributed.py: derive effective cube_w/h from tensor
shard placement so single-cube TP paths get cube_w=cube_h=1.
- kernel_args() now accepts optional cube_w/cube_h kwargs.
Verified post-merge:
- test_intercube_root_center.py: 2/2 (center-root multi-cube path).
- test_tp_layers.py + test_tp_mlp.py: 10/10 (single-cube TP path).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
@@ -132,6 +130,7 @@ def allreduce_intercube_multidevice(
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
single_cube = (cube_w == 1 and cube_h == 1)
|
||||
|
||||
root_col = cube_w // 2
|
||||
root_row = cube_h // 2
|
||||
@@ -140,90 +139,108 @@ def allreduce_intercube_multidevice(
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce — converge at col == root_col ──
|
||||
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
||||
# walks E→W; the root_col cube merges both sides.
|
||||
if col == 0 and root_col > 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif 0 < col < root_col:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col == root_col:
|
||||
if root_col > 0:
|
||||
if single_cube:
|
||||
# ── Single-cube mode: skip intra-SIP reduce, go directly to
|
||||
# inter-SIP exchange (TP use case: one cube per rank). ──
|
||||
if n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
else:
|
||||
# ── Multi-cube mode: center-root bidirectional reduce
|
||||
# + inter-SIP exchange + bidirectional broadcast ──
|
||||
|
||||
# Phase 1: row reduce — converge at col == root_col.
|
||||
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
||||
# walks E→W; the root_col cube merges both sides.
|
||||
if col == 0 and root_col > 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif 0 < col < root_col:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_w - 1 > root_col:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col == root_col:
|
||||
if root_col > 0:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_w - 1 > root_col:
|
||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_col < col < cube_w - 1:
|
||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_col < col < cube_w - 1:
|
||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
||||
tl.send(dir="W", src=acc)
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
||||
tl.send(dir="W", src=acc)
|
||||
|
||||
# ── Phase 2: col reduce on col == root_col — converge at row == root_row ──
|
||||
if col == root_col:
|
||||
if row == 0 and root_row > 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif 0 < row < root_row:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row == root_row:
|
||||
if root_row > 0:
|
||||
# Phase 2: col reduce on col == root_col — converge at row == root_row.
|
||||
if col == root_col:
|
||||
if row == 0 and root_row > 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif 0 < row < root_row:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_h - 1 > root_row:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row == root_row:
|
||||
if root_row > 0:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if cube_h - 1 > root_row:
|
||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_row < row < cube_h - 1:
|
||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
elif root_row < row < cube_h - 1:
|
||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
||||
tl.send(dir="N", src=acc)
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast on col == root_col, outward from root_row ──
|
||||
if col == root_col:
|
||||
if row == root_row:
|
||||
if root_row > 0:
|
||||
tl.send(dir="N", src=acc)
|
||||
if cube_h - 1 > root_row:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < root_row:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
if row > 0:
|
||||
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > root_row:
|
||||
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
if row < cube_h - 1:
|
||||
tl.send(dir="S", src=acc)
|
||||
|
||||
# ── Phase 5: row broadcast outward from root_col ──
|
||||
if col == root_col:
|
||||
if root_col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
if cube_w - 1 > root_col:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < root_col:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
if col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > root_col:
|
||||
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
if col < cube_w - 1:
|
||||
tl.send(dir="E", src=acc)
|
||||
# Phase 3: inter-SIP exchange on root cube.
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# Phase 4: col broadcast on col == root_col, outward from root_row.
|
||||
if col == root_col:
|
||||
if row == root_row:
|
||||
if root_row > 0:
|
||||
tl.send(dir="N", src=acc)
|
||||
if cube_h - 1 > root_row:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < root_row:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
if row > 0:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > root_row:
|
||||
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
if row < cube_h - 1:
|
||||
tl.send(dir="S", src=acc)
|
||||
|
||||
# Phase 5: row broadcast outward from root_col.
|
||||
if col == root_col:
|
||||
if root_col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
if cube_w - 1 > root_col:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < root_col:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
if col > 0:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > root_col:
|
||||
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
if col < cube_w - 1:
|
||||
tl.send(dir="E", src=acc)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user