diff --git a/docs/diagrams/allreduce_latency_plots/topology.png b/docs/diagrams/allreduce_latency_plots/topology.png index 40e8719..1990768 100644 Binary files a/docs/diagrams/allreduce_latency_plots/topology.png and b/docs/diagrams/allreduce_latency_plots/topology.png differ diff --git a/pyproject.toml b/pyproject.toml index ef6ba8e..8a5863f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "kernbench" version = "0.1.0" requires-python = ">=3.10" -dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"] +dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0", "matplotlib>=3.7"] [project.scripts] kernbench = "kernbench.cli.main:main" diff --git a/src/kernbench/ccl/algorithms/intercube_allreduce.py b/src/kernbench/ccl/algorithms/intercube_allreduce.py index a141942..f6e1055 100644 --- a/src/kernbench/ccl/algorithms/intercube_allreduce.py +++ b/src/kernbench/ccl/algorithms/intercube_allreduce.py @@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = { } -def kernel_args(world_size: int, n_elem: int) -> tuple: - cube_w = 4 - cube_h = 4 +def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple: return (n_elem, cube_w, cube_h, world_size) @@ -132,6 +130,7 @@ def allreduce_intercube_multidevice( row = cube_id // cube_w col = cube_id % cube_w nbytes = n_elem * 2 + single_cube = (cube_w == 1 and cube_h == 1) root_col = cube_w // 2 root_row = cube_h // 2 @@ -140,90 +139,108 @@ def allreduce_intercube_multidevice( pe_addr = t_ptr + cube_id * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") - # ── Phase 1: row reduce — converge at col == root_col ── - # Left half (col < root_col) walks W→E; right half (col > root_col) - # walks E→W; the root_col cube merges both sides. - if col == 0 and root_col > 0: - tl.send(dir="E", src=acc) - elif 0 < col < root_col: - recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") - acc = acc + recv - tl.send(dir="E", src=acc) - elif col == root_col: - if root_col > 0: + if single_cube: + # ── Single-cube mode: skip intra-SIP reduce, go directly to + # inter-SIP exchange (TP use case: one cube per rank). ── + if n_sips > 1: + if sip_topo_kind == SIP_TOPO_RING: + acc = _inter_sip_ring(acc, n_sips, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_TORUS: + acc = _inter_sip_torus_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_MESH: + acc = _inter_sip_mesh_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + else: + # ── Multi-cube mode: center-root bidirectional reduce + # + inter-SIP exchange + bidirectional broadcast ── + + # Phase 1: row reduce — converge at col == root_col. + # Left half (col < root_col) walks W→E; right half (col > root_col) + # walks E→W; the root_col cube merges both sides. + if col == 0 and root_col > 0: + tl.send(dir="E", src=acc) + elif 0 < col < root_col: recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") acc = acc + recv - if cube_w - 1 > root_col: + tl.send(dir="E", src=acc) + elif col == root_col: + if root_col > 0: + recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") + acc = acc + recv + if cube_w - 1 > root_col: + recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16") + acc = acc + recv + elif root_col < col < cube_w - 1: recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16") acc = acc + recv - elif root_col < col < cube_w - 1: - recv = tl.recv(dir="E", shape=(n_elem,), dtype="f16") - acc = acc + recv - tl.send(dir="W", src=acc) - elif col == cube_w - 1 and cube_w - 1 > root_col: - tl.send(dir="W", src=acc) + tl.send(dir="W", src=acc) + elif col == cube_w - 1 and cube_w - 1 > root_col: + tl.send(dir="W", src=acc) - # ── Phase 2: col reduce on col == root_col — converge at row == root_row ── - if col == root_col: - if row == 0 and root_row > 0: - tl.send(dir="S", src=acc) - elif 0 < row < root_row: - recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") - acc = acc + recv - tl.send(dir="S", src=acc) - elif row == root_row: - if root_row > 0: + # Phase 2: col reduce on col == root_col — converge at row == root_row. + if col == root_col: + if row == 0 and root_row > 0: + tl.send(dir="S", src=acc) + elif 0 < row < root_row: recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") acc = acc + recv - if cube_h - 1 > root_row: + tl.send(dir="S", src=acc) + elif row == root_row: + if root_row > 0: + recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + acc = acc + recv + if cube_h - 1 > root_row: + recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16") + acc = acc + recv + elif root_row < row < cube_h - 1: recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16") acc = acc + recv - elif root_row < row < cube_h - 1: - recv = tl.recv(dir="S", shape=(n_elem,), dtype="f16") - acc = acc + recv - tl.send(dir="N", src=acc) - elif row == cube_h - 1 and cube_h - 1 > root_row: - tl.send(dir="N", src=acc) - - # ── Phase 3: inter-SIP exchange on root cube ── - if cube_id == root_cube and n_sips > 1: - if sip_topo_kind == SIP_TOPO_RING: - acc = _inter_sip_ring(acc, n_sips, n_elem, tl) - elif sip_topo_kind == SIP_TOPO_TORUS: - acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) - elif sip_topo_kind == SIP_TOPO_MESH: - acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) - - # ── Phase 4: col broadcast on col == root_col, outward from root_row ── - if col == root_col: - if row == root_row: - if root_row > 0: tl.send(dir="N", src=acc) - if cube_h - 1 > root_row: - tl.send(dir="S", src=acc) - elif row < root_row: - acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") - if row > 0: + elif row == cube_h - 1 and cube_h - 1 > root_row: tl.send(dir="N", src=acc) - elif row > root_row: - acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16") - if row < cube_h - 1: - tl.send(dir="S", src=acc) - # ── Phase 5: row broadcast outward from root_col ── - if col == root_col: - if root_col > 0: - tl.send(dir="W", src=acc) - if cube_w - 1 > root_col: - tl.send(dir="E", src=acc) - elif col < root_col: - acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") - if col > 0: - tl.send(dir="W", src=acc) - elif col > root_col: - acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16") - if col < cube_w - 1: - tl.send(dir="E", src=acc) + # Phase 3: inter-SIP exchange on root cube. + if cube_id == root_cube and n_sips > 1: + if sip_topo_kind == SIP_TOPO_RING: + acc = _inter_sip_ring(acc, n_sips, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_TORUS: + acc = _inter_sip_torus_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_MESH: + acc = _inter_sip_mesh_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + + # Phase 4: col broadcast on col == root_col, outward from root_row. + if col == root_col: + if row == root_row: + if root_row > 0: + tl.send(dir="N", src=acc) + if cube_h - 1 > root_row: + tl.send(dir="S", src=acc) + elif row < root_row: + acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") + if row > 0: + tl.send(dir="N", src=acc) + elif row > root_row: + acc = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + if row < cube_h - 1: + tl.send(dir="S", src=acc) + + # Phase 5: row broadcast outward from root_col. + if col == root_col: + if root_col > 0: + tl.send(dir="W", src=acc) + if cube_w - 1 > root_col: + tl.send(dir="E", src=acc) + elif col < root_col: + acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") + if col > 0: + tl.send(dir="W", src=acc) + elif col > root_col: + acc = tl.recv(dir="W", shape=(n_elem,), dtype="f16") + if col < cube_w - 1: + tl.send(dir="E", src=acc) tl.store(pe_addr, acc) diff --git a/src/kernbench/runtime_api/distributed.py b/src/kernbench/runtime_api/distributed.py index f87a268..a56086f 100644 --- a/src/kernbench/runtime_api/distributed.py +++ b/src/kernbench/runtime_api/distributed.py @@ -113,7 +113,18 @@ class AhbmCCLBackend: ) n_elem = shards[0].nbytes // tensor.itemsize kernel_fn = self._algo_module.kernel - kernel_args = self._algo_module.kernel_args(self._world_size, n_elem) + # Derive effective cube dims from tensor's actual shard placement + # (may differ from topology mesh when TP uses fewer cubes). + sip0_cubes = sorted({s.cube for s in shards if s.sip == shards[0].sip}) + eff_n_cubes = len(sip0_cubes) if sip0_cubes else 1 + if eff_n_cubes == 1: + eff_cube_w, eff_cube_h = 1, 1 + else: + eff_cube_w, eff_cube_h = self._cube_w, self._cube_h + kernel_args = self._algo_module.kernel_args( + self._world_size, n_elem, + cube_w=eff_cube_w, cube_h=eff_cube_h, + ) # Resolve sip_rank from the current greenlet's bound rank from greenlet import getcurrent as _gc