Merge origin/master: combine single-cube fast path + center-root reduce

Conflict resolution:
- intercube_allreduce.py: kept origin's `if single_cube:` early-exit
  (TP launches kernel on one cube/rank → skip intra-SIP mesh and go
  direct to inter-SIP exchange) AND replaced the multi-cube body with
  the local center-root + bidirectional reduce/broadcast (8-hop
  critical path on 4×4 vs 12 with corner root).
- tests/{allreduce,pe2pe}_latency_plots/: kept the local move to
  docs/diagrams/; dropped origin's stale content edits to the old
  paths (regenerable derived artifacts).
- docs/diagrams/pe2pe_latency_plots/summary.csv: kept local
  (post-Phase-2 + center-root values).

Origin contributions retained as-is:
- pyproject.toml: matplotlib >= 3.7 dep.
- runtime_api/distributed.py: derive effective cube_w/h from tensor
  shard placement so single-cube TP paths get cube_w=cube_h=1.
- kernel_args() now accepts optional cube_w/cube_h kwargs.

Verified post-merge:
- test_intercube_root_center.py: 2/2 (center-root multi-cube path).
- test_tp_layers.py + test_tp_mlp.py: 10/10 (single-cube TP path).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 21:41:46 -07:00
4 changed files with 105 additions and 77 deletions
Binary file not shown.

Before

Width:  |  Height:  |  Size: 194 KiB

After

Width:  |  Height:  |  Size: 194 KiB

+1 -1
View File
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "kernbench"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"]
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0", "matplotlib>=3.7"]
[project.scripts]
kernbench = "kernbench.cli.main:main"
@@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = {
}
def kernel_args(world_size: int, n_elem: int) -> tuple:
cube_w = 4
cube_h = 4
def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple:
return (n_elem, cube_w, cube_h, world_size)
@@ -132,6 +130,7 @@ def allreduce_intercube_multidevice(
row = cube_id // cube_w
col = cube_id % cube_w
nbytes = n_elem * 2
single_cube = (cube_w == 1 and cube_h == 1)
root_col = cube_w // 2
root_row = cube_h // 2
@@ -140,7 +139,23 @@ def allreduce_intercube_multidevice(
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# ── Phase 1: row reduce — converge at col == root_col ──
if single_cube:
# ── Single-cube mode: skip intra-SIP reduce, go directly to
# inter-SIP exchange (TP use case: one cube per rank). ──
if n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
else:
# ── Multi-cube mode: center-root bidirectional reduce
# + inter-SIP exchange + bidirectional broadcast ──
# Phase 1: row reduce — converge at col == root_col.
# Left half (col < root_col) walks W→E; right half (col > root_col)
# walks E→W; the root_col cube merges both sides.
if col == 0 and root_col > 0:
@@ -163,7 +178,7 @@ def allreduce_intercube_multidevice(
elif col == cube_w - 1 and cube_w - 1 > root_col:
tl.send(dir="W", src=acc)
# ── Phase 2: col reduce on col == root_col — converge at row == root_row ──
# Phase 2: col reduce on col == root_col — converge at row == root_row.
if col == root_col:
if row == 0 and root_row > 0:
tl.send(dir="S", src=acc)
@@ -185,16 +200,18 @@ def allreduce_intercube_multidevice(
elif row == cube_h - 1 and cube_h - 1 > root_row:
tl.send(dir="N", src=acc)
# ── Phase 3: inter-SIP exchange on root cube ──
# Phase 3: inter-SIP exchange on root cube.
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
acc = _inter_sip_torus_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
acc = _inter_sip_mesh_2d(
acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# ── Phase 4: col broadcast on col == root_col, outward from root_row ──
# Phase 4: col broadcast on col == root_col, outward from root_row.
if col == root_col:
if row == root_row:
if root_row > 0:
@@ -210,7 +227,7 @@ def allreduce_intercube_multidevice(
if row < cube_h - 1:
tl.send(dir="S", src=acc)
# ── Phase 5: row broadcast outward from root_col ──
# Phase 5: row broadcast outward from root_col.
if col == root_col:
if root_col > 0:
tl.send(dir="W", src=acc)
+12 -1
View File
@@ -113,7 +113,18 @@ class AhbmCCLBackend:
)
n_elem = shards[0].nbytes // tensor.itemsize
kernel_fn = self._algo_module.kernel
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
# Derive effective cube dims from tensor's actual shard placement
# (may differ from topology mesh when TP uses fewer cubes).
sip0_cubes = sorted({s.cube for s in shards if s.sip == shards[0].sip})
eff_n_cubes = len(sip0_cubes) if sip0_cubes else 1
if eff_n_cubes == 1:
eff_cube_w, eff_cube_h = 1, 1
else:
eff_cube_w, eff_cube_h = self._cube_w, self._cube_h
kernel_args = self._algo_module.kernel_args(
self._world_size, n_elem,
cube_w=eff_cube_w, cube_h=eff_cube_h,
)
# Resolve sip_rank from the current greenlet's bound rank
from greenlet import getcurrent as _gc