Fix all remaining test failures: single-cube allreduce + matplotlib dep

- intercube_allreduce: add single-cube fast path that skips intra-SIP
  mesh reduce and goes directly to inter-SIP exchange. Fixes IPCQ
  deadlock when TP launches kernel on one cube per SIP.
- distributed.py: derive effective cube dims from tensor shard placement
  instead of hardcoding topology mesh size.
- pyproject.toml: add matplotlib>=3.7 to dependencies.
- pe_dma.py (prior commit): add MMU translation in pipeline DMA path.

577 passed, 0 failed (was 529 passed, 10 failed).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 21:25:31 -07:00
parent d55dc6cb4f
commit fca24feac5
15 changed files with 112 additions and 95 deletions
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
}
def kernel_args(world_size: int, n_elem: int) -> tuple:
cube_w = 4
cube_h = 4
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
return (n_elem, cube_w, cube_h, world_size)
@@ -127,61 +125,79 @@ def allreduce_intercube_multidevice(
row = cube_id // cube_w
col = cube_id % cube_w
nbytes = n_elem * 2
single_cube = (cube_w == 1 and cube_h == 1)
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# ── Phase 1: row reduce W → E ──
if col == 0:
tl.send(dir="E", src=acc)
elif col < cube_w - 1:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="E", src=acc)
if single_cube:
# ── Single-cube mode: skip intra-SIP reduce, go directly to
# inter-SIP exchange (TP use case: one cube per rank). ──
if n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
else:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Multi-cube mode: full mesh reduce + inter-SIP + broadcast ──
# ── Phase 2: col reduce NS on rightmost column ──
if col == cube_w - 1:
if row == 0:
tl.send(dir="S", src=acc)
elif row < cube_h - 1:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
# Phase 1: row reduce WE
if col == 0:
tl.send(dir="E", src=acc)
elif col < cube_w - 1:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
tl.send(dir="E", src=acc)
else:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 3: inter-SIP exchange on root cube ──
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# Phase 2: col reduce N → S on rightmost column
if col == cube_w - 1:
if row == 0:
tl.send(dir="S", src=acc)
elif row < cube_h - 1:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
else:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 4: col broadcast S → N on rightmost column ──
if col == cube_w - 1:
if row == cube_h - 1:
tl.send(dir="N", src=acc)
elif row > 0:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
tl.send(dir="N", src=acc)
# Phase 3: inter-SIP exchange on root cube
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# Phase 4: col broadcast S → N on rightmost column
if col == cube_w - 1:
if row == cube_h - 1:
tl.send(dir="N", src=acc)
elif row > 0:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
tl.send(dir="N", src=acc)
else:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
# Phase 5: row broadcast E → W
if col == cube_w - 1:
tl.send(dir="W", src=acc)
elif col > 0:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.send(dir="W", src=acc)
else:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
# ── Phase 5: row broadcast E → W ──
if col == cube_w - 1:
tl.send(dir="W", src=acc)
elif col > 0:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.send(dir="W", src=acc)
else:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.store(pe_addr, acc)