Fix all remaining test failures: single-cube allreduce + matplotlib dep
- intercube_allreduce: add single-cube fast path that skips intra-SIP mesh reduce and goes directly to inter-SIP exchange. Fixes IPCQ deadlock when TP launches kernel on one cube per SIP. - distributed.py: derive effective cube dims from tensor shard placement instead of hardcoding topology mesh size. - pyproject.toml: add matplotlib>=3.7 to dependencies. - pe_dma.py (prior commit): add MMU translation in pipeline DMA path. 577 passed, 0 failed (was 529 passed, 10 failed). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "kernbench"
|
name = "kernbench"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"]
|
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0", "matplotlib>=3.7"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
kernbench = "kernbench.cli.main:main"
|
kernbench = "kernbench.cli.main:main"
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
|
||||||
cube_w = 4
|
|
||||||
cube_h = 4
|
|
||||||
return (n_elem, cube_w, cube_h, world_size)
|
return (n_elem, cube_w, cube_h, world_size)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,61 +125,79 @@ def allreduce_intercube_multidevice(
|
|||||||
row = cube_id // cube_w
|
row = cube_id // cube_w
|
||||||
col = cube_id % cube_w
|
col = cube_id % cube_w
|
||||||
nbytes = n_elem * 2
|
nbytes = n_elem * 2
|
||||||
|
single_cube = (cube_w == 1 and cube_h == 1)
|
||||||
|
|
||||||
pe_addr = t_ptr + cube_id * nbytes
|
pe_addr = t_ptr + cube_id * nbytes
|
||||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||||
|
|
||||||
# ── Phase 1: row reduce W → E ──
|
if single_cube:
|
||||||
if col == 0:
|
# ── Single-cube mode: skip intra-SIP reduce, go directly to
|
||||||
tl.send(dir="E", src=acc)
|
# inter-SIP exchange (TP use case: one cube per rank). ──
|
||||||
elif col < cube_w - 1:
|
if n_sips > 1:
|
||||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
if sip_topo_kind == SIP_TOPO_RING:
|
||||||
acc = acc + recv
|
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||||
tl.send(dir="E", src=acc)
|
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||||
|
acc = _inter_sip_torus_2d(
|
||||||
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
|
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||||
|
acc = _inter_sip_mesh_2d(
|
||||||
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
else:
|
else:
|
||||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
# ── Multi-cube mode: full mesh reduce + inter-SIP + broadcast ──
|
||||||
acc = acc + recv
|
|
||||||
|
|
||||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
# Phase 1: row reduce W → E
|
||||||
if col == cube_w - 1:
|
if col == 0:
|
||||||
if row == 0:
|
tl.send(dir="E", src=acc)
|
||||||
tl.send(dir="S", src=acc)
|
elif col < cube_w - 1:
|
||||||
elif row < cube_h - 1:
|
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
tl.send(dir="S", src=acc)
|
tl.send(dir="E", src=acc)
|
||||||
else:
|
else:
|
||||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
|
|
||||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
# Phase 2: col reduce N → S on rightmost column
|
||||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
if col == cube_w - 1:
|
||||||
if cube_id == root_cube and n_sips > 1:
|
if row == 0:
|
||||||
if sip_topo_kind == SIP_TOPO_RING:
|
tl.send(dir="S", src=acc)
|
||||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
elif row < cube_h - 1:
|
||||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
acc = acc + recv
|
||||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
tl.send(dir="S", src=acc)
|
||||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
else:
|
||||||
|
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||||
|
acc = acc + recv
|
||||||
|
|
||||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
# Phase 3: inter-SIP exchange on root cube
|
||||||
if col == cube_w - 1:
|
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||||
if row == cube_h - 1:
|
if cube_id == root_cube and n_sips > 1:
|
||||||
tl.send(dir="N", src=acc)
|
if sip_topo_kind == SIP_TOPO_RING:
|
||||||
elif row > 0:
|
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||||
tl.send(dir="N", src=acc)
|
acc = _inter_sip_torus_2d(
|
||||||
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
|
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||||
|
acc = _inter_sip_mesh_2d(
|
||||||
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
|
|
||||||
|
# Phase 4: col broadcast S → N on rightmost column
|
||||||
|
if col == cube_w - 1:
|
||||||
|
if row == cube_h - 1:
|
||||||
|
tl.send(dir="N", src=acc)
|
||||||
|
elif row > 0:
|
||||||
|
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||||
|
tl.send(dir="N", src=acc)
|
||||||
|
else:
|
||||||
|
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||||
|
|
||||||
|
# Phase 5: row broadcast E → W
|
||||||
|
if col == cube_w - 1:
|
||||||
|
tl.send(dir="W", src=acc)
|
||||||
|
elif col > 0:
|
||||||
|
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||||
|
tl.send(dir="W", src=acc)
|
||||||
else:
|
else:
|
||||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||||
|
|
||||||
# ── Phase 5: row broadcast E → W ──
|
|
||||||
if col == cube_w - 1:
|
|
||||||
tl.send(dir="W", src=acc)
|
|
||||||
elif col > 0:
|
|
||||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
|
||||||
tl.send(dir="W", src=acc)
|
|
||||||
else:
|
|
||||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
|
||||||
|
|
||||||
tl.store(pe_addr, acc)
|
tl.store(pe_addr, acc)
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,18 @@ class AhbmCCLBackend:
|
|||||||
)
|
)
|
||||||
n_elem = shards[0].nbytes // tensor.itemsize
|
n_elem = shards[0].nbytes // tensor.itemsize
|
||||||
kernel_fn = self._algo_module.kernel
|
kernel_fn = self._algo_module.kernel
|
||||||
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
|
# Derive effective cube dims from tensor's actual shard placement
|
||||||
|
# (may differ from topology mesh when TP uses fewer cubes).
|
||||||
|
sip0_cubes = sorted({s.cube for s in shards if s.sip == shards[0].sip})
|
||||||
|
eff_n_cubes = len(sip0_cubes) if sip0_cubes else 1
|
||||||
|
if eff_n_cubes == 1:
|
||||||
|
eff_cube_w, eff_cube_h = 1, 1
|
||||||
|
else:
|
||||||
|
eff_cube_w, eff_cube_h = self._cube_w, self._cube_h
|
||||||
|
kernel_args = self._algo_module.kernel_args(
|
||||||
|
self._world_size, n_elem,
|
||||||
|
cube_w=eff_cube_w, cube_h=eff_cube_h,
|
||||||
|
)
|
||||||
|
|
||||||
# Resolve sip_rank from the current greenlet's bound rank
|
# Resolve sip_rank from the current greenlet's bound rank
|
||||||
from greenlet import getcurrent as _gc
|
from greenlet import getcurrent as _gc
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 41 KiB |
|
Before Width: | Height: | Size: 87 KiB After Width: | Height: | Size: 87 KiB |
|
Before Width: | Height: | Size: 41 KiB After Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 194 KiB After Width: | Height: | Size: 194 KiB |
|
Before Width: | Height: | Size: 41 KiB After Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 48 KiB After Width: | Height: | Size: 48 KiB |
|
Before Width: | Height: | Size: 48 KiB After Width: | Height: | Size: 48 KiB |
|
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 51 KiB |
|
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 50 KiB |
|
Before Width: | Height: | Size: 101 KiB After Width: | Height: | Size: 100 KiB |
@@ -79,13 +79,3 @@ h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),8192,ipcq,181.659999
|
|||||||
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),8192,raw,183.04000000000087
|
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),8192,raw,183.04000000000087
|
||||||
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,ipcq,205.65999999999985
|
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,ipcq,205.65999999999985
|
||||||
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,raw,207.04000000000087
|
h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,raw,207.04000000000087
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",128,ipcq,6.015000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",256,ipcq,6.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",384,ipcq,7.015000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",512,ipcq,7.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",768,ipcq,8.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",1024,ipcq,9.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",2048,ipcq,13.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",4096,ipcq,21.515000000003056
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",8192,ipcq,37.51499999999214
|
|
||||||
h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",10240,ipcq,45.51499999999214
|
|
||||||
|
|||||||
|