Fix all remaining test failures: single-cube allreduce + matplotlib dep
- intercube_allreduce: add single-cube fast path that skips intra-SIP mesh reduce and goes directly to inter-SIP exchange. Fixes IPCQ deadlock when TP launches kernel on one cube per SIP. - distributed.py: derive effective cube dims from tensor shard placement instead of hardcoding topology mesh size. - pyproject.toml: add matplotlib>=3.7 to dependencies. - pe_dma.py (prior commit): add MMU translation in pipeline DMA path. 577 passed, 0 failed (was 529 passed, 10 failed). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
@@ -127,61 +125,79 @@ def allreduce_intercube_multidevice(
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
single_cube = (cube_w == 1 and cube_h == 1)
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
if single_cube:
|
||||
# ── Single-cube mode: skip intra-SIP reduce, go directly to
|
||||
# inter-SIP exchange (TP use case: one cube per rank). ──
|
||||
if n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
# ── Multi-cube mode: full mesh reduce + inter-SIP + broadcast ──
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
# Phase 1: row reduce W → E
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
# Phase 2: col reduce N → S on rightmost column
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
# Phase 3: inter-SIP exchange on root cube
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(
|
||||
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# Phase 4: col broadcast S → N on rightmost column
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Phase 5: row broadcast E → W
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user