Merge origin/master: combine single-cube fast path + center-root reduce
Conflict resolution:
- intercube_allreduce.py: kept origin's `if single_cube:` early-exit
(TP launches kernel on one cube/rank → skip intra-SIP mesh and go
direct to inter-SIP exchange) AND replaced the multi-cube body with
the local center-root + bidirectional reduce/broadcast (8-hop
critical path on 4×4 vs 12 with corner root).
- tests/{allreduce,pe2pe}_latency_plots/: kept the local move to
docs/diagrams/; dropped origin's stale content edits to the old
paths (regenerable derived artifacts).
- docs/diagrams/pe2pe_latency_plots/summary.csv: kept local
(post-Phase-2 + center-root values).
Origin contributions retained as-is:
- pyproject.toml: matplotlib >= 3.7 dep.
- runtime_api/distributed.py: derive effective cube_w/h from tensor
shard placement so single-cube TP paths get cube_w=cube_h=1.
- kernel_args() now accepts optional cube_w/cube_h kwargs.
Verified post-merge:
- test_intercube_root_center.py: 2/2 (center-root multi-cube path).
- test_tp_layers.py + test_tp_mlp.py: 10/10 (single-cube TP path).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Binary file not shown.
|
Before Width: | Height: | Size: 194 KiB After Width: | Height: | Size: 194 KiB |
+1
-1
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "kernbench"
|
name = "kernbench"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"]
|
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0", "matplotlib>=3.7"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
kernbench = "kernbench.cli.main:main"
|
kernbench = "kernbench.cli.main:main"
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
|
||||||
cube_w = 4
|
|
||||||
cube_h = 4
|
|
||||||
return (n_elem, cube_w, cube_h, world_size)
|
return (n_elem, cube_w, cube_h, world_size)
|
||||||
|
|
||||||
|
|
||||||
@@ -132,6 +130,7 @@ def allreduce_intercube_multidevice(
|
|||||||
row = cube_id // cube_w
|
row = cube_id // cube_w
|
||||||
col = cube_id % cube_w
|
col = cube_id % cube_w
|
||||||
nbytes = n_elem * 2
|
nbytes = n_elem * 2
|
||||||
|
single_cube = (cube_w == 1 and cube_h == 1)
|
||||||
|
|
||||||
root_col = cube_w // 2
|
root_col = cube_w // 2
|
||||||
root_row = cube_h // 2
|
root_row = cube_h // 2
|
||||||
@@ -140,90 +139,108 @@ def allreduce_intercube_multidevice(
|
|||||||
pe_addr = t_ptr + cube_id * nbytes
|
pe_addr = t_ptr + cube_id * nbytes
|
||||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||||
|
|
||||||
# ── Phase 1: row reduce — converge at col == root_col ──
|
if single_cube:
|
||||||
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
# ── Single-cube mode: skip intra-SIP reduce, go directly to
|
||||||
# walks E→W; the root_col cube merges both sides.
|
# inter-SIP exchange (TP use case: one cube per rank). ──
|
||||||
if col == 0 and root_col > 0:
|
if n_sips > 1:
|
||||||
tl.send(dir="E", src=acc)
|
if sip_topo_kind == SIP_TOPO_RING:
|
||||||
elif 0 < col < root_col:
|
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||||
acc = acc + recv
|
acc = _inter_sip_torus_2d(
|
||||||
tl.send(dir="E", src=acc)
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
elif col == root_col:
|
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||||
if root_col > 0:
|
acc = _inter_sip_mesh_2d(
|
||||||
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
|
else:
|
||||||
|
# ── Multi-cube mode: center-root bidirectional reduce
|
||||||
|
# + inter-SIP exchange + bidirectional broadcast ──
|
||||||
|
|
||||||
|
# Phase 1: row reduce — converge at col == root_col.
|
||||||
|
# Left half (col < root_col) walks W→E; right half (col > root_col)
|
||||||
|
# walks E→W; the root_col cube merges both sides.
|
||||||
|
if col == 0 and root_col > 0:
|
||||||
|
tl.send(dir="E", src=acc)
|
||||||
|
elif 0 < col < root_col:
|
||||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
if cube_w - 1 > root_col:
|
tl.send(dir="E", src=acc)
|
||||||
|
elif col == root_col:
|
||||||
|
if root_col > 0:
|
||||||
|
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||||
|
acc = acc + recv
|
||||||
|
if cube_w - 1 > root_col:
|
||||||
|
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||||
|
acc = acc + recv
|
||||||
|
elif root_col < col < cube_w - 1:
|
||||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
elif root_col < col < cube_w - 1:
|
tl.send(dir="W", src=acc)
|
||||||
recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
||||||
acc = acc + recv
|
tl.send(dir="W", src=acc)
|
||||||
tl.send(dir="W", src=acc)
|
|
||||||
elif col == cube_w - 1 and cube_w - 1 > root_col:
|
|
||||||
tl.send(dir="W", src=acc)
|
|
||||||
|
|
||||||
# ── Phase 2: col reduce on col == root_col — converge at row == root_row ──
|
# Phase 2: col reduce on col == root_col — converge at row == root_row.
|
||||||
if col == root_col:
|
if col == root_col:
|
||||||
if row == 0 and root_row > 0:
|
if row == 0 and root_row > 0:
|
||||||
tl.send(dir="S", src=acc)
|
tl.send(dir="S", src=acc)
|
||||||
elif 0 < row < root_row:
|
elif 0 < row < root_row:
|
||||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
|
||||||
acc = acc + recv
|
|
||||||
tl.send(dir="S", src=acc)
|
|
||||||
elif row == root_row:
|
|
||||||
if root_row > 0:
|
|
||||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
if cube_h - 1 > root_row:
|
tl.send(dir="S", src=acc)
|
||||||
|
elif row == root_row:
|
||||||
|
if root_row > 0:
|
||||||
|
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||||
|
acc = acc + recv
|
||||||
|
if cube_h - 1 > root_row:
|
||||||
|
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||||
|
acc = acc + recv
|
||||||
|
elif root_row < row < cube_h - 1:
|
||||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||||
acc = acc + recv
|
acc = acc + recv
|
||||||
elif root_row < row < cube_h - 1:
|
|
||||||
recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
|
||||||
acc = acc + recv
|
|
||||||
tl.send(dir="N", src=acc)
|
|
||||||
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
|
||||||
tl.send(dir="N", src=acc)
|
|
||||||
|
|
||||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
|
||||||
if cube_id == root_cube and n_sips > 1:
|
|
||||||
if sip_topo_kind == SIP_TOPO_RING:
|
|
||||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
|
||||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
|
||||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
|
||||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
|
||||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
|
||||||
|
|
||||||
# ── Phase 4: col broadcast on col == root_col, outward from root_row ──
|
|
||||||
if col == root_col:
|
|
||||||
if row == root_row:
|
|
||||||
if root_row > 0:
|
|
||||||
tl.send(dir="N", src=acc)
|
tl.send(dir="N", src=acc)
|
||||||
if cube_h - 1 > root_row:
|
elif row == cube_h - 1 and cube_h - 1 > root_row:
|
||||||
tl.send(dir="S", src=acc)
|
|
||||||
elif row < root_row:
|
|
||||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
|
||||||
if row > 0:
|
|
||||||
tl.send(dir="N", src=acc)
|
tl.send(dir="N", src=acc)
|
||||||
elif row > root_row:
|
|
||||||
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
|
||||||
if row < cube_h - 1:
|
|
||||||
tl.send(dir="S", src=acc)
|
|
||||||
|
|
||||||
# ── Phase 5: row broadcast outward from root_col ──
|
# Phase 3: inter-SIP exchange on root cube.
|
||||||
if col == root_col:
|
if cube_id == root_cube and n_sips > 1:
|
||||||
if root_col > 0:
|
if sip_topo_kind == SIP_TOPO_RING:
|
||||||
tl.send(dir="W", src=acc)
|
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||||
if cube_w - 1 > root_col:
|
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||||
tl.send(dir="E", src=acc)
|
acc = _inter_sip_torus_2d(
|
||||||
elif col < root_col:
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||||
if col > 0:
|
acc = _inter_sip_mesh_2d(
|
||||||
tl.send(dir="W", src=acc)
|
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||||
elif col > root_col:
|
|
||||||
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
# Phase 4: col broadcast on col == root_col, outward from root_row.
|
||||||
if col < cube_w - 1:
|
if col == root_col:
|
||||||
tl.send(dir="E", src=acc)
|
if row == root_row:
|
||||||
|
if root_row > 0:
|
||||||
|
tl.send(dir="N", src=acc)
|
||||||
|
if cube_h - 1 > root_row:
|
||||||
|
tl.send(dir="S", src=acc)
|
||||||
|
elif row < root_row:
|
||||||
|
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||||
|
if row > 0:
|
||||||
|
tl.send(dir="N", src=acc)
|
||||||
|
elif row > root_row:
|
||||||
|
acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||||
|
if row < cube_h - 1:
|
||||||
|
tl.send(dir="S", src=acc)
|
||||||
|
|
||||||
|
# Phase 5: row broadcast outward from root_col.
|
||||||
|
if col == root_col:
|
||||||
|
if root_col > 0:
|
||||||
|
tl.send(dir="W", src=acc)
|
||||||
|
if cube_w - 1 > root_col:
|
||||||
|
tl.send(dir="E", src=acc)
|
||||||
|
elif col < root_col:
|
||||||
|
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||||
|
if col > 0:
|
||||||
|
tl.send(dir="W", src=acc)
|
||||||
|
elif col > root_col:
|
||||||
|
acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||||
|
if col < cube_w - 1:
|
||||||
|
tl.send(dir="E", src=acc)
|
||||||
|
|
||||||
tl.store(pe_addr, acc)
|
tl.store(pe_addr, acc)
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,18 @@ class AhbmCCLBackend:
|
|||||||
)
|
)
|
||||||
n_elem = shards[0].nbytes // tensor.itemsize
|
n_elem = shards[0].nbytes // tensor.itemsize
|
||||||
kernel_fn = self._algo_module.kernel
|
kernel_fn = self._algo_module.kernel
|
||||||
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
|
# Derive effective cube dims from tensor's actual shard placement
|
||||||
|
# (may differ from topology mesh when TP uses fewer cubes).
|
||||||
|
sip0_cubes = sorted({s.cube for s in shards if s.sip == shards[0].sip})
|
||||||
|
eff_n_cubes = len(sip0_cubes) if sip0_cubes else 1
|
||||||
|
if eff_n_cubes == 1:
|
||||||
|
eff_cube_w, eff_cube_h = 1, 1
|
||||||
|
else:
|
||||||
|
eff_cube_w, eff_cube_h = self._cube_w, self._cube_h
|
||||||
|
kernel_args = self._algo_module.kernel_args(
|
||||||
|
self._world_size, n_elem,
|
||||||
|
cube_w=eff_cube_w, cube_h=eff_cube_h,
|
||||||
|
)
|
||||||
|
|
||||||
# Resolve sip_rank from the current greenlet's bound rank
|
# Resolve sip_rank from the current greenlet's bound rank
|
||||||
from greenlet import getcurrent as _gc
|
from greenlet import getcurrent as _gc
|
||||||
|
|||||||
Reference in New Issue
Block a user