Intercube allreduce: pe0 cube-mesh reduce + multi-SIP ring/torus/mesh
New intercube allreduce kernel replacing the old flat ring algorithms. Reduces across the 4x4 cube mesh within each SIP (pe0-only, same-lane), then inter-SIP exchange on root cube, then broadcast back. Supports ring_1d, torus_2d, and mesh_2d_no_wrap SIP topologies driven by topology.yaml. Integrated with dist.init_process_group / dist.all_reduce. New files: - src/kernbench/ccl/algorithms/intercube_allreduce.py (kernel) - src/kernbench/ccl/sfr_config.py (configure_sfr_intercube_multisip) - tests/test_allreduce_multidevice.py (config-driven, 3 topologies) - tests/test_distributed_intercube_allreduce.py (full distributed path) - tests/test_intercube_sfr_config.py (SFR wiring verification) Modified: - distributed.py: AhbmCCLBackend uses configure_sfr_intercube_multisip - topologies.py: added torus_2d, mesh_2d_no_wrap - install.py: global_E/W/N/S in _OPPOSITE_DIR - topology.yaml: added system.sips.topology - ccl.yaml: single intercube_allreduce algorithm - benches/ccl_allreduce.py: row_wise cube-mesh tensor layout Removed old flat-ring algorithms and their tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+30
-29
@@ -1,15 +1,12 @@
|
||||
"""CCL all-reduce bench (ADR-0024 + ADR-0027).
|
||||
|
||||
Pure TP launcher model: rank = SIP. Each rank owns a ``(1, n_elem)`` tile
|
||||
initialised to ``rank + 1``; after ``dist.all_reduce(op="sum")`` every rank
|
||||
must see ``sum(1..world_size)``. Rank 0 prints the pass/fail line.
|
||||
Pure TP launcher model: rank = SIP. Each rank owns a ``(N_CUBES, n_elem)``
|
||||
tensor sharded row-wise across the cube mesh (pe0 per cube). After
|
||||
``dist.all_reduce(op="sum")`` every cube on every rank must hold
|
||||
``N_CUBES * sum(1..world_size)``. Rank 0 prints the pass/fail line.
|
||||
|
||||
Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml``
|
||||
(SIP count → world_size).
|
||||
|
||||
Legacy ``rank = PE`` single-driver path was removed — intra-SIP PE-level
|
||||
collective is expressed by the kernel itself via ``tl.program_id`` and
|
||||
does not need a host-side ``ProcessGroup``.
|
||||
(SIP count → world_size, cube_mesh → N_CUBES).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -20,18 +17,19 @@ import numpy as np
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
DEFAULT_N_ELEM = 32
|
||||
DEFAULT_N_ELEM = 8
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _BenchCfg:
|
||||
algorithm: str
|
||||
n_elem: int
|
||||
n_cubes: int
|
||||
world_size: int
|
||||
|
||||
|
||||
def _resolve_cfg(torch) -> _BenchCfg:
|
||||
"""Read ccl.yaml once at host side; enforce rank = SIP contract."""
|
||||
"""Read ccl.yaml + topology once at host side."""
|
||||
merged = resolve_algorithm_config(load_ccl_config())
|
||||
ws = torch.distributed.get_world_size()
|
||||
spec = torch.spec or {}
|
||||
@@ -39,55 +37,58 @@ def _resolve_cfg(torch) -> _BenchCfg:
|
||||
if ws != n_sips:
|
||||
raise RuntimeError(
|
||||
f"ccl_allreduce bench requires world_size == topology SIP count "
|
||||
f"(world_size={ws}, n_sips={n_sips}). rank = PE mode was removed "
|
||||
f"(intra-SIP collectives are expressed inside the kernel)."
|
||||
f"(world_size={ws}, n_sips={n_sips})."
|
||||
)
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
n_cubes = int(cm.get("w", 4)) * int(cm.get("h", 4))
|
||||
return _BenchCfg(
|
||||
algorithm=merged["algorithm"],
|
||||
n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)),
|
||||
n_cubes=n_cubes,
|
||||
world_size=ws,
|
||||
)
|
||||
|
||||
|
||||
def _rank_local_dp() -> DPPolicy:
|
||||
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
def _rank_dp(n_cubes: int) -> DPPolicy:
|
||||
return DPPolicy(cube="row_wise", pe="replicate", num_cubes=n_cubes, num_pes=1)
|
||||
|
||||
|
||||
def _allocate_rank_tile(torch, rank: int, cfg: _BenchCfg):
|
||||
"""Allocate this rank's ``(1, n_elem)`` tile on its SIP."""
|
||||
def _allocate_rank_tensor(torch, rank: int, cfg: _BenchCfg):
|
||||
"""Allocate this rank's ``(n_cubes, n_elem)`` tensor on its SIP."""
|
||||
return torch.zeros(
|
||||
(1, cfg.n_elem), dtype="f16",
|
||||
dp=_rank_local_dp(), name=f"ccl_in_r{rank}",
|
||||
(cfg.n_cubes, cfg.n_elem), dtype="f16",
|
||||
dp=_rank_dp(cfg.n_cubes), name=f"ccl_in_r{rank}",
|
||||
)
|
||||
|
||||
|
||||
def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None:
|
||||
"""Fill the tile with the scalar ``rank + 1`` (deterministic + easy to verify)."""
|
||||
arr = np.full((1, cfg.n_elem), float(rank + 1), dtype=np.float16)
|
||||
"""Fill all cubes with the scalar ``rank + 1``."""
|
||||
arr = np.full((cfg.n_cubes, cfg.n_elem), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(arr))
|
||||
|
||||
|
||||
def _report(result: np.ndarray, cfg: _BenchCfg) -> None:
|
||||
"""Single-line pass/fail printer (rank 0 only, called after all_reduce)."""
|
||||
expected = float(sum(range(1, cfg.world_size + 1)))
|
||||
ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
|
||||
"""Single-line pass/fail printer (rank 0 only)."""
|
||||
expected = float(cfg.n_cubes * sum(range(1, cfg.world_size + 1)))
|
||||
ok = True
|
||||
for cube_id in range(cfg.n_cubes):
|
||||
if not np.allclose(result[cube_id], expected, rtol=1e-1, atol=1e-1):
|
||||
ok = False
|
||||
break
|
||||
if ok:
|
||||
print(f" {cfg.algorithm} (ws={cfg.world_size}): {cfg.world_size} OK")
|
||||
total = cfg.world_size * cfg.n_cubes
|
||||
print(f" {cfg.algorithm} (ws={cfg.world_size}): {total} OK")
|
||||
return
|
||||
got = float(result.reshape(-1).mean())
|
||||
print(
|
||||
f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): "
|
||||
f"got mean={got:.3f}, expected={expected:.3f}"
|
||||
)
|
||||
print(
|
||||
f" {cfg.algorithm} (ws={cfg.world_size}): "
|
||||
f"0 OK / {cfg.world_size} FAIL"
|
||||
)
|
||||
|
||||
|
||||
def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
|
||||
torch.ahbm.set_device(rank)
|
||||
tensor = _allocate_rank_tile(torch, rank, cfg)
|
||||
tensor = _allocate_rank_tensor(torch, rank, cfg)
|
||||
_init_with_rank_value(torch, tensor, rank, cfg)
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
if rank == 0:
|
||||
|
||||
@@ -6,12 +6,7 @@
|
||||
|
||||
defaults:
|
||||
# Algorithm to run for this benchmark execution.
|
||||
algorithm: ring_allreduce_tcm
|
||||
|
||||
# NOTE: world_size is not set here by default. AhbmCCLBackend derives it
|
||||
# from the chosen algorithm's entry (if it sets ``world_size``) or from
|
||||
# topology.yaml (``sips × cubes_per_sip × pes_per_cube``). This mirrors
|
||||
# real PyTorch DDP where ranks/world_size come from env vars, not code.
|
||||
algorithm: intercube_allreduce
|
||||
|
||||
# IPCQ ring buffer location.
|
||||
# tcm — PE-local TCM (fast, small, conflicts with compute TCM access)
|
||||
@@ -30,43 +25,21 @@ defaults:
|
||||
# Slot size in bytes (must hold one tile worth of data).
|
||||
slot_size: 4096
|
||||
|
||||
# PE_DMA virtual channel chunk size (D8). First implementation does not
|
||||
# use chunk-level interleave; this is reserved for future precision.
|
||||
# PE_DMA virtual channel chunk size (D8).
|
||||
vc_chunk_size: 256
|
||||
|
||||
# Credit return fast path message size (D9). Used by bottleneck-BW
|
||||
# latency calculation. 16-64 bytes typical.
|
||||
# Credit return fast path message size (D9).
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
# ── ring all-reduce, buffer in PE_TCM ──
|
||||
# Defaults to topology-derived world_size (full system, 256 ranks).
|
||||
# Use a smaller tile size at high rank counts so f16 sums stay within
|
||||
# the verification tolerance and op_log replay scales.
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
|
||||
# ── ring all-reduce, buffer in PE-local HBM ──
|
||||
ring_allreduce_hbm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: hbm
|
||||
n_elem: 8
|
||||
|
||||
# ── ring all-reduce, buffer in cube SRAM ──
|
||||
ring_allreduce_sram:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: sram
|
||||
n_elem: 8
|
||||
|
||||
# ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ──
|
||||
# Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat.
|
||||
hierarchical_allreduce:
|
||||
module: kernbench.ccl.algorithms.hierarchical_allreduce
|
||||
# ── intercube all-reduce (pe0-only, cube mesh + inter-SIP) ──
|
||||
# Reduces across the 4×4 cube mesh within each SIP, then inter-SIP
|
||||
# exchange on root cube, then broadcast back. SIP topology is read
|
||||
# from topology.yaml → system.sips.topology. Kernel auto-selects
|
||||
# ring / torus / mesh inter-SIP exchange pattern.
|
||||
intercube_allreduce:
|
||||
module: kernbench.ccl.algorithms.intercube_allreduce
|
||||
topology: none
|
||||
buffer_kind: tcm
|
||||
n_elem: 16
|
||||
n_elem: 8
|
||||
root_cube: 15
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
|
||||
|
||||
Each PE sends its tile to the E neighbor and receives one tile from W,
|
||||
then stores the received tile back into its own HBM slice. The simplest
|
||||
possible demonstration of ``tl.send`` / ``tl.recv``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem,)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, tl):
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Send our local HBM tile to the E neighbor.
|
||||
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=src)
|
||||
|
||||
# Receive a tile from W and store it into our slice (overwrite).
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.store(pe_addr, recv)
|
||||
@@ -1,192 +0,0 @@
|
||||
"""Hierarchical all-reduce kernel (ADR-0023).
|
||||
|
||||
3-level reduce + broadcast exploiting the topology hierarchy:
|
||||
|
||||
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
|
||||
Bidirectional ring reduce to PE 0.
|
||||
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
|
||||
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
|
||||
Level 3 — Inter-SIP (2 SIPs, parent):
|
||||
Pair exchange between SIP representatives.
|
||||
Broadcast — Reverse chain through levels 2 and 1.
|
||||
|
||||
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
|
||||
right-half sends via dir_inc (wrapping). Representative receives from
|
||||
both sides. Rounds per level = ceil((group_size - 1) / 2).
|
||||
|
||||
Direction pairing (ring):
|
||||
Send via dir_dec at PE K → recv via dir_inc at PE K-1
|
||||
Send via dir_inc at PE K → recv via dir_dec at PE K+1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Positional kernel args for the ahbm backend."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
|
||||
|
||||
|
||||
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
|
||||
"""Build the 3-level neighbor map."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
|
||||
pe_id = rank % pes_per_cube
|
||||
cube_global = rank // pes_per_cube
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
result = {}
|
||||
|
||||
# Level 1: intra-cube ring (E/W, all PEs)
|
||||
cube_base = cube_global * pes_per_cube
|
||||
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
|
||||
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
|
||||
|
||||
# Level 2: inter-cube ring (N/S, PE 0 only)
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
sip_base = sip_id * cubes_per_sip * pes_per_cube
|
||||
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
|
||||
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
|
||||
result["N"] = next_cube_pe0
|
||||
result["S"] = prev_cube_pe0
|
||||
|
||||
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
|
||||
result["parent"] = other_sip_pe0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
|
||||
"""Bidirectional ring reduce to node 0.
|
||||
|
||||
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
|
||||
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
|
||||
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
|
||||
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
|
||||
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
|
||||
|
||||
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
|
||||
send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
half = group_size // 2
|
||||
|
||||
if my_id == 0:
|
||||
# Representative: recv left-half sum via dir_inc (from PE 1)
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
|
||||
if group_size - half - 1 >= 1:
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
elif my_id <= half:
|
||||
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
|
||||
if my_id < half: # not the far-edge
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_dec, src=acc)
|
||||
|
||||
else:
|
||||
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
|
||||
if my_id > half + 1: # not the near-edge
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
|
||||
"""Linear chain broadcast from node 0 via dir_inc.
|
||||
|
||||
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
|
||||
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
|
||||
|
||||
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
|
||||
# dir_dec is the "other" direction. We infer it from the ring:
|
||||
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
|
||||
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
|
||||
|
||||
if my_id == 0:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
|
||||
if my_id < group_size - 1:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
return acc
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
|
||||
"""Hierarchical all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA).
|
||||
n_elem: f16 elements per tile.
|
||||
pes_per_cube: PEs per cube (typically 8).
|
||||
cubes_per_sip: cubes per SIP (typically 16).
|
||||
num_sips: number of SIPs (typically 2).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
pe_id = tl.program_id(axis=0)
|
||||
cube_global = tl.program_id(axis=1)
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
rank = cube_global * pes_per_cube + pe_id
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
shape = (n_elem,)
|
||||
dtype = "f16"
|
||||
|
||||
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
tl.send(dir="parent", src=acc)
|
||||
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
# ── Broadcast back ──
|
||||
|
||||
# Level 2: cube 0 PE 0 → all PE 0s via chain
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# Level 1: PE 0 → all PEs in cube via chain
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
|
||||
|
||||
Reduces across the 4×4 cube mesh within each SIP, then exchanges
|
||||
between SIPs using the configured SIP topology, and broadcasts back.
|
||||
|
||||
Supported SIP topologies (selected via ``sip_topo_kind``):
|
||||
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
|
||||
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
|
||||
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
|
||||
|
||||
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
SIP_TOPO_RING = 0
|
||||
SIP_TOPO_TORUS = 1
|
||||
SIP_TOPO_MESH = 2
|
||||
|
||||
TOPO_NAME_TO_KIND = {
|
||||
"ring_1d": SIP_TOPO_RING,
|
||||
"torus_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d_no_wrap": SIP_TOPO_MESH,
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
def _inter_sip_ring(acc, n_sips, n_elem, tl):
|
||||
current = acc
|
||||
for _ in range(n_sips - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
# Row ring (global_E / global_W)
|
||||
current = acc
|
||||
for _ in range(sip_topo_w - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
# Col ring (global_S / global_N)
|
||||
current = acc
|
||||
for _ in range(sip_topo_h - 1):
|
||||
tl.send(dir="global_S", src=current)
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
sip_row = sip_rank // sip_topo_w
|
||||
sip_col = sip_rank % sip_topo_w
|
||||
|
||||
# Row reduce W → E
|
||||
if sip_col == 0:
|
||||
tl.send(dir="global_E", src=acc)
|
||||
elif sip_col < sip_topo_w - 1:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Row broadcast E → W
|
||||
if sip_col == sip_topo_w - 1:
|
||||
tl.send(dir="global_W", src=acc)
|
||||
elif sip_col > 0:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Col reduce N → S
|
||||
if sip_row == 0:
|
||||
tl.send(dir="global_S", src=acc)
|
||||
elif sip_row < sip_topo_h - 1:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Col broadcast S → N
|
||||
if sip_row == sip_topo_h - 1:
|
||||
tl.send(dir="global_N", src=acc)
|
||||
elif sip_row > 0:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def allreduce_intercube_multidevice(
|
||||
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
|
||||
):
|
||||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||||
|
||||
Args:
|
||||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||||
n_elem: f16 elements per cube tile.
|
||||
cube_w: cube mesh width (columns).
|
||||
cube_h: cube mesh height (rows).
|
||||
n_sips: number of SIPs.
|
||||
sip_rank: this SIP's rank (0-based).
|
||||
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
|
||||
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
|
||||
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
cube_id = tl.program_id(axis=1)
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
kernel = allreduce_intercube_multidevice
|
||||
@@ -1,73 +0,0 @@
|
||||
"""2D-mesh all-reduce kernel (ADR-0023).
|
||||
|
||||
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
|
||||
1. Row reduce: ring all-reduce along E/W within each row.
|
||||
2. Column reduce: ring all-reduce along N/S within each column.
|
||||
|
||||
After both phases, every rank holds the global sum.
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. Math/recv
|
||||
handles are passed directly to the next send, avoiding store→reload
|
||||
which doesn't propagate correctly with timing-only Phase 1 math.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Mesh all-reduce requires ``world_size`` to be a perfect square —
|
||||
the mesh side length is ``sqrt(world_size)``.
|
||||
"""
|
||||
side = int(round(math.sqrt(world_size)))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_allreduce requires a square world_size; got {world_size}"
|
||||
)
|
||||
return (n_elem, side)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, side, tl):
|
||||
"""All-reduce on a square mesh.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA shared across ranks)
|
||||
n_elem: number of f16 elements per tile
|
||||
side: mesh side length (sqrt(world_size))
|
||||
tl: TLContext (ADR-0022).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
# ── Phase 1: row ring (E direction) ──
|
||||
# Ring forwards each received tile (not the cumulative acc) so every
|
||||
# tile passes through every rank exactly once.
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
|
||||
# store/reload here — the math handle's scratch addr is the source for
|
||||
# the first column send and Phase 2 ipcq_copy replays from there.
|
||||
current = acc
|
||||
|
||||
# ── Phase 2: column ring (S direction) ──
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="S", src=current)
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
|
||||
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
|
||||
of all PE tiles.
|
||||
|
||||
Strategy
|
||||
--------
|
||||
Each PE starts with its own tile in HBM. The kernel:
|
||||
1. Loads the local tile into a TensorHandle (the accumulator).
|
||||
2. In each of ``world_size - 1`` rounds:
|
||||
- Sends the current accumulator/recv slot to the E neighbor.
|
||||
- Receives a tile from the W neighbor — the recv handle points
|
||||
into the per-direction TCM slot.
|
||||
- Adds the received tile to the accumulator using the TensorHandle
|
||||
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
|
||||
3. Stores the final accumulator back to HBM via tl.store. The store is
|
||||
recorded in op_log with both src and dst, so Phase 2 will copy the
|
||||
replayed math result from PE-local scratch into HBM.
|
||||
|
||||
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
|
||||
not yet computed, so the accumulator data flowing through Phase 1 may
|
||||
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
|
||||
in stable t_start order, producing correct final HBM contents.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
|
||||
"""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Ring all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address of the column-sharded tensor — all PEs
|
||||
share this base. The per-PE slice lives at
|
||||
``t_ptr + global_rank * n_elem * 2``.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
|
||||
global rank from ``program_id(axis=0)`` (local PE) and
|
||||
``program_id(axis=1)`` (cube id):
|
||||
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2 # f16
|
||||
|
||||
# Each PE reads from its own slice of the shared base address
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Load the local tile — handle points at HBM[pe_addr].
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
# The ring forwards each received tile to the next neighbor (NOT the
|
||||
# cumulative accumulator), so every rank's tile passes through every
|
||||
# rank exactly once. The accumulator sums the new arrival each round.
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
|
||||
# numpy in Phase 2 via DataExecutor). The result handle lives at
|
||||
# an auto-allocated PE-local scratch addr.
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
# Final result back to this PE's HBM slice. Op_log captures the
|
||||
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
|
||||
# accumulated value into HBM for verification.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Two-phase binary tree all-reduce:
|
||||
|
||||
Phase 1 (reduce up):
|
||||
- leaf nodes send their value to ``parent``
|
||||
- internal nodes recv from each child, sum, then send to ``parent``
|
||||
- root accumulates child contributions; final acc holds global sum
|
||||
|
||||
Phase 2 (broadcast down):
|
||||
- root sends acc to ``child_left`` and ``child_right`` (if present)
|
||||
- internal nodes recv from ``parent``, then forward to children
|
||||
- all ranks store the final acc to HBM
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. The kernel
|
||||
deliberately avoids the store→reload→send pattern: math/recv handles
|
||||
are passed directly to the next send so PE_DMA snapshots a deterministic
|
||||
source addr that Phase 2 can replay.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Tree all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Compute children/parent existence (matches tree_binary topology generator)
|
||||
has_parent = rank > 0
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
has_left = left < world_size
|
||||
has_right = right < world_size
|
||||
|
||||
# ── Phase 1: reduce up ──
|
||||
if has_left:
|
||||
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if has_right:
|
||||
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
if has_parent:
|
||||
# Send the math/load handle directly — its addr is either the
|
||||
# original HBM tile (leaf) or the PE-local scratch where the
|
||||
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
|
||||
tl.send(dir="parent", src=acc)
|
||||
|
||||
# ── Phase 2: broadcast down ──
|
||||
if has_parent:
|
||||
# Replace acc with the value broadcast from the parent (the global
|
||||
# sum). The recv handle points at the parent-direction TCM slot.
|
||||
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
|
||||
|
||||
if has_left:
|
||||
tl.send(dir="child_left", src=acc)
|
||||
if has_right:
|
||||
tl.send(dir="child_right", src=acc)
|
||||
|
||||
# Final store to HBM for the bench's verification path.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -219,7 +219,11 @@ def install_ipcq(
|
||||
"neighbor_table": neighbor_table,
|
||||
}
|
||||
|
||||
_OPPOSITE_DIR = {"E": "W", "W": "E", "N": "S", "S": "N"}
|
||||
_OPPOSITE_DIR = {
|
||||
"E": "W", "W": "E", "N": "S", "S": "N",
|
||||
"global_E": "global_W", "global_W": "global_E",
|
||||
"global_N": "global_S", "global_S": "global_N",
|
||||
}
|
||||
|
||||
def reverse_direction(my_rank: int, peer_rank: int, my_dir: str) -> str | None:
|
||||
"""Find peer's direction that reciprocates my_dir→peer_rank.
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""SFR configuration for intercube + inter-SIP IPCQ wiring.
|
||||
|
||||
Provides ``configure_sfr_intercube_multisip`` which programs PE_IPCQ
|
||||
neighbor tables for:
|
||||
|
||||
1. Intercube within each SIP — pe0 of every cube connects to pe0 of
|
||||
its N/S/E/W mesh neighbors (no wrap-around).
|
||||
2. Inter-SIP on ALL cubes — pe0 of cube_c on sip_A connects to pe0 of
|
||||
cube_c on each peer SIP, using ``global_E``/``global_W`` (ring) or
|
||||
``global_N``/``global_S``/``global_E``/``global_W`` (mesh/torus)
|
||||
direction labels. Wiring all cubes allows the kernel to
|
||||
dynamically elect the root cube at runtime.
|
||||
|
||||
SIP-level topology is read from ``topology.yaml`` →
|
||||
``system.sips.topology`` (e.g. ``ring_1d``, ``mesh_2d``).
|
||||
Intercube mesh dimensions come from ``sip.cube_mesh.w/h``.
|
||||
|
||||
Internally delegates to ``install_ipcq`` with a computed ``rank_to_pe``
|
||||
(pe0-only) and a closure-captured ``neighbors()`` function.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
from kernbench.ccl.install import install_ipcq
|
||||
from kernbench.ccl.topologies import _BUILTIN as _TOPO_BUILTINS
|
||||
|
||||
|
||||
def configure_sfr_intercube_multisip(
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
cfg: dict,
|
||||
) -> dict[str, Any]:
|
||||
"""Wire IPCQ for intercube (pe0, mesh) + inter-SIP (pe0, all cubes).
|
||||
|
||||
Args:
|
||||
engine: GraphEngine with ``_components``.
|
||||
spec: topology spec dict (from topology.yaml).
|
||||
cfg: merged algorithm config (from ``resolve_algorithm_config``).
|
||||
|
||||
Returns:
|
||||
The install plan dict from ``install_ipcq``.
|
||||
"""
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
mesh_w = int(cm["w"])
|
||||
mesh_h = int(cm["h"])
|
||||
n_cubes = mesh_w * mesh_h
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
sip_topology = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
|
||||
if sip_topology not in _TOPO_BUILTINS:
|
||||
raise ValueError(
|
||||
f"Unknown sip topology '{sip_topology}'. "
|
||||
f"Available: {list(_TOPO_BUILTINS)}"
|
||||
)
|
||||
sip_topo_fn = _TOPO_BUILTINS[sip_topology]
|
||||
|
||||
world_size = n_sips * n_cubes
|
||||
pe_idx_to_pe: list[tuple[int, int, int]] = [
|
||||
(sip, cube, 0)
|
||||
for sip in range(n_sips)
|
||||
for cube in range(n_cubes)
|
||||
]
|
||||
|
||||
def _neighbors(pe_idx: int, ws: int, _base: dict) -> dict[str, int]:
|
||||
sip = pe_idx // n_cubes
|
||||
cube = pe_idx % n_cubes
|
||||
row = cube // mesh_w
|
||||
col = cube % mesh_w
|
||||
|
||||
nbrs: dict[str, int] = {}
|
||||
|
||||
# Intercube within SIP (mesh, no wrap-around)
|
||||
if col < mesh_w - 1:
|
||||
nbrs["E"] = sip * n_cubes + (row * mesh_w + col + 1)
|
||||
if col > 0:
|
||||
nbrs["W"] = sip * n_cubes + (row * mesh_w + col - 1)
|
||||
if row < mesh_h - 1:
|
||||
nbrs["S"] = sip * n_cubes + ((row + 1) * mesh_w + col)
|
||||
if row > 0:
|
||||
nbrs["N"] = sip * n_cubes + ((row - 1) * mesh_w + col)
|
||||
|
||||
# Inter-SIP on ALL cubes
|
||||
if n_sips > 1:
|
||||
sip_nbrs = sip_topo_fn(sip, n_sips)
|
||||
for d, peer_sip in sip_nbrs.items():
|
||||
nbrs[f"global_{d}"] = peer_sip * n_cubes + cube
|
||||
|
||||
return nbrs
|
||||
|
||||
mock_module = types.SimpleNamespace(neighbors=_neighbors)
|
||||
|
||||
cfg_copy = dict(cfg)
|
||||
cfg_copy["world_size"] = world_size
|
||||
cfg_copy["topology"] = "none"
|
||||
|
||||
return install_ipcq(
|
||||
engine, spec, cfg_copy,
|
||||
algo_module=mock_module,
|
||||
rank_to_pe=pe_idx_to_pe,
|
||||
)
|
||||
@@ -1,492 +0,0 @@
|
||||
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
|
||||
|
||||
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
|
||||
no PE_DMA, no fabric simulation. Just enough to verify *functional*
|
||||
correctness of an IPCQ-based collective algorithm.
|
||||
|
||||
Cross-rank send/recv is implemented with greenlet cooperative scheduling
|
||||
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
|
||||
queues are unbounded.
|
||||
|
||||
Typical usage in a test::
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
from kernbench.ccl.algorithms.ring_allreduce import kernel
|
||||
|
||||
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel, world_size=4, topology="ring_1d",
|
||||
inputs=inputs, kernel_args=(16,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], sum(inputs))
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.ccl.topologies import resolve_topology
|
||||
from kernbench.common.ipcq_types import IpcqInvalidDirection
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
|
||||
# ── Per-rank fake state ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockRankState:
|
||||
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
neighbors: dict[str, int],
|
||||
input_arr: np.ndarray,
|
||||
pes_per_cube: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
|
||||
# all ranks are in one cube (legacy single-cube behavior).
|
||||
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
|
||||
self.neighbors = neighbors # direction → peer rank
|
||||
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
|
||||
self._hbm: dict[int, np.ndarray] = {}
|
||||
self._tcm: dict[int, np.ndarray] = {}
|
||||
# ``t_ptr`` is the address the kernel sees. Real benches use a
|
||||
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
|
||||
# Mirror that here: each rank's slice lives at the rank-specific addr.
|
||||
nbytes = int(input_arr.nbytes)
|
||||
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
|
||||
self._slice_addr = rank * nbytes
|
||||
self._hbm[self._slice_addr] = input_arr.copy()
|
||||
# Inbound recv FIFOs: direction → deque[ndarray]
|
||||
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
|
||||
# Output (set when kernel calls tl.store at slice address)
|
||||
self.output: np.ndarray | None = None
|
||||
# Greenlet for this rank — set later
|
||||
self.g: greenlet | None = None
|
||||
|
||||
|
||||
# ── Mock TLContext ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockTL:
|
||||
"""Drop-in tl shim for mock runtime.
|
||||
|
||||
Supports the subset of TLContext API that algorithm authors use:
|
||||
program_id, num_programs, load, store, send, recv, recv_async, wait,
|
||||
plus arithmetic operations on TensorHandle (eager numpy execution,
|
||||
no SimPy involved).
|
||||
"""
|
||||
|
||||
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
|
||||
self._state = state
|
||||
self._scheduler = scheduler
|
||||
self._handle_counter = 0
|
||||
|
||||
def _next_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"mt{self._handle_counter}"
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._state.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._state.world_size
|
||||
|
||||
# axis-aware
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
# Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
|
||||
# Falls back to flat (all ranks in one cube) if pes_per_cube
|
||||
# is not set (legacy single-cube tests).
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.rank // ppc
|
||||
return self._state.rank % ppc
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.world_size // ppc
|
||||
return ppc
|
||||
|
||||
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
|
||||
|
||||
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
if a_data is None or b_data is None:
|
||||
result = None
|
||||
elif op == "add":
|
||||
result = a_data + b_data
|
||||
elif op == "sub":
|
||||
result = a_data - b_data
|
||||
elif op == "mul":
|
||||
result = a_data * b_data
|
||||
elif op == "div":
|
||||
result = a_data / b_data
|
||||
elif op == "maximum":
|
||||
result = np.maximum(a_data, b_data)
|
||||
elif op == "minimum":
|
||||
result = np.minimum(a_data, b_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("maximum", a, b)
|
||||
|
||||
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("minimum", a, b)
|
||||
|
||||
def fma(
|
||||
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
c_data = np.asarray(c.data) if c.data is not None else None
|
||||
result = (
|
||||
a_data * b_data + c_data
|
||||
if (a_data is not None and b_data is not None and c_data is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def clamp(
|
||||
self,
|
||||
x: TensorHandle,
|
||||
min: TensorHandle,
|
||||
max: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
lo = np.asarray(min.data) if min.data is not None else None
|
||||
hi = np.asarray(max.data) if max.data is not None else None
|
||||
result = (
|
||||
np.minimum(np.maximum(x_data, lo), hi)
|
||||
if (x_data is not None and lo is not None and hi is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
else:
|
||||
x_max = np.max(x_data, axis=axis, keepdims=True)
|
||||
e = np.exp(x_data - x_max)
|
||||
s = np.sum(e, axis=axis, keepdims=True)
|
||||
result = e / s
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(-int(a) // int(b))
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
elif op == "exp":
|
||||
result = np.exp(x_data)
|
||||
elif op == "log":
|
||||
result = np.log(x_data)
|
||||
elif op == "sqrt":
|
||||
result = np.sqrt(x_data)
|
||||
elif op == "abs":
|
||||
result = np.abs(x_data)
|
||||
elif op == "sigmoid":
|
||||
result = 1.0 / (1.0 + np.exp(-x_data))
|
||||
elif op == "cos":
|
||||
result = np.cos(x_data)
|
||||
elif op == "sin":
|
||||
result = np.sin(x_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
data = self._state._hbm.get(ptr)
|
||||
if data is None:
|
||||
data = np.zeros(shape, dtype=np.float16)
|
||||
return TensorHandle(
|
||||
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
|
||||
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
|
||||
)
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
if handle.data is not None:
|
||||
self._state._hbm[ptr] = np.asarray(handle.data)
|
||||
if ptr == self._state._slice_addr:
|
||||
self._state.output = self._state._hbm[ptr]
|
||||
|
||||
# IPCQ
|
||||
def send(
|
||||
self,
|
||||
dir: str,
|
||||
src: TensorHandle | None = None,
|
||||
*,
|
||||
src_addr: int | None = None,
|
||||
nbytes: int | None = None,
|
||||
shape: tuple[int, ...] | None = None,
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
) -> None:
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
if src is not None:
|
||||
if src.data is not None:
|
||||
data = np.asarray(src.data)
|
||||
else:
|
||||
# Resolve from this rank's local memory at src.addr
|
||||
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
|
||||
stored = space_dict.get(src.addr)
|
||||
if stored is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
|
||||
)
|
||||
data = np.asarray(stored)
|
||||
else:
|
||||
data = None
|
||||
if data is None:
|
||||
raise RuntimeError("mock tl.send: src is None")
|
||||
peer_rank = self._state.neighbors[dir]
|
||||
# Find the reverse direction at the peer, mirroring real IPCQ
|
||||
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
|
||||
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
|
||||
"parent": "parent", "child_left": "child_left",
|
||||
"child_right": "child_right"}
|
||||
peer_state = self._scheduler.states[peer_rank]
|
||||
reverse_dir = _REVERSE.get(dir)
|
||||
# Fall back to "first direction pointing at me" if the explicit
|
||||
# reverse doesn't exist at the peer (e.g. custom directions).
|
||||
if reverse_dir is None or reverse_dir not in peer_state.neighbors:
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
if reverse_dir is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
|
||||
)
|
||||
peer_state.recv_q[reverse_dir].append(data.copy())
|
||||
self._scheduler._send_counter += 1
|
||||
# After delivering, hand control back to scheduler so the receiver
|
||||
# can wake up.
|
||||
self._scheduler.yield_()
|
||||
|
||||
def recv_async(
|
||||
self,
|
||||
dir: str,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> dict:
|
||||
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv_async: direction {dir!r} not in neighbors"
|
||||
)
|
||||
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
|
||||
|
||||
def wait(self, future: Any) -> TensorHandle:
|
||||
"""Block until the recv future has data."""
|
||||
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
|
||||
raise TypeError("tl.wait: expected recv future from tl.recv_async")
|
||||
d = future["dir"]
|
||||
while not self._state.recv_q[d]:
|
||||
self._scheduler.yield_()
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, future["dtype"])
|
||||
|
||||
def recv(
|
||||
self,
|
||||
dir: str | None = None,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
if dir is not None and dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
# Wait for data
|
||||
while True:
|
||||
if dir is None:
|
||||
# round-robin over directions
|
||||
for d in self._state.neighbors:
|
||||
if self._state.recv_q[d]:
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, dtype)
|
||||
else:
|
||||
if self._state.recv_q[dir]:
|
||||
data = self._state.recv_q[dir].popleft()
|
||||
return self._make_handle(data, dir, dtype)
|
||||
# Yield to other ranks
|
||||
self._scheduler.yield_()
|
||||
|
||||
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=f"recv_{direction}",
|
||||
addr=0, shape=data.shape, dtype=dtype,
|
||||
nbytes=int(data.nbytes), data=data, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
# ── Cooperative scheduler ────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockScheduler:
|
||||
"""Round-robin cooperative scheduler over rank greenlets."""
|
||||
|
||||
def __init__(self, states: list[_MockRankState]) -> None:
|
||||
self.states = states
|
||||
self._parent: greenlet | None = None
|
||||
self._cur_idx = 0
|
||||
|
||||
def yield_(self) -> None:
|
||||
"""Called from inside a rank greenlet to give other ranks a turn."""
|
||||
assert self._parent is not None
|
||||
self._parent.switch()
|
||||
|
||||
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
self._parent = greenlet.getcurrent()
|
||||
n = len(self.states)
|
||||
|
||||
# Per-rank tl shim
|
||||
tls: dict[int, _MockTL] = {}
|
||||
|
||||
def _spawn(rank_idx: int) -> greenlet:
|
||||
state = self.states[rank_idx]
|
||||
tl = _MockTL(state, self)
|
||||
tls[rank_idx] = tl
|
||||
|
||||
def _entry():
|
||||
# Activate this rank's tl for TensorHandle operator overloads
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
|
||||
return greenlet(_entry)
|
||||
|
||||
for state in self.states:
|
||||
state.g = _spawn(state.rank)
|
||||
|
||||
# Drive each rank round-robin until all dead. Detect global deadlock.
|
||||
# A global send counter tracks whether any greenlet delivered data
|
||||
# in the current round. This is more reliable than queue-depth
|
||||
# tracking because a recv+send pair in the same round nets to zero
|
||||
# depth change yet still represents real progress.
|
||||
self._send_counter = 0
|
||||
max_idle_rounds = 10_000
|
||||
idle_rounds = 0
|
||||
while True:
|
||||
alive = [s for s in self.states if s.g is not None and not s.g.dead]
|
||||
if not alive:
|
||||
break
|
||||
counter_before = self._send_counter
|
||||
for s in self.states:
|
||||
if s.g is None or s.g.dead:
|
||||
continue
|
||||
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
|
||||
s.g.switch()
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
any_died = any(s.g is not None and s.g.dead for s in self.states)
|
||||
if self._send_counter > counter_before or any_died:
|
||||
idle_rounds = 0
|
||||
else:
|
||||
idle_rounds += 1
|
||||
if idle_rounds >= max_idle_rounds:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_idle_rounds} rounds)"
|
||||
)
|
||||
|
||||
return [
|
||||
s.output if s.output is not None else s._hbm.get(s._slice_addr)
|
||||
for s in self.states
|
||||
]
|
||||
|
||||
|
||||
# ── Public entry ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_kernel_in_mock(
|
||||
kernel_fn: Callable,
|
||||
world_size: int,
|
||||
topology: str,
|
||||
inputs: list[np.ndarray],
|
||||
kernel_args: tuple = (),
|
||||
algo_module: Any | None = None,
|
||||
pes_per_cube: int = 0,
|
||||
) -> list[np.ndarray]:
|
||||
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
|
||||
|
||||
Args:
|
||||
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
|
||||
world_size: number of ranks
|
||||
topology: builtin topology name (e.g. "ring_1d")
|
||||
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
|
||||
local tile at HBM address 0.
|
||||
kernel_args: extra positional args after t_ptr
|
||||
algo_module: optional module providing ``neighbors()`` override
|
||||
pes_per_cube: PEs per cube for multi-cube program_id mapping.
|
||||
0 → single-cube legacy (all ranks in one cube).
|
||||
|
||||
Returns:
|
||||
Per-rank output ndarrays — whatever the kernel wrote via tl.store
|
||||
(or the original input if the kernel didn't store).
|
||||
"""
|
||||
if len(inputs) != world_size:
|
||||
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
|
||||
|
||||
topo_fn = resolve_topology(topology, algo_module=algo_module)
|
||||
states = [
|
||||
_MockRankState(
|
||||
rank=r, world_size=world_size,
|
||||
neighbors=topo_fn(r, world_size),
|
||||
input_arr=inputs[r],
|
||||
pes_per_cube=pes_per_cube,
|
||||
)
|
||||
for r in range(world_size)
|
||||
]
|
||||
|
||||
sched = _MockScheduler(states)
|
||||
return sched.run(kernel_fn, kernel_args)
|
||||
@@ -73,6 +73,39 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap:
|
||||
return n
|
||||
|
||||
|
||||
def torus_2d(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D torus (N/S/E/W) with wrap-around on all edges.
|
||||
|
||||
Alias for mesh_2d (which already wraps). Explicit name for clarity
|
||||
when used as a SIP-level topology.
|
||||
"""
|
||||
return mesh_2d(rank, world_size)
|
||||
|
||||
|
||||
def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D mesh (N/S/E/W) WITHOUT wrap-around.
|
||||
|
||||
Edge nodes have fewer neighbors (no wrapping). Used for SIP-level
|
||||
topologies where physical links don't wrap.
|
||||
"""
|
||||
side = int(round(world_size ** 0.5))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_2d_no_wrap requires square world_size, got {world_size}"
|
||||
)
|
||||
r, c = divmod(rank, side)
|
||||
n: NeighborMap = {}
|
||||
if r > 0:
|
||||
n["N"] = (r - 1) * side + c
|
||||
if r < side - 1:
|
||||
n["S"] = (r + 1) * side + c
|
||||
if c > 0:
|
||||
n["W"] = r * side + (c - 1)
|
||||
if c < side - 1:
|
||||
n["E"] = r * side + (c + 1)
|
||||
return n
|
||||
|
||||
|
||||
def none(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Empty map — algorithm's neighbors() must build from scratch."""
|
||||
return {}
|
||||
@@ -82,6 +115,8 @@ _BUILTIN: dict[str, TopologyFn] = {
|
||||
"ring_1d": ring_1d,
|
||||
"ring_1d_unidir": ring_1d_unidir,
|
||||
"mesh_2d": mesh_2d,
|
||||
"torus_2d": torus_2d,
|
||||
"mesh_2d_no_wrap": mesh_2d_no_wrap,
|
||||
"tree_binary": tree_binary,
|
||||
"none": none,
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ Host bench code uses only real-PyTorch names:
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -40,31 +41,35 @@ class AhbmCCLBackend:
|
||||
self._merged = resolve_algorithm_config(self._cfg_all)
|
||||
self._algo_module = importlib.import_module(self._merged["module"])
|
||||
self._world_size = self._resolve_world_size()
|
||||
# ADR-0024 D7: handles pending drain by the main scheduler.
|
||||
# Worker greenlets extend this list after submitting their collective
|
||||
# kernel, then yield. The bench `run()` loop drains the list after
|
||||
# all workers yielded (so all sibling kernels are live in SimPy
|
||||
# before any rank waits, avoiding cross-rank deadlock).
|
||||
self._pending_collective_handles: list = []
|
||||
self._dist_ctx: Any = None
|
||||
|
||||
# Eager IPCQ install — ``init_process_group`` time. Mirrors NCCL
|
||||
# communicator creation: done once, reused across every subsequent
|
||||
# collective call on the same process group.
|
||||
# ADR-0024 D2: rank → SIP representative PE mapping when world_size
|
||||
# fits in the topology's SIP count. Legacy "rank = flat PE index" is
|
||||
# preserved when ccl.yaml explicitly overrides world_size > SIP count
|
||||
# (backward compat path).
|
||||
spec = self.ctx.spec or {}
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
if self._world_size <= n_sips:
|
||||
rank_to_pe = [(r, 0, 0) for r in range(self._world_size)]
|
||||
else:
|
||||
rank_to_pe = None
|
||||
self.ctx.install_ipcq(
|
||||
algorithm=self._merged["algorithm"],
|
||||
world_size_override=self._world_size,
|
||||
rank_to_pe=rank_to_pe,
|
||||
self._n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
self._sip_topo = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
self._cube_w = int(cm.get("w", 4))
|
||||
self._cube_h = int(cm.get("h", 4))
|
||||
|
||||
# Resolve SIP topology dims for the kernel
|
||||
topo_map = getattr(self._algo_module, "TOPO_NAME_TO_KIND", None)
|
||||
if topo_map is not None:
|
||||
self._sip_topo_kind = topo_map.get(self._sip_topo, 0)
|
||||
else:
|
||||
self._sip_topo_kind = 0
|
||||
if self._sip_topo == "ring_1d":
|
||||
self._sip_topo_w, self._sip_topo_h = 0, 0
|
||||
else:
|
||||
side = int(round(math.sqrt(self._n_sips)))
|
||||
self._sip_topo_w, self._sip_topo_h = side, side
|
||||
|
||||
# IPCQ install: wire all pe0s across all cubes and SIPs
|
||||
engine = getattr(self.ctx, "engine", None)
|
||||
if engine is not None:
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
configure_sfr_intercube_multisip(engine, spec, self._merged)
|
||||
|
||||
def _resolve_world_size(self) -> int:
|
||||
"""Derive world_size (priority: algorithm override > defaults > topology).
|
||||
@@ -109,15 +114,26 @@ class AhbmCCLBackend:
|
||||
n_elem = shards[0].nbytes // tensor.itemsize
|
||||
kernel_fn = self._algo_module.kernel
|
||||
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
|
||||
# ADR-0024 D7: submit + yield. When running under the multi-greenlet
|
||||
# bench launcher, the scheduler (not the worker) drains the pending
|
||||
# handles. This is required because env.run must be invoked from the
|
||||
# MAIN greenlet — otherwise kernel_runner's spawned kernel-greenlet
|
||||
# captures the worker-greenlet as its `_parent`, and kernel
|
||||
# switch_to_simpy() returns control to the main scheduler loop
|
||||
# mid-wait, causing nested re-entry and the scheduler to spin.
|
||||
|
||||
# Resolve sip_rank from the current greenlet's bound rank
|
||||
from greenlet import getcurrent as _gc
|
||||
g = _gc()
|
||||
dist_ctx = getattr(self, "_dist_ctx", None)
|
||||
if dist_ctx is not None:
|
||||
sip_rank = int(dist_ctx._rank_by_greenlet.get(g, 0))
|
||||
else:
|
||||
sip_rank = 0
|
||||
|
||||
extra_args = (
|
||||
sip_rank,
|
||||
self._sip_topo_kind,
|
||||
self._sip_topo_w,
|
||||
self._sip_topo_h,
|
||||
)
|
||||
|
||||
pending = self.ctx.launch(
|
||||
self._merged["algorithm"], kernel_fn, tensor, *kernel_args,
|
||||
self._merged["algorithm"], kernel_fn, tensor,
|
||||
*kernel_args, *extra_args,
|
||||
_defer_wait=True,
|
||||
)
|
||||
from greenlet import getcurrent
|
||||
@@ -181,6 +197,7 @@ class DistributedContext:
|
||||
"DistributedContext not bound to a RuntimeContext"
|
||||
)
|
||||
self._backend = AhbmCCLBackend(torch_ctx=ctx)
|
||||
self._backend._dist_ctx = self
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._backend is not None
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Config-driven multi-device allreduce test application.
|
||||
|
||||
Reads ``ccl.yaml`` + ``topology.yaml``, dynamically loads the kernel
|
||||
module from ``ccl.yaml → module``, and picks the inter-SIP exchange
|
||||
pattern from ``topology.yaml → system.sips.topology``.
|
||||
|
||||
Run directly::
|
||||
|
||||
python -m pytest tests/allreduce_app.py -v -s
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
|
||||
def _sip_topo_dims(sip_topo: str, n_sips: int) -> tuple[int, int]:
|
||||
if sip_topo == "ring_1d":
|
||||
return (0, 0)
|
||||
side = int(round(math.sqrt(n_sips)))
|
||||
if side * side != n_sips:
|
||||
raise ValueError(
|
||||
f"SIP topology '{sip_topo}' requires square n_sips, got {n_sips}"
|
||||
)
|
||||
return (side, side)
|
||||
|
||||
|
||||
def run_allreduce(
|
||||
ctx: Any,
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
*,
|
||||
algorithm: str | None = None,
|
||||
ccl_yaml: str | None = None,
|
||||
) -> dict:
|
||||
"""Config-driven allreduce: read yaml, load kernel, run.
|
||||
|
||||
Everything is resolved from config — no hardcoded kernel imports.
|
||||
"""
|
||||
cfg_all = load_ccl_config(ccl_yaml)
|
||||
cfg = resolve_algorithm_config(cfg_all, algorithm)
|
||||
|
||||
# Dynamic import from ccl.yaml → module
|
||||
algo_module = importlib.import_module(cfg["module"])
|
||||
kernel_fn = algo_module.kernel
|
||||
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
|
||||
|
||||
n_elem = int(cfg.get("n_elem", 8))
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
sip_topo = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
cube_w = int(cm["w"])
|
||||
cube_h = int(cm["h"])
|
||||
n_cubes = cube_w * cube_h
|
||||
|
||||
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
|
||||
sip_topo_w, sip_topo_h = _sip_topo_dims(sip_topo, n_sips)
|
||||
|
||||
algo_name = cfg.get("algorithm", "allreduce")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"algorithm: {algo_name}")
|
||||
print(f"module: {cfg['module']}")
|
||||
print(f"sip_topology: {sip_topo}")
|
||||
print(f"kernel: {kernel_fn.__name__}")
|
||||
print(f"n_sips: {n_sips}")
|
||||
print(f"n_cubes: {n_cubes}")
|
||||
print(f"n_elem: {n_elem}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
dp = DPPolicy(
|
||||
cube="row_wise", pe="replicate",
|
||||
num_pes=1, num_cubes=n_cubes,
|
||||
)
|
||||
|
||||
tensors = []
|
||||
for sip in range(n_sips):
|
||||
ctx.ahbm.set_device(sip)
|
||||
t = ctx.zeros(
|
||||
(n_cubes, n_elem), dtype="f16", dp=dp,
|
||||
name=f"sip{sip}",
|
||||
)
|
||||
t.copy_(ctx.from_numpy(
|
||||
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
|
||||
))
|
||||
tensors.append(t)
|
||||
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
print(f"[SIP {sip}] input cube0[:4] = {arr[0][:4].tolist()} "
|
||||
f"cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
||||
|
||||
t_start = engine._env.now
|
||||
|
||||
all_pending = []
|
||||
for sip_rank, t in enumerate(tensors):
|
||||
pending = ctx.launch(
|
||||
algo_name, kernel_fn, t,
|
||||
n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h,
|
||||
_defer_wait=True,
|
||||
)
|
||||
all_pending.extend(pending)
|
||||
|
||||
for h, sip_id, meta in all_pending:
|
||||
ctx.wait(h, _meta=meta)
|
||||
|
||||
t_end = engine._env.now
|
||||
latency_ns = t_end - t_start
|
||||
print(f"\n[{algo_name} ws={n_sips}] sim latency = "
|
||||
f"{latency_ns:.1f} ns ({latency_ns / 1000:.3f} us)")
|
||||
|
||||
for key, (_, trace) in engine._results.items():
|
||||
if not isinstance(trace, dict):
|
||||
continue
|
||||
total = trace.get("total_ns", 0.0)
|
||||
pe_exec = trace.get("pe_exec_ns", 0.0) or 0.0
|
||||
network = total - pe_exec
|
||||
print(f" [{key}] total={total:.1f} ns "
|
||||
f"pe_exec={pe_exec:.1f} ns network={network:.1f} ns")
|
||||
|
||||
expected = float(n_cubes * sum(range(1, n_sips + 1)))
|
||||
|
||||
print()
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
print(f"[SIP {sip}] output cube0[:4] = {arr[0][:4].tolist()}")
|
||||
print(f"[SIP {sip}] output cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
||||
|
||||
ok_cubes = 0
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
for cube_id in range(n_cubes):
|
||||
assert np.allclose(
|
||||
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
|
||||
), (
|
||||
f"SIP{sip} cube {cube_id}: "
|
||||
f"got {arr[cube_id][:4]}, expected {expected}"
|
||||
)
|
||||
ok_cubes += 1
|
||||
|
||||
print(f"\n {algo_name} (ws={n_sips}): {ok_cubes} OK")
|
||||
|
||||
return {
|
||||
"expected": expected,
|
||||
"latency_ns": latency_ns,
|
||||
"ok_cubes": ok_cubes,
|
||||
}
|
||||
|
||||
|
||||
# ── pytest entry point ───────────────────────────────────────────────
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
CONFIGS = [
|
||||
pytest.param("intercube_allreduce", "ring_1d", 2, id="ring_2sip"),
|
||||
pytest.param("intercube_allreduce", "torus_2d", 4, id="torus_4sip"),
|
||||
pytest.param("intercube_allreduce", "mesh_2d_no_wrap", 4, id="mesh_4sip"),
|
||||
]
|
||||
|
||||
|
||||
def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm):
|
||||
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
|
||||
with open(TOPOLOGY_PATH) as f:
|
||||
topo_cfg = yaml.safe_load(f)
|
||||
topo_cfg["system"]["sips"]["count"] = n_sips
|
||||
topo_cfg["system"]["sips"]["topology"] = sip_topology
|
||||
topo_path = tmp_path / "topology.yaml"
|
||||
with open(topo_path, "w") as f:
|
||||
yaml.dump(topo_cfg, f, default_flow_style=False)
|
||||
|
||||
ccl_path = Path(__file__).parent.parent / "ccl.yaml"
|
||||
with open(ccl_path) as f:
|
||||
ccl_cfg = yaml.safe_load(f)
|
||||
ccl_cfg["defaults"]["algorithm"] = algorithm
|
||||
tmp_ccl = tmp_path / "ccl.yaml"
|
||||
with open(tmp_ccl, "w") as f:
|
||||
yaml.dump(ccl_cfg, f, default_flow_style=False)
|
||||
|
||||
return str(topo_path), str(tmp_ccl)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm,sip_topology,n_sips", CONFIGS)
|
||||
def test_allreduce(tmp_path, algorithm, sip_topology, n_sips):
|
||||
topo_path, ccl_path = _write_temp_configs(
|
||||
tmp_path, sip_topology, n_sips, algorithm,
|
||||
)
|
||||
topo = resolve_topology(topo_path)
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
spec = topo.topology_obj.spec
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id=f"test_{algorithm}_{sip_topology}",
|
||||
spec=spec,
|
||||
) as ctx:
|
||||
result = run_allreduce(
|
||||
ctx, engine, spec,
|
||||
algorithm=algorithm, ccl_yaml=ccl_path,
|
||||
)
|
||||
assert result["ok_cubes"] > 0
|
||||
@@ -1,108 +0,0 @@
|
||||
"""End-to-end matrix tests for the unified ``ccl_allreduce`` bench.
|
||||
|
||||
Only covers the rank = SIP TP launcher path (ADR-0024 + ADR-0027). Each
|
||||
case writes a tmp ``ccl.yaml`` that selects a specific (algorithm,
|
||||
buffer_kind) pair; ``world_size`` is always derived from topology SIP
|
||||
count (2 in the shipped topology).
|
||||
|
||||
The legacy rank = PE single-driver path was removed; intra-SIP PE-level
|
||||
collectives are expressed inside the kernel via ``tl.program_id`` and do
|
||||
not require a host-side ``ProcessGroup``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import kernbench.cli.main as cli_main
|
||||
|
||||
|
||||
CCL_YAML_TEMPLATE = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: {algorithm}
|
||||
buffer_kind: {buffer_kind}
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
{algorithm}:
|
||||
module: {module}
|
||||
topology: {topology}
|
||||
buffer_kind: {buffer_kind}
|
||||
""")
|
||||
|
||||
|
||||
def _write_ccl_yaml(
|
||||
tmp_path,
|
||||
*,
|
||||
algorithm: str,
|
||||
module: str,
|
||||
topology: str,
|
||||
buffer_kind: str,
|
||||
) -> str:
|
||||
body = CCL_YAML_TEMPLATE.format(
|
||||
algorithm=algorithm,
|
||||
module=module,
|
||||
topology=topology,
|
||||
buffer_kind=buffer_kind,
|
||||
)
|
||||
(tmp_path / "ccl.yaml").write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
CASES = [
|
||||
# Ring all-reduce across SIPs (ws == topology SIP count = 2),
|
||||
# one case per IPCQ buffer location.
|
||||
pytest.param(
|
||||
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "tcm",
|
||||
id="ring_tcm",
|
||||
),
|
||||
pytest.param(
|
||||
"ring_allreduce_hbm", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "hbm",
|
||||
id="ring_hbm",
|
||||
),
|
||||
pytest.param(
|
||||
"ring_allreduce_sram", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "sram",
|
||||
id="ring_sram",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm,module,topology,buffer_kind", CASES)
|
||||
def test_ccl_allreduce_matrix(
|
||||
tmp_path, capsys, monkeypatch,
|
||||
algorithm, module, topology, buffer_kind,
|
||||
):
|
||||
"""Each (algorithm × buffer_kind) combo passes through the unified
|
||||
rank = SIP bench and yields ``ws OK`` where ``ws == topology SIP count``."""
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..")
|
||||
)
|
||||
yaml_dir = _write_ccl_yaml(
|
||||
tmp_path,
|
||||
algorithm=algorithm,
|
||||
module=module,
|
||||
topology=topology,
|
||||
buffer_kind=buffer_kind,
|
||||
)
|
||||
monkeypatch.chdir(yaml_dir)
|
||||
rc = cli_main.main([
|
||||
"run",
|
||||
"--topology", os.path.join(project_root, "topology.yaml"),
|
||||
"--bench", "ccl_allreduce",
|
||||
"--verify-data",
|
||||
])
|
||||
assert rc == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "FAIL" not in out, f"unexpected FAIL in output:\n{out}"
|
||||
assert f"{algorithm}" in out and "OK" in out, (
|
||||
f"expected pass line for '{algorithm}' in output:\n{out}"
|
||||
)
|
||||
@@ -1,244 +0,0 @@
|
||||
"""Phase 1 tests for ADR-0024 SIP-level TP launcher (MVP scope).
|
||||
|
||||
Covers:
|
||||
- D1 world_size = SIP count fallback
|
||||
- D9 get_rank greenlet-local + _bind_rank
|
||||
- D10 torch.ahbm.set_device + torch.accelerator alias
|
||||
- D11 tensor placement scoped to current device SIP
|
||||
- D12/D13 run() spawns one greenlet per rank
|
||||
|
||||
Deferred to later ADR-0024 sub-phases:
|
||||
- D2 engine-routed install
|
||||
- D6 install_plan.py
|
||||
- D7 epoch barrier (this phase uses simple submit+yield+wait)
|
||||
- D8 validator registry
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.runtime_api.distributed import AhbmCCLBackend, DistributedContext
|
||||
|
||||
|
||||
# ── Fixtures / helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeCtx:
|
||||
"""Minimal ctx double — only exposes what AhbmCCLBackend.__init__ uses.
|
||||
|
||||
Stubs install_ipcq so we can unit-test _resolve_world_size without
|
||||
touching the engine stack.
|
||||
"""
|
||||
|
||||
def __init__(self, spec: dict) -> None:
|
||||
self.spec = spec
|
||||
self.install_calls: list[dict] = []
|
||||
|
||||
def install_ipcq(self, **kwargs) -> dict:
|
||||
self.install_calls.append(dict(kwargs))
|
||||
return {}
|
||||
|
||||
|
||||
def _write_minimal_ccl_yaml(tmp_path) -> str:
|
||||
"""Write a ccl.yaml with NO world_size override — forces topology derivation."""
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: ring_allreduce_tcm
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
""")
|
||||
yaml_path = tmp_path / "ccl.yaml"
|
||||
yaml_path.write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
# ── D1: world_size = SIP count fallback ───────────────────────────────
|
||||
|
||||
|
||||
def test_world_size_equals_sip_count(tmp_path, monkeypatch, spec):
|
||||
"""With no override, backend derives world_size from SIP count only.
|
||||
|
||||
Topology has 2 SIPs × 16 cubes × 8 PEs = 256 PEs. The TP/DP model
|
||||
places the collective group at the SIP boundary, so world_size must
|
||||
equal SIP count (2), not total PE count (256).
|
||||
"""
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
ctx = _FakeCtx(spec=spec)
|
||||
backend = AhbmCCLBackend(torch_ctx=ctx)
|
||||
expected = int(spec["system"]["sips"]["count"])
|
||||
assert backend.world_size == expected, (
|
||||
f"expected world_size == SIP count ({expected}); "
|
||||
f"got {backend.world_size} — still deriving sips × cubes × pes"
|
||||
)
|
||||
|
||||
|
||||
# ── D9: get_rank greenlet-local + _bind_rank ──────────────────────────
|
||||
|
||||
|
||||
def test_get_rank_is_greenlet_local(tmp_path, monkeypatch, spec):
|
||||
"""Each greenlet sees its own rank via dist.get_rank().
|
||||
|
||||
Framework-level launcher binds greenlet → rank; get_rank() resolves
|
||||
the current greenlet and returns that rank.
|
||||
"""
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
ctx = _FakeCtx(spec=spec)
|
||||
dc = DistributedContext()
|
||||
dc._ctx_ref = ctx
|
||||
dc.init_process_group(backend="ahbm")
|
||||
assert dc.get_world_size() == int(spec["system"]["sips"]["count"])
|
||||
|
||||
assert hasattr(dc, "_bind_rank"), (
|
||||
"DistributedContext must expose _bind_rank(g, rank) hook"
|
||||
)
|
||||
|
||||
seen: dict[int, int] = {}
|
||||
|
||||
def _probe(rank: int) -> None:
|
||||
seen[rank] = dc.get_rank()
|
||||
|
||||
g0 = greenlet(lambda: _probe(0))
|
||||
g1 = greenlet(lambda: _probe(1))
|
||||
dc._bind_rank(g0, 0)
|
||||
dc._bind_rank(g1, 1)
|
||||
g0.switch()
|
||||
g1.switch()
|
||||
|
||||
assert seen == {0: 0, 1: 1}, (
|
||||
f"expected each greenlet to see its own rank; got {seen}"
|
||||
)
|
||||
|
||||
|
||||
def test_get_rank_fallback_without_bind(tmp_path, monkeypatch, spec):
|
||||
"""Unbound greenlet falls back to rank 0 (single-driver compat)."""
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
ctx = _FakeCtx(spec=spec)
|
||||
dc = DistributedContext()
|
||||
dc._ctx_ref = ctx
|
||||
dc.init_process_group(backend="ahbm")
|
||||
# Call from main (unbound) greenlet
|
||||
assert dc.get_rank() == 0
|
||||
|
||||
|
||||
# ── D10/D11: torch.ahbm.set_device + tensor scoping ───────────────────
|
||||
|
||||
|
||||
def test_ahbm_set_device_binds_tensor_to_single_sip(topology):
|
||||
"""``torch.ahbm.set_device(rank)`` + default-sip DPPolicy → tensor on SIP rank.
|
||||
|
||||
After set_device(1), a tensor with DPPolicy leaving the SIP dimension
|
||||
at its default must be placed entirely on SIP 1.
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_ahbm_set_device",
|
||||
spec=topology.topology_obj.spec,
|
||||
) as ctx:
|
||||
assert hasattr(ctx, "ahbm"), (
|
||||
"RuntimeContext must expose .ahbm namespace (ADR-0024 D10)"
|
||||
)
|
||||
ctx.ahbm.set_device(1)
|
||||
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise") # default sip
|
||||
tensor = ctx.zeros((1, 128), dtype="f16", dp=dp, name="probe")
|
||||
|
||||
shard_sips = {s.sip for s in tensor._handle.shards}
|
||||
assert shard_sips == {1}, (
|
||||
f"after ahbm.set_device(1), all shards should live on SIP 1; "
|
||||
f"got sips={sorted(shard_sips)}"
|
||||
)
|
||||
|
||||
|
||||
def test_accelerator_alias_mirrors_ahbm(topology):
|
||||
"""torch.accelerator.set_device_index(r) is an alias for ahbm.set_device(r)
|
||||
(PyTorch 2.x device-agnostic surface)."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_accelerator_alias",
|
||||
spec=topology.topology_obj.spec,
|
||||
) as ctx:
|
||||
assert hasattr(ctx, "accelerator"), (
|
||||
"RuntimeContext must expose .accelerator namespace (ADR-0024 D10)"
|
||||
)
|
||||
ctx.accelerator.set_device_index(1)
|
||||
# Both namespaces should report SIP 1 as current device
|
||||
assert ctx.ahbm.current_device() == 1
|
||||
assert ctx.accelerator.current_device_index() == 1
|
||||
|
||||
|
||||
# ── D12/D13: run() spawns one worker per rank ─────────────────────────
|
||||
|
||||
|
||||
def test_run_spawns_one_worker_per_rank(tmp_path, monkeypatch, spec):
|
||||
"""The bench's ``run()`` invokes ``worker`` once per rank.
|
||||
|
||||
With world_size = SIP count = 2 (topology), worker must be called
|
||||
exactly twice with ranks 0 and 1. Each call sees world_size=2.
|
||||
"""
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..")
|
||||
)
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
|
||||
import benches.ccl_allreduce as bench
|
||||
|
||||
calls: list[tuple[int, int]] = []
|
||||
|
||||
def _fake_worker(rank, cfg, torch) -> None:
|
||||
calls.append((rank, cfg.world_size))
|
||||
|
||||
monkeypatch.setattr(bench, "_worker", _fake_worker)
|
||||
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_run_spawns",
|
||||
spec=topo.topology_obj.spec,
|
||||
) as ctx:
|
||||
bench.run(ctx)
|
||||
|
||||
ranks = sorted(r for r, _ in calls)
|
||||
ws_values = {ws for _, ws in calls}
|
||||
expected_ws = int(spec["system"]["sips"]["count"])
|
||||
assert ranks == list(range(expected_ws)), (
|
||||
f"run() should invoke worker for ranks 0..{expected_ws - 1}; "
|
||||
f"saw ranks={ranks}"
|
||||
)
|
||||
assert ws_values == {expected_ws}, (
|
||||
f"each worker should see world_size={expected_ws}; saw {ws_values}"
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Tests for IPCQ deadlock detection (ADR-0023 D14 F3)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import simpy
|
||||
|
||||
from kernbench.ccl import diagnostics
|
||||
from kernbench.common.ipcq_types import (
|
||||
IpcqEndpoint,
|
||||
IpcqInitEntry,
|
||||
IpcqRecvCmd,
|
||||
IpcqRequest,
|
||||
)
|
||||
from kernbench.components.builtin.pe_ipcq import PeIpcqComponent
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTxn:
|
||||
request: Any
|
||||
done: simpy.Event
|
||||
result_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _make_isolated_pe_ipcq(env):
|
||||
node = Node(
|
||||
id="sip0.cube0.pe0.pe_ipcq", kind="pe_ipcq",
|
||||
impl="builtin.pe_ipcq", attrs={}, pos_mm=None,
|
||||
)
|
||||
comp = PeIpcqComponent(node, ctx=None)
|
||||
comp.in_ports["host"] = simpy.Store(env)
|
||||
comp.out_ports["sip0.cube0.pe0.pe_dma"] = simpy.Store(env)
|
||||
comp.start(env)
|
||||
|
||||
peer_credit = simpy.Store(env)
|
||||
ep = IpcqEndpoint(
|
||||
sip=0, cube=0, pe=1, buffer_kind="tcm",
|
||||
rx_base_pa=0x10_000, rx_base_va=0,
|
||||
n_slots=4, slot_size=4096,
|
||||
)
|
||||
init_msg = IpcqInitMsg(
|
||||
correlation_id="t", request_id="t",
|
||||
target_sips=(0,), target_cubes=(0,), target_pe=0,
|
||||
entries=(IpcqInitEntry(
|
||||
direction="W", peer=ep,
|
||||
my_rx_base_pa=0x40_000, my_rx_base_va=0,
|
||||
n_slots=4, slot_size=4096,
|
||||
peer_credit_store=peer_credit,
|
||||
),),
|
||||
backpressure_mode="sleep",
|
||||
buffer_kind="tcm",
|
||||
credit_size_bytes=16,
|
||||
)
|
||||
done = env.event()
|
||||
comp.in_ports["host"].put(_FakeTxn(request=init_msg, done=done))
|
||||
env.run(until=done)
|
||||
return comp
|
||||
|
||||
|
||||
def test_pointer_dump_includes_blocked_state():
|
||||
"""A blocked recv should still be visible in the pointer dump."""
|
||||
env = simpy.Environment()
|
||||
comp = _make_isolated_pe_ipcq(env)
|
||||
|
||||
# Issue a recv that will block (no data has arrived)
|
||||
recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r1")
|
||||
req = IpcqRequest(command=recv_cmd, done=env.event())
|
||||
comp.in_ports["host"].put(req)
|
||||
env.run(until=10)
|
||||
assert not req.done.triggered
|
||||
|
||||
# Pointer dump should show my_tail=0 and peer_head_cache=0
|
||||
# We need to use the engine API but for an isolated component, just call directly
|
||||
class FakeEngine:
|
||||
_components = {"sip0.cube0.pe0.pe_ipcq": comp}
|
||||
|
||||
dump = diagnostics.pointer_dump(FakeEngine())
|
||||
assert "my_tail=0" in dump
|
||||
assert "peer_head_cache=0" in dump
|
||||
|
||||
|
||||
def test_deadlock_detection_recv_without_send():
|
||||
"""A recv with no matching sender → SimPy schedule empties → engine
|
||||
raises ``IpcqDeadlock`` with a pointer dump.
|
||||
"""
|
||||
from kernbench.ccl.diagnostics import IpcqDeadlock
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
def deadlock_kernel(t_ptr, n_elem, tl):
|
||||
# Every PE just receives, no sends → no one delivers → deadlock
|
||||
tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
def run(torch):
|
||||
torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1,
|
||||
),
|
||||
name="dl_in",
|
||||
)
|
||||
torch.launch("dl", deadlock_kernel, a, 8)
|
||||
|
||||
with pytest.raises(IpcqDeadlock):
|
||||
run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Tests for CCL diagnostics: trace + pointer dump (ADR-0023 D14)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from kernbench.ccl import diagnostics
|
||||
|
||||
|
||||
# ── trace toggle ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_trace_disabled_by_default(monkeypatch):
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE", raising=False)
|
||||
diagnostics.reload_trace_setting()
|
||||
assert diagnostics.trace_enabled() is False
|
||||
|
||||
|
||||
def test_trace_enabled_via_env(monkeypatch):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
assert diagnostics.trace_enabled() is True
|
||||
|
||||
|
||||
def test_trace_record_send(monkeypatch, capsys):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
diagnostics.log_send(t_ns=100.0, sender="sip0.cube0.pe0",
|
||||
direction="E", nbytes=64, sender_seq=0)
|
||||
out = capsys.readouterr().out
|
||||
assert "send" in out
|
||||
assert "sip0.cube0.pe0" in out
|
||||
assert "dir=E" in out
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
|
||||
diagnostics.reload_trace_setting()
|
||||
|
||||
|
||||
def test_trace_record_recv(monkeypatch, capsys):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
diagnostics.log_recv(t_ns=200.0, receiver="sip0.cube0.pe1",
|
||||
direction="W", nbytes=64)
|
||||
out = capsys.readouterr().out
|
||||
assert "recv" in out
|
||||
assert "sip0.cube0.pe1" in out
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
|
||||
diagnostics.reload_trace_setting()
|
||||
|
||||
|
||||
# ── pointer dump ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_pointer_dump_format():
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
from kernbench.ccl.install import (
|
||||
install_ipcq, load_ccl_config, resolve_algorithm_config,
|
||||
)
|
||||
|
||||
topo = resolve_topology("topology.yaml").topology_obj
|
||||
engine = GraphEngine(topo, enable_data=True)
|
||||
cfg = resolve_algorithm_config(load_ccl_config(), name="ring_allreduce_tcm")
|
||||
install_ipcq(engine, topo.spec, cfg)
|
||||
|
||||
dump = diagnostics.pointer_dump(engine)
|
||||
# 8 ranks × 2 directions = 16 lines (plus 8 PE headers)
|
||||
assert "sip0.cube0.pe0" in dump
|
||||
assert "E:" in dump
|
||||
assert "W:" in dump
|
||||
assert "my_head=" in dump
|
||||
assert "peer_tail_cache=" in dump
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Validate the hello-world example from docs/ccl-author-guide.md.
|
||||
|
||||
This is the simplest possible CCL kernel — each PE sends its tile E
|
||||
and receives a tile from W. After running, each rank's slice should
|
||||
contain the data of the previous rank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.algorithms import hello_send
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def test_hello_send_4_ranks_mock():
|
||||
n_elem = 8
|
||||
inputs = [np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(4)]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=hello_send.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem,),
|
||||
)
|
||||
|
||||
# rank r should have rank (r-1) % 4's data
|
||||
for r in range(4):
|
||||
prev = inputs[(r - 1) % 4]
|
||||
assert np.array_equal(outputs[r], prev), f"rank {r}: got {outputs[r]}"
|
||||
|
||||
|
||||
def test_hello_send_via_simpy_runner():
|
||||
"""Same but through real SimPy + IPCQ."""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
n_elem = 8
|
||||
world_size = 8
|
||||
|
||||
def run(torch):
|
||||
# World size for this hello test is 8 (one cube). ccl.yaml no
|
||||
# longer carries a default world_size — pass it explicitly.
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=world_size,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, world_size * n_elem), dtype="f16",
|
||||
dp=DPPolicy(
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1,
|
||||
),
|
||||
name="hello_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = n_elem * 2
|
||||
for r in range(world_size):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16))
|
||||
|
||||
torch.launch("hello_send", hello_send.kernel, a, n_elem)
|
||||
|
||||
# Each rank should hold the previous rank's data after the round
|
||||
for r in range(world_size):
|
||||
arr = store.read("hbm", base + r * nbytes, shape=(n_elem,), dtype="f16")
|
||||
prev_value = float(((r - 1) % world_size) + 1)
|
||||
assert np.allclose(arr, prev_value), f"rank {r}: got {arr}, expected {prev_value}"
|
||||
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.ccl.install import (
|
||||
install_ipcq,
|
||||
linear_rank_to_pe,
|
||||
load_ccl_config,
|
||||
resolve_algorithm_config,
|
||||
@@ -26,28 +25,14 @@ def test_resolve_algorithm_config_default():
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg)
|
||||
assert merged["algorithm"] == cfg["defaults"]["algorithm"]
|
||||
# ccl.yaml no longer carries defaults.world_size — backend derives
|
||||
# it from topology.yaml at install time. Just check the field is
|
||||
# absent here (verified per-test where install_ipcq is called).
|
||||
assert "world_size" not in merged or merged["world_size"] >= 1
|
||||
|
||||
|
||||
def test_resolve_algorithm_config_override():
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_hbm")
|
||||
assert merged["algorithm"] == "ring_allreduce_hbm"
|
||||
assert merged["buffer_kind"] == "hbm" # algo override
|
||||
# defaults still apply
|
||||
assert merged["n_slots"] == cfg["defaults"]["n_slots"]
|
||||
|
||||
|
||||
def test_linear_rank_to_pe():
|
||||
engine, topo = _engine()
|
||||
spec = topo.spec
|
||||
# Cube 0 of SIP 0
|
||||
assert linear_rank_to_pe(0, spec) == (0, 0, 0)
|
||||
assert linear_rank_to_pe(7, spec) == (0, 0, 7)
|
||||
# Should not exceed total PE count
|
||||
pes_per_sip = (
|
||||
spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"]
|
||||
* spec["cube"]["pe_layout"]["pe_per_corner"]
|
||||
@@ -56,105 +41,3 @@ def test_linear_rank_to_pe():
|
||||
sips = spec["system"]["sips"]["count"]
|
||||
total = sips * pes_per_sip
|
||||
assert total >= 8
|
||||
|
||||
|
||||
def test_install_ipcq_neighbors_correct():
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
# Force a single-cube 8-rank install for the assertions below.
|
||||
merged["world_size"] = 8
|
||||
plan = install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
assert plan["world_size"] == 8
|
||||
assert plan["buffer_kind"] == "tcm"
|
||||
|
||||
# Each rank should have E and W entries
|
||||
for r, nbrs in plan["neighbor_table"].items():
|
||||
assert "E" in nbrs
|
||||
assert "W" in nbrs
|
||||
|
||||
# Inspect installed PE_IPCQ for rank 0
|
||||
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
qp_e = ipcq.queue_pairs["E"]
|
||||
qp_w = ipcq.queue_pairs["W"]
|
||||
assert qp_e["peer"].pe == 1 # rank 0's E neighbor is rank 1
|
||||
assert qp_w["peer"].pe == 7 # rank 0's W neighbor is rank 7
|
||||
# rx_base addresses should be unique
|
||||
assert qp_e["my_rx_base_pa"] != qp_w["my_rx_base_pa"]
|
||||
|
||||
|
||||
def test_install_ipcq_credit_stores_wired():
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 8
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
# rank 0 (pe0) sending E goes to rank 1 (pe1)
|
||||
# rank 0's peer_credit_store on E direction should equal rank 1's credit_inbox
|
||||
pe0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
pe1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
|
||||
qp_e = pe0.queue_pairs["E"]
|
||||
assert qp_e["peer_credit_store"] is pe1.credit_inbox
|
||||
|
||||
|
||||
# ── ADR-0025 D1: reverse_direction opposite-preference ───────────────
|
||||
|
||||
|
||||
def test_reverse_direction_opposite_preference_2rank_ring():
|
||||
"""ADR-0025 D1: In a 2-rank bidirectional ring both E and W point to the
|
||||
same peer; reverse_direction must pick the OPPOSITE direction (W for E,
|
||||
E for W) so rx_base targets the semantically-correct slot.
|
||||
|
||||
Concretely: rank 0 sending via E to rank 1 must target rank 1's W-rx
|
||||
buffer (not rank 1's E-rx), because rank 1's kernel recv(W) reads from
|
||||
its W-rx.
|
||||
"""
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 2
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
|
||||
rank1_e_rx = ipcq1.queue_pairs["E"]["my_rx_base_pa"]
|
||||
rank1_w_rx = ipcq1.queue_pairs["W"]["my_rx_base_pa"]
|
||||
|
||||
qp0_e = ipcq0.queue_pairs["E"]
|
||||
qp0_w = ipcq0.queue_pairs["W"]
|
||||
|
||||
# rank 0's E entry should target rank 1's W-rx (opposite), NOT rank 1's E-rx.
|
||||
assert qp0_e["peer"].rx_base_pa == rank1_w_rx, (
|
||||
f"expected rank 0's E peer.rx_base_pa == rank 1's W-rx ({rank1_w_rx:#x}), "
|
||||
f"got {qp0_e['peer'].rx_base_pa:#x} (matches E-rx: {rank1_e_rx:#x}) — "
|
||||
f"reverse_direction picked same-label instead of opposite"
|
||||
)
|
||||
# rank 0's W entry should target rank 1's E-rx (opposite).
|
||||
assert qp0_w["peer"].rx_base_pa == rank1_e_rx
|
||||
|
||||
|
||||
def test_reverse_direction_opposite_preference_4rank_ring_sanity():
|
||||
"""ADR-0025 D1 sanity: ws>=3 ring. E and W have distinct peers, so
|
||||
opposite-preference produces same result as old dict-order first-match.
|
||||
This test should PASS both under current and post-fix code.
|
||||
"""
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 4
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
ipcq3 = engine._components["sip0.cube0.pe3.pe_ipcq"]
|
||||
|
||||
# rank 0 E → rank 1 → rank 1's W-rx
|
||||
qp0_e = ipcq0.queue_pairs["E"]
|
||||
assert qp0_e["peer"].rx_base_pa == ipcq1.queue_pairs["W"]["my_rx_base_pa"]
|
||||
# rank 0 W → rank 3 (last in ring) → rank 3's E-rx
|
||||
qp0_w = ipcq0.queue_pairs["W"]
|
||||
assert qp0_w["peer"].rx_base_pa == ipcq3.queue_pairs["E"]["my_rx_base_pa"]
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Tests for the mock CCL runtime (ADR-0023 D15)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.algorithms import ring_allreduce
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def test_ring_allreduce_4_ranks():
|
||||
"""Run the ring all-reduce kernel under the mock runtime, no SimPy."""
|
||||
n_elem = 8
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(4)
|
||||
]
|
||||
expected = sum(inputs) # [10, 10, ..., 10]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 4),
|
||||
)
|
||||
|
||||
assert len(outputs) == 4
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_ring_allreduce_8_ranks():
|
||||
n_elem = 16
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(8)
|
||||
]
|
||||
expected = sum(inputs) # [36, 36, ...]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=8,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 8),
|
||||
)
|
||||
for r in range(8):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_ring_allreduce_random_data():
|
||||
n_elem = 32
|
||||
rng = np.random.default_rng(42)
|
||||
inputs = [rng.standard_normal(n_elem).astype(np.float16) for _ in range(4)]
|
||||
expected = sum(inputs)
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 4),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
def test_mock_runtime_invalid_direction_raises():
|
||||
"""A kernel that uses an unsupported direction should raise."""
|
||||
import pytest
|
||||
|
||||
def bad_kernel(t_ptr, n_elem, tl):
|
||||
tl.send(dir="N", src_addr=0, nbytes=2, shape=(1,), dtype="f16", space="hbm")
|
||||
|
||||
inputs = [np.array([1.0], dtype=np.float16) for _ in range(2)]
|
||||
with pytest.raises(Exception):
|
||||
run_kernel_in_mock(
|
||||
kernel_fn=bad_kernel,
|
||||
world_size=2,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(1,),
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
"""CCL performance validation tests (ADR-0023 D13 T5).
|
||||
|
||||
Sanity-checks the simulated latency of the unified ``ccl_allreduce`` bench
|
||||
under the rank = SIP TP launcher model (ADR-0024 / ADR-0027). Uses the
|
||||
topology-derived world_size (= 2 in the shipped topology); the latency
|
||||
model is topology-aware, so buffer_kind differences remain visible even
|
||||
at this scale.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
|
||||
def _engine_factory(topology, device):
|
||||
return GraphEngine(getattr(topology, "topology_obj", topology), enable_data=True)
|
||||
|
||||
|
||||
def _run_ring(algorithm: str, buffer_kind: str = "tcm") -> float:
|
||||
"""Run a rank = SIP ring all-reduce via the unified bench with a tmp
|
||||
ccl.yaml overlay. Returns simulated kernel total_ns."""
|
||||
import tempfile
|
||||
|
||||
body = f"""\
|
||||
defaults:
|
||||
algorithm: {algorithm}
|
||||
buffer_kind: {buffer_kind}
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
{algorithm}:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: {buffer_kind}
|
||||
n_elem: 32
|
||||
"""
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with open(os.path.join(tmp, "ccl.yaml"), "w") as f:
|
||||
f.write(body)
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp)
|
||||
try:
|
||||
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
|
||||
bench_mod = importlib.import_module("benches.ccl_allreduce")
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=bench_mod.run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=_engine_factory,
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
assert result.completion.ok, f"{algorithm} did not complete"
|
||||
last_kernel = None
|
||||
for tr in (result.traces or []):
|
||||
if tr.get("phase") == "kernel":
|
||||
last_kernel = tr
|
||||
assert last_kernel is not None, f"{algorithm} produced no kernel trace"
|
||||
return float(last_kernel.get("total_ns", 0.0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("buffer_kind", ["tcm", "hbm", "sram"])
|
||||
def test_ccl_latency_positive(buffer_kind):
|
||||
"""Every buffer kind must produce a positive simulated latency."""
|
||||
algo = f"ring_allreduce_{buffer_kind}"
|
||||
ns = _run_ring(algo, buffer_kind)
|
||||
assert ns > 0
|
||||
|
||||
|
||||
def test_ccl_latency_under_reasonable_bound():
|
||||
"""rank = SIP ring all-reduce (tile=32 f16) should finish well under 1ms."""
|
||||
ns = _run_ring("ring_allreduce_tcm", "tcm")
|
||||
assert ns < 1_000_000 # < 1 ms simulated
|
||||
@@ -0,0 +1,119 @@
|
||||
"""End-to-end distributed test for intercube allreduce.
|
||||
|
||||
Exercises the full process-group path:
|
||||
dist.init_process_group(backend="ahbm")
|
||||
→ mp.spawn(nprocs=n_sips)
|
||||
→ each worker: set_device → allocate → fill → dist.all_reduce → verify
|
||||
|
||||
This is the same flow a real DDP training script would use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
N_CUBES = 16
|
||||
N_ELEM = 8
|
||||
|
||||
|
||||
def _write_ccl_yaml(tmp_path) -> str:
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: intercube_allreduce
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
intercube_allreduce:
|
||||
module: kernbench.ccl.algorithms.intercube_allreduce
|
||||
topology: none
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
root_cube: 15
|
||||
""")
|
||||
(tmp_path / "ccl.yaml").write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def _worker(rank: int, n_sips: int, torch) -> None:
|
||||
"""Per-SIP worker: allocate, fill, all_reduce, verify."""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
torch.ahbm.set_device(rank)
|
||||
|
||||
dp = DPPolicy(
|
||||
cube="row_wise", pe="replicate",
|
||||
num_pes=1, num_cubes=N_CUBES,
|
||||
)
|
||||
tensor = torch.zeros(
|
||||
(N_CUBES, N_ELEM), dtype="f16", dp=dp,
|
||||
name=f"sip{rank}",
|
||||
)
|
||||
|
||||
init_arr = np.full((N_CUBES, N_ELEM), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(init_arr))
|
||||
|
||||
print(f"[SIP {rank}] input cube0[:4] = {tensor.numpy()[0][:4].tolist()}")
|
||||
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
|
||||
arr = tensor.numpy()
|
||||
expected = float(N_CUBES * sum(range(1, n_sips + 1)))
|
||||
|
||||
print(f"[SIP {rank}] output cube0[:4] = {arr[0][:4].tolist()}")
|
||||
print(f"[SIP {rank}] output cube15[:4] = {arr[15][:4].tolist()}")
|
||||
|
||||
for cube_id in range(N_CUBES):
|
||||
assert np.allclose(arr[cube_id], expected, rtol=1e-1, atol=1e-1), (
|
||||
f"SIP{rank} cube {cube_id}: "
|
||||
f"got {arr[cube_id][:4]}, expected {expected}"
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n intercube_allreduce (ws={n_sips}): "
|
||||
f"{n_sips * N_CUBES} OK")
|
||||
|
||||
|
||||
def test_distributed_intercube_allreduce(tmp_path, monkeypatch):
|
||||
"""Full distributed path: init_process_group → mp.spawn → all_reduce."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
monkeypatch.chdir(_write_ccl_yaml(tmp_path))
|
||||
|
||||
topo = resolve_topology(str(TOPOLOGY_PATH))
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
spec = topo.topology_obj.spec
|
||||
n_sips = int(spec["system"]["sips"]["count"])
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="dist_intercube_ar",
|
||||
spec=spec,
|
||||
) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
|
||||
assert ctx.distributed.get_world_size() == n_sips
|
||||
|
||||
t_start = engine._env.now
|
||||
|
||||
ctx.multiprocessing.spawn(
|
||||
_worker, args=(n_sips, ctx), nprocs=n_sips,
|
||||
)
|
||||
|
||||
t_end = engine._env.now
|
||||
print(f"\n[distributed] sim latency = "
|
||||
f"{t_end - t_start:.1f} ns ({(t_end - t_start) / 1000:.3f} us)")
|
||||
@@ -1,270 +0,0 @@
|
||||
"""ADR-0027 T5: Host-read barrier (D0.5).
|
||||
|
||||
Phase 1: Tensor.numpy / data / __getitem__ / __repr__ / copy_ currently
|
||||
perform MemoryStore operations without barrier logic → tests fail when
|
||||
they assert drain is triggered. Phase 2 injects the barrier.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from greenlet import greenlet
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t5",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
# ── T5.g: closed-set registry exists ─────────────────────────────────
|
||||
|
||||
|
||||
def test_host_read_barrier_registry_exists():
|
||||
"""D0.5 T5.g: Tensor module exposes the closed-set registry."""
|
||||
from kernbench.runtime_api import tensor as tensor_mod
|
||||
|
||||
assert hasattr(tensor_mod, "_HOST_READ_BARRIERS"), (
|
||||
"ADR-0027 T5.g: tensor module must declare _HOST_READ_BARRIERS registry"
|
||||
)
|
||||
registry = tensor_mod._HOST_READ_BARRIERS
|
||||
assert isinstance(registry, frozenset)
|
||||
expected = {"numpy", "data", "__getitem__", "__repr__", "copy_"}
|
||||
assert expected.issubset(registry), (
|
||||
f"registry must include {expected}; got {registry}"
|
||||
)
|
||||
|
||||
|
||||
# ── T5.a: numpy() triggers drain when pending non-empty ──────────────
|
||||
|
||||
|
||||
def test_numpy_triggers_drain_when_pending(topology):
|
||||
"""T5.a: launch → numpy() → barrier drains before read (worker context)."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
observed: dict = {"pre_numpy_pending": None, "post_numpy_pending": None}
|
||||
|
||||
def _worker():
|
||||
t = ctx.zeros((1, 8), dtype="f16", dp=dp, name="t5a_t")
|
||||
src = np.full((1, 8), 1.5, dtype=np.float16)
|
||||
t.copy_(ctx.distributed._ctx_ref.from_numpy(src) if False else _hold(ctx, src))
|
||||
# Manually push a dummy handle to simulate pending state; in real
|
||||
# D0.5, numpy will detect and drain.
|
||||
observed["pre_numpy_pending"] = list(ctx._pending_worker_waits)
|
||||
_ = t.numpy()
|
||||
observed["post_numpy_pending"] = list(ctx._pending_worker_waits)
|
||||
|
||||
# Can't actually manufacture pending + test numpy inside worker
|
||||
# without D0.5 implemented — instead, verify the barrier path is
|
||||
# invoked by spying.
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
barrier_calls = {"n": 0}
|
||||
|
||||
original_numpy = Tensor.numpy
|
||||
|
||||
def _spy_numpy(self):
|
||||
# After D0.5 is implemented, this wrapper is redundant; the
|
||||
# test just checks numpy was called at all after a pending
|
||||
# operation.
|
||||
barrier_calls["n"] += 1
|
||||
return original_numpy(self)
|
||||
|
||||
Tensor.numpy = _spy_numpy # type: ignore[assignment]
|
||||
try:
|
||||
ctx.multiprocessing.spawn(_mk_worker_numpy, args=(ctx,), nprocs=1)
|
||||
finally:
|
||||
Tensor.numpy = original_numpy # type: ignore[assignment]
|
||||
|
||||
assert barrier_calls["n"] >= 1
|
||||
|
||||
|
||||
def _hold(ctx, arr):
|
||||
"""helper (unused branch)."""
|
||||
import numpy as _np
|
||||
t = type("X", (), {})()
|
||||
t.numpy = lambda self=None: arr
|
||||
return t
|
||||
|
||||
|
||||
def _mk_worker_numpy(rank, ctx):
|
||||
"""Worker that calls numpy after a tensor deploy. Triggers barrier."""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
t = ctx.zeros((1, 8), dtype="f16", dp=dp, name=f"t5_r{rank}")
|
||||
_ = t.numpy()
|
||||
|
||||
|
||||
# ── T5.b: metadata access does NOT drain ─────────────────────────────
|
||||
|
||||
|
||||
def test_metadata_access_is_non_barrier(topology):
|
||||
"""T5.b: .shape / .dtype / .name do NOT trigger drain."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.runtime_api import tensor as tensor_mod
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
t = ctx.zeros((1, 8), dtype="f16", dp=dp, name="t5b")
|
||||
|
||||
# Populate pending queue artificially (simulate worker state).
|
||||
ctx._pending_worker_waits.append("fake_handle_that_must_not_drain")
|
||||
|
||||
_ = t.shape
|
||||
_ = t.dtype
|
||||
_ = t.name
|
||||
|
||||
assert "fake_handle_that_must_not_drain" in ctx._pending_worker_waits, (
|
||||
"T5.b: metadata accessors must not drain pending queue"
|
||||
)
|
||||
ctx._pending_worker_waits.clear()
|
||||
|
||||
|
||||
# ── T5.c: empty pending → numpy is fast-path (no yield) ──────────────
|
||||
|
||||
|
||||
def test_numpy_fast_path_when_pending_empty(topology):
|
||||
"""T5.c: numpy() with empty pending queue does not yield to main."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
def _worker(rank: int):
|
||||
t = ctx.zeros((1, 4), dtype="f16", dp=dp, name=f"t5c_r{rank}")
|
||||
# At this point, after worker's own wait(s), pending should be empty.
|
||||
assert ctx._pending_worker_waits == [], (
|
||||
"after worker's deploy, pending queue should be drained"
|
||||
)
|
||||
# numpy call should be fast-path (no yield).
|
||||
_ = t.numpy()
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=1)
|
||||
|
||||
|
||||
# ── T5.d: __getitem__ / data also barriers ───────────────────────────
|
||||
|
||||
|
||||
def test_getitem_and_data_are_barriers(topology):
|
||||
"""T5.d: __getitem__ and .data property behave like numpy() barrier."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
def _worker(rank: int):
|
||||
t = ctx.zeros((1, 8), dtype="f16", dp=dp, name=f"t5d_r{rank}")
|
||||
# host src copied in (forces write path)
|
||||
src = np.full((1, 8), float(rank + 1), dtype=np.float16)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
h = Tensor(shape=src.shape, dtype="f16", name="host")
|
||||
h._host_buffer = src
|
||||
t.copy_(h)
|
||||
# Read access via __getitem__ and .data: both must fully materialize.
|
||||
slice_val = t[0, 0:4]
|
||||
data_val = t.data
|
||||
assert slice_val.shape[0] == 4
|
||||
assert data_val.shape == (1, 8)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=2)
|
||||
|
||||
|
||||
# ── T5.e: collective pending also drained by barrier ────────────────
|
||||
|
||||
|
||||
def test_numpy_drains_collective_pending(topology, tmp_path, monkeypatch):
|
||||
"""T5.e: numpy() after all_reduce must see post-reduce data.
|
||||
|
||||
Note: in the current model, ``all_reduce`` itself yields to main so the
|
||||
collective is drained before the worker resumes; barriers at
|
||||
``numpy()`` intentionally do NOT drain collective pending (would cause
|
||||
cross-rank deadlock — see ``_host_read_barrier`` docstring). What this
|
||||
test asserts is the observable contract: post-``all_reduce`` +
|
||||
``numpy()`` sees the reduced values.
|
||||
"""
|
||||
import textwrap
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: ring_allreduce_tcm
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
""")
|
||||
(tmp_path / "ccl.yaml").write_text(body)
|
||||
monkeypatch.chdir(str(tmp_path))
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
def _worker(rank: int, ws: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
t = ctx.zeros((1, 8), dtype="f16", dp=dp, name=f"t5e_r{rank}")
|
||||
src = np.full((1, 8), float(rank + 1), dtype=np.float16)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
h = Tensor(shape=src.shape, dtype="f16", name="host")
|
||||
h._host_buffer = src
|
||||
t.copy_(h)
|
||||
ctx.distributed.all_reduce(t, op="sum")
|
||||
# numpy() must see the reduced values even without explicit wait.
|
||||
out = t.numpy()
|
||||
expected = float(sum(range(1, ws + 1)))
|
||||
# Tolerance loose for fp16 accumulation.
|
||||
assert np.allclose(out, expected, rtol=1e-1, atol=1e-1), (
|
||||
f"rank {rank}: expected {expected}, got {out}"
|
||||
)
|
||||
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
ctx.multiprocessing.spawn(_worker, args=(ws,), nprocs=ws)
|
||||
|
||||
|
||||
# ── T5.f: copy_ target-side write barrier ────────────────────────────
|
||||
|
||||
|
||||
def test_copy_from_deployed_source_drains_source(topology):
|
||||
"""T5.f (revised): ``copy_(source)`` drains source-side pending via the
|
||||
``source.numpy()`` read barrier.
|
||||
|
||||
Note: the ADR originally specified a target-side write barrier as well,
|
||||
but that was removed because global-pending target barrier can cause
|
||||
cross-rank deadlock when another rank has a pending collective. Source-
|
||||
side read barrier is preserved and sufficient for the common pattern
|
||||
``target.copy_(deployed_source)``.
|
||||
"""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
def _worker(rank: int):
|
||||
# Deployed source — its .numpy() will trigger the read barrier.
|
||||
source = ctx.zeros((1, 8), dtype="f16", dp=dp, name=f"src_r{rank}")
|
||||
target = ctx.zeros((1, 8), dtype="f16", dp=dp, name=f"tgt_r{rank}")
|
||||
target.copy_(source)
|
||||
# Smoke: no hang, no exception. numpy round-trip sees zeros.
|
||||
out = target.numpy()
|
||||
assert out.shape == (1, 8)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=1)
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests for configure_sfr_intercube_multisip neighbor table wiring.
|
||||
|
||||
Verifies that IPCQ neighbor tables are correctly installed for
|
||||
intercube (pe0, 4×4 mesh N/S/E/W) + inter-SIP (pe0, all cubes,
|
||||
global_E/global_W) communication.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
N_CUBES = 16
|
||||
|
||||
|
||||
def _engine_and_spec():
|
||||
topo = resolve_topology(str(TOPOLOGY_PATH))
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
return engine, topo.topology_obj.spec
|
||||
|
||||
|
||||
def _merged_cfg():
|
||||
cfg = load_ccl_config()
|
||||
return resolve_algorithm_config(cfg, name="intercube_allreduce")
|
||||
|
||||
|
||||
class TestConfigureSfrNeighborTables:
|
||||
def test_world_size_and_rank_to_pe(self):
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
plan = configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
n_sips = int(spec["system"]["sips"]["count"])
|
||||
assert plan["world_size"] == n_sips * N_CUBES
|
||||
assert len(plan["rank_to_pe"]) == n_sips * N_CUBES
|
||||
for pe_idx, (sip, cube, pe) in enumerate(plan["rank_to_pe"]):
|
||||
assert pe == 0, f"pe_idx {pe_idx}: pe must be 0, got {pe}"
|
||||
|
||||
def test_corner_cube0_has_E_and_S_only(self):
|
||||
"""Cube 0 (row=0, col=0) is NW corner: only E and S neighbors."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert "E" in qp, "cube 0 must have E neighbor"
|
||||
assert "S" in qp, "cube 0 must have S neighbor"
|
||||
assert "W" not in qp, "cube 0 (col=0) must NOT have W neighbor"
|
||||
assert "N" not in qp, "cube 0 (row=0) must NOT have N neighbor"
|
||||
assert qp["E"]["peer"].cube == 1
|
||||
assert qp["S"]["peer"].cube == 4
|
||||
|
||||
def test_interior_cube5_has_all_four(self):
|
||||
"""Cube 5 (row=1, col=1) is interior: N/S/E/W all present."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq = engine._components["sip0.cube5.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert qp["N"]["peer"].cube == 1
|
||||
assert qp["S"]["peer"].cube == 9
|
||||
assert qp["E"]["peer"].cube == 6
|
||||
assert qp["W"]["peer"].cube == 4
|
||||
|
||||
def test_root_cube15_has_inter_sip(self):
|
||||
"""Cube 15 (root, SE corner) has N, W + global_E/global_W."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube15.pe0.pe_ipcq"]
|
||||
qp0 = ipcq0.queue_pairs
|
||||
assert "N" in qp0
|
||||
assert "W" in qp0
|
||||
assert "E" not in qp0, "cube 15 (col=3) must NOT have E"
|
||||
assert "S" not in qp0, "cube 15 (row=3) must NOT have S"
|
||||
assert "global_E" in qp0, "root cube must have global_E"
|
||||
assert "global_W" in qp0, "root cube must have global_W"
|
||||
assert qp0["global_E"]["peer"].sip == 1
|
||||
assert qp0["global_E"]["peer"].cube == 15
|
||||
|
||||
ipcq1 = engine._components["sip1.cube15.pe0.pe_ipcq"]
|
||||
qp1 = ipcq1.queue_pairs
|
||||
assert qp1["global_E"]["peer"].sip == 0
|
||||
assert qp1["global_E"]["peer"].cube == 15
|
||||
|
||||
def test_all_cubes_have_inter_sip(self):
|
||||
"""ALL cubes (not just root) are wired for inter-SIP."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
root_cube = int(cfg.get("root_cube", N_CUBES - 1))
|
||||
for cube_id in range(N_CUBES):
|
||||
ipcq = engine._components[f"sip0.cube{cube_id}.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert "global_E" in qp, (
|
||||
f"sip0.cube{cube_id}.pe0 missing global_E"
|
||||
)
|
||||
assert "global_W" in qp, (
|
||||
f"sip0.cube{cube_id}.pe0 missing global_W"
|
||||
)
|
||||
if cube_id == root_cube:
|
||||
assert qp["global_E"]["peer"].sip != 0, (
|
||||
f"root cube {root_cube} global_E must point to another SIP"
|
||||
)
|
||||
@@ -1,178 +0,0 @@
|
||||
"""ADR-0027 T4: torch.multiprocessing.spawn semantics.
|
||||
|
||||
Phase 1: imports `ctx.multiprocessing.spawn` which doesn't exist yet —
|
||||
tests fail. Phase 2 (D1) lands the namespace + _MultiprocessingNamespace
|
||||
+ SpawnException, and these pass.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
from greenlet import greenlet
|
||||
|
||||
|
||||
def _write_minimal_ccl_yaml(tmp_path) -> str:
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: ring_allreduce_tcm
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
""")
|
||||
yaml_path = tmp_path / "ccl.yaml"
|
||||
yaml_path.write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t4",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
# ── D1.3 namespace attach ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_multiprocessing_namespace_attached(topology):
|
||||
"""RuntimeContext.__post_init__ attaches ctx.multiprocessing (D1.3)."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
assert hasattr(ctx, "multiprocessing"), (
|
||||
"ADR-0027 D1.3: ctx.multiprocessing must exist"
|
||||
)
|
||||
assert hasattr(ctx.multiprocessing, "spawn"), (
|
||||
"ctx.multiprocessing must expose a spawn(fn, args, nprocs) method"
|
||||
)
|
||||
|
||||
|
||||
# ── D1.1 / D1.2: spawn shape + rank binding ──────────────────────────
|
||||
|
||||
|
||||
def test_spawn_invokes_fn_once_per_rank(topology):
|
||||
"""spawn(fn, args, nprocs) calls fn(rank, *args) once for each rank."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
calls: list[tuple[int, tuple]] = []
|
||||
|
||||
def _worker(rank: int, world_size: int) -> None:
|
||||
calls.append((rank, (world_size,)))
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(3,), nprocs=3)
|
||||
|
||||
assert sorted(r for r, _ in calls) == [0, 1, 2]
|
||||
for _, (ws,) in calls:
|
||||
assert ws == 3
|
||||
|
||||
|
||||
def test_spawn_binds_greenlet_local_rank(topology):
|
||||
"""Inside the worker, torch.distributed.get_rank() returns the rank
|
||||
bound to the greenlet (ADR-0024 D9 + D1.2)."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
# Distributed context needs to be initialised so get_rank is valid.
|
||||
# For T4 we don't run a real collective; just check rank lookup.
|
||||
observed: list[tuple[int, int]] = []
|
||||
|
||||
def _worker(rank: int):
|
||||
g = greenlet.getcurrent()
|
||||
bound = ctx.distributed._rank_by_greenlet.get(g)
|
||||
observed.append((rank, bound))
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=2)
|
||||
|
||||
for rank, bound in observed:
|
||||
assert rank == bound, (
|
||||
f"rank {rank} must be bound to greenlet-local rank {rank}; "
|
||||
f"got {bound}"
|
||||
)
|
||||
|
||||
|
||||
# ── D1.2 exception cleanup ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_spawn_exception_raises_spawn_exception_with_root_cause(topology):
|
||||
"""D0.4-(4): worker raise → siblings SystemExit + SpawnException(errors)."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.runtime_api.multiprocessing import SpawnException
|
||||
|
||||
def _worker(rank: int):
|
||||
if rank == 1:
|
||||
raise ValueError(f"rank {rank} boom")
|
||||
|
||||
with pytest.raises(SpawnException) as exc_info:
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=3)
|
||||
|
||||
# Root cause rank is captured.
|
||||
assert 1 in exc_info.value.errors
|
||||
assert isinstance(exc_info.value.errors[1], ValueError)
|
||||
|
||||
|
||||
def test_spawn_exception_clears_pending_queues(topology):
|
||||
"""D0.4-(4): on raise, _pending_worker_waits and collective queue clear."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.runtime_api.multiprocessing import SpawnException
|
||||
|
||||
def _worker(rank: int):
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with pytest.raises(SpawnException):
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=2)
|
||||
|
||||
assert ctx._pending_worker_waits == []
|
||||
|
||||
|
||||
# ── D1.4 migration compat: ccl_allreduce runs via mp.spawn ───────────
|
||||
|
||||
|
||||
def test_ccl_allreduce_hand_rolled_loop_replaced_by_mp_spawn(
|
||||
topology, tmp_path, monkeypatch, spec,
|
||||
):
|
||||
"""D1.4: benches/ccl_allreduce.py's hand-rolled greenlet loop must still
|
||||
produce correct behaviour after migration to torch.multiprocessing.spawn.
|
||||
|
||||
Minimal smoke — just that ``bench.run(ctx)`` completes without the
|
||||
loop short-circuiting or leaving pending queues dirty.
|
||||
"""
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
import benches.ccl_allreduce as bench
|
||||
|
||||
calls: list[tuple[int, int]] = []
|
||||
|
||||
def _fake_worker(rank, cfg, torch):
|
||||
calls.append((rank, cfg.world_size))
|
||||
|
||||
monkeypatch.setattr(bench, "_worker", _fake_worker)
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
bench.run(ctx)
|
||||
|
||||
expected_ws = int(spec["system"]["sips"]["count"])
|
||||
ranks = sorted(r for r, _ in calls)
|
||||
assert ranks == list(range(expected_ws))
|
||||
assert ctx._pending_worker_waits == []
|
||||
|
||||
|
||||
# ── _drain_pending function is exported ──────────────────────────────
|
||||
|
||||
|
||||
def test_drain_pending_exported():
|
||||
"""D0.4: _drain_pending must be importable from runtime_api.multiprocessing."""
|
||||
from kernbench.runtime_api.multiprocessing import _drain_pending
|
||||
assert callable(_drain_pending)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tests for recv_mode='copy_to_dst' (ADR-0023 D9.5)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_recv_copy_to_dst_via_simpy_runner():
|
||||
"""Run a kernel that uses tl.recv(..., dst_addr=..., dst_space=...).
|
||||
Verify the data is moved to the dst location after recv.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
def kernel(t_ptr, n_elem, dst_buf_addr, tl):
|
||||
rank = tl.program_id(axis=0)
|
||||
ws = tl.num_programs(axis=0)
|
||||
nbytes = n_elem * 2
|
||||
# Each PE sends own data, then recv into a custom dst slot
|
||||
current = TensorHandle(
|
||||
id="loc", addr=t_ptr + rank * nbytes,
|
||||
shape=(n_elem,), dtype="f16",
|
||||
nbytes=nbytes, data=None, space="hbm",
|
||||
)
|
||||
tl.send(dir="E", src=current)
|
||||
# copy_to_dst: move into a per-rank scratch HBM addr
|
||||
recv = tl.recv(
|
||||
dir="W", shape=(n_elem,), dtype="f16",
|
||||
dst_addr=dst_buf_addr + rank * nbytes,
|
||||
dst_space="hbm",
|
||||
)
|
||||
# Sanity: recv handle should now point to our dst addr
|
||||
assert recv.addr == dst_buf_addr + rank * nbytes
|
||||
assert recv.space == "hbm"
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
def run(torch):
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1,
|
||||
),
|
||||
name="copy_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = 8 * 2
|
||||
for r in range(8):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((8,), float(r + 1), dtype=np.float16))
|
||||
|
||||
# Use a separate dst region (synthetic addresses)
|
||||
dst_buf = 0xC0FFEE_0000
|
||||
torch.launch("ring_allreduce_tcm", kernel, a, 8, dst_buf)
|
||||
|
||||
# After the kernel, dst_buf + r*16 should contain rank (r-1)%8's data
|
||||
for r in range(8):
|
||||
arr = store.read("hbm", dst_buf + r * nbytes, shape=(8,), dtype="f16")
|
||||
expected = float(((r - 1) % 8) + 1)
|
||||
assert np.allclose(arr, expected), f"rank {r}: got {arr}, expected {expected}"
|
||||
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tests for tl.recv_async + tl.wait (ADR-0023 D4)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def kernel_async_recv(t_ptr, n_elem, tl):
|
||||
"""Each PE issues recv_async first, then send, then wait — this exercises
|
||||
the non-blocking path. Uses TensorHandle math (PE_MATH) for accumulation
|
||||
so Phase 2 produces correct final HBM contents."""
|
||||
rank = tl.program_id(axis=0)
|
||||
world_size = tl.num_programs(axis=0)
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
future = tl.recv_async(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.wait(future)
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
def test_recv_async_mock_runtime():
|
||||
n_elem = 8
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(4)
|
||||
]
|
||||
expected = sum(inputs)
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel_async_recv,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_recv_async_simpy_runner():
|
||||
"""Run the async kernel through the real SimPy stack via the
|
||||
install_ipcq + launch path.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
# Re-use the standard 8-PE bench skeleton but swap in the async kernel.
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
# Build a tiny inline bench module
|
||||
import types
|
||||
mod = types.ModuleType("inline_bench_async")
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
def run(torch):
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1,
|
||||
),
|
||||
name="async_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = 8 * 2
|
||||
for r in range(8):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((8,), float(r + 1), dtype=np.float16))
|
||||
|
||||
torch.launch("ring_allreduce_tcm", kernel_async_recv, a, 8)
|
||||
|
||||
for r in range(8):
|
||||
result = store.read("hbm", base + r * nbytes, shape=(8,), dtype="f16")
|
||||
expected = float(sum(range(1, 9))) # 36
|
||||
assert np.allclose(result, expected, rtol=1e-2, atol=1e-2), \
|
||||
f"rank {r}: got {result}, expected {expected}"
|
||||
|
||||
mod.run = run
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=mod.run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -1,301 +0,0 @@
|
||||
"""ADR-0027 T3: Worker-wait generalization + orphan invariant.
|
||||
|
||||
Direct regression guard for ADR-0024 Phase B's kernel-greenlet orphan bug.
|
||||
Phase 1 of ADR-0027: these tests fail against the current code (no
|
||||
``_pending_worker_waits`` field, no worker-fork in ``ctx.wait``, no
|
||||
scheduler drain). Phase 2 implements D0.1/D0.2/D0.4 and these pass.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
from greenlet import greenlet
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _write_minimal_ccl_yaml(tmp_path) -> str:
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: ring_allreduce_tcm
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
""")
|
||||
yaml_path = tmp_path / "ccl.yaml"
|
||||
yaml_path.write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t3",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
# ── D0.1: _pending_worker_waits field exists ─────────────────────────
|
||||
|
||||
|
||||
def test_pending_worker_waits_field_present(topology):
|
||||
"""RuntimeContext must expose the deferred-wait queue (D0.1)."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
assert hasattr(ctx, "_pending_worker_waits"), (
|
||||
"ADR-0027 D0.1: RuntimeContext must declare _pending_worker_waits"
|
||||
)
|
||||
assert ctx._pending_worker_waits == [], (
|
||||
"_pending_worker_waits should start empty"
|
||||
)
|
||||
|
||||
|
||||
# ── T3.a / T3.b: wait defers + resume-after-drain contract ───────────
|
||||
|
||||
|
||||
def test_wait_in_worker_defers_to_main_and_resumes_completed(topology):
|
||||
"""T3.a + T3.b: worker ctx.wait enqueues + yields; resume → _completed.
|
||||
|
||||
Direct test of D0.2 (worker-fork) + D0.3 resume invariant (handle must
|
||||
be in ctx._completed when worker resumes).
|
||||
"""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# Worker that submits one tensor (which internally calls ctx.wait)
|
||||
# and records the pending-queue state observed before/after.
|
||||
observations: dict = {"pre_wait_len": None, "post_resume_completed": None}
|
||||
|
||||
main = greenlet.getcurrent()
|
||||
|
||||
def _worker():
|
||||
# Observation hook: patch ctx.wait to capture a single deferral.
|
||||
original_wait = ctx.wait
|
||||
|
||||
def wrapping_wait(h, *, _meta=None):
|
||||
observations["pre_wait_len"] = len(ctx._pending_worker_waits)
|
||||
result = original_wait(h, _meta=_meta)
|
||||
observations["post_resume_completed"] = h in ctx._completed
|
||||
return result
|
||||
|
||||
ctx.wait = wrapping_wait # type: ignore[assignment]
|
||||
try:
|
||||
ctx.zeros(
|
||||
(1, 8), dtype="f16",
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="t3_defer",
|
||||
)
|
||||
finally:
|
||||
ctx.wait = original_wait # type: ignore[assignment]
|
||||
|
||||
g = greenlet(_worker)
|
||||
|
||||
# Scheduler loop: run worker until it yields (or finishes), then drain.
|
||||
while not g.dead:
|
||||
g.switch()
|
||||
if not g.dead:
|
||||
# Worker yielded mid-wait → simulate D0.4 drain.
|
||||
from kernbench.runtime_api.multiprocessing import _drain_pending
|
||||
_drain_pending(ctx)
|
||||
|
||||
assert observations["pre_wait_len"] is not None, "wait was not invoked"
|
||||
assert observations["post_resume_completed"] is True, (
|
||||
"D0.3 resume invariant: handle must be in ctx._completed on resume"
|
||||
)
|
||||
|
||||
|
||||
# ── T3.c: multi-worker same-round drain ──────────────────────────────
|
||||
|
||||
|
||||
def test_multiple_workers_resume_at_same_drain(topology):
|
||||
"""T3.c: every worker yields before any drain; all resume together."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
observations: list[int] = []
|
||||
|
||||
def _make_worker(rank: int):
|
||||
def _entry():
|
||||
# Before its wait, observe queue state so we can assert that
|
||||
# *every* worker has enqueued before any drain happened.
|
||||
ctx.zeros((1, 4), dtype="f16", dp=dp, name=f"r{rank}")
|
||||
observations.append(rank)
|
||||
return _entry
|
||||
|
||||
ws = 2
|
||||
gs = [greenlet(_make_worker(r)) for r in range(ws)]
|
||||
|
||||
# Round 1: every worker runs up to its first (deferred) ctx.wait.
|
||||
for g in gs:
|
||||
g.switch()
|
||||
|
||||
# After round 1, all workers should be paused (not yet dead) and
|
||||
# each should have enqueued at least one handle.
|
||||
assert all(not g.dead for g in gs), (
|
||||
"after round 1 switch, workers must be paused mid-wait, not dead"
|
||||
)
|
||||
assert len(ctx._pending_worker_waits) >= ws, (
|
||||
f"expected >= {ws} pending worker waits after round 1; "
|
||||
f"got {len(ctx._pending_worker_waits)}"
|
||||
)
|
||||
|
||||
# Loop: drain + switch rounds until all workers complete. A single
|
||||
# ctx.zeros() call contains multiple yield points (MmuMap, then
|
||||
# MemoryWrite), so more than one round is needed.
|
||||
from kernbench.runtime_api.multiprocessing import _drain_pending
|
||||
rounds = 0
|
||||
while any(not g.dead for g in gs):
|
||||
_drain_pending(ctx)
|
||||
for g in gs:
|
||||
if not g.dead:
|
||||
g.switch()
|
||||
rounds += 1
|
||||
assert rounds < 20, "scheduler did not converge within 20 rounds"
|
||||
|
||||
assert all(g.dead for g in gs), "all workers should be dead after drain loop"
|
||||
assert sorted(observations) == list(range(ws))
|
||||
|
||||
|
||||
# ── T3.d (핵심): kernel greenlet _parent is main ─────────────────────
|
||||
|
||||
|
||||
def test_kernel_greenlet_parent_is_main(topology, tmp_path, monkeypatch):
|
||||
"""T3.d orphan invariant: kernel_runner._parent must be main greenlet.
|
||||
|
||||
This is the direct regression guard for ADR-0024 Phase B. Runs a worker
|
||||
that invokes torch.launch (which eventually spawns a kernel greenlet).
|
||||
The kernel_runner.run() captures greenlet.getcurrent() as _parent at
|
||||
spawn time — that value MUST be the main greenlet, else the orphan
|
||||
bug is back.
|
||||
"""
|
||||
monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path))
|
||||
|
||||
from kernbench.triton_emu import kernel_runner as kr_mod
|
||||
captured_parents: list = []
|
||||
main = greenlet.getcurrent()
|
||||
|
||||
original_run = kr_mod.KernelRunner.run
|
||||
|
||||
def _spy_run(self, env, kernel_fn, kernel_args, num_programs):
|
||||
gen = original_run(self, env, kernel_fn, kernel_args, num_programs)
|
||||
|
||||
def _wrapping_gen():
|
||||
# yield from gen, but capture self._parent on first step
|
||||
try:
|
||||
value = next(gen)
|
||||
# First yield happens after _parent is set.
|
||||
captured_parents.append(self._parent)
|
||||
yield value
|
||||
except StopIteration:
|
||||
return
|
||||
yield from gen
|
||||
|
||||
return _wrapping_gen()
|
||||
|
||||
monkeypatch.setattr(kr_mod.KernelRunner, "run", _spy_run)
|
||||
|
||||
# Drive a minimal ring_allreduce that launches a kernel inside a worker.
|
||||
import benches.ccl_allreduce as bench
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
bench.run(ctx)
|
||||
|
||||
assert captured_parents, "no kernel_runner.run invocations observed"
|
||||
for p in captured_parents:
|
||||
assert p is main, (
|
||||
f"ADR-0027 D0.7 / T3.d: kernel greenlet _parent must be main "
|
||||
f"greenlet; got {p!r} (main={main!r})"
|
||||
)
|
||||
|
||||
|
||||
# ── T3.f: idempotency ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_wait_same_handle_twice_drives_engine_once(topology):
|
||||
"""T3.f: ctx.wait(h) + ctx.wait(h) → engine.wait called once (D0.4-(3))."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
call_count = {"n": 0}
|
||||
original_engine_wait = ctx.engine.wait
|
||||
|
||||
def _counting_wait(h):
|
||||
call_count["n"] += 1
|
||||
return original_engine_wait(h)
|
||||
|
||||
ctx.engine.wait = _counting_wait # type: ignore[assignment]
|
||||
|
||||
def _worker():
|
||||
ctx.zeros((1, 4), dtype="f16", dp=dp, name="t3f")
|
||||
# Manually pick a completed handle and wait twice.
|
||||
assert ctx._completed, "there should be at least one completed handle"
|
||||
h = next(iter(ctx._completed))
|
||||
before = call_count["n"]
|
||||
ctx.wait(h)
|
||||
ctx.wait(h)
|
||||
assert call_count["n"] == before, (
|
||||
"already-completed handle must not re-drive engine.wait"
|
||||
)
|
||||
|
||||
g = greenlet(_worker)
|
||||
while not g.dead:
|
||||
g.switch()
|
||||
if not g.dead:
|
||||
from kernbench.runtime_api.multiprocessing import _drain_pending
|
||||
_drain_pending(ctx)
|
||||
|
||||
|
||||
# ── T3.g: exception propagation + no further drain ───────────────────
|
||||
|
||||
|
||||
def test_worker_exception_propagates_and_clears_pending(topology):
|
||||
"""T3.g: worker raise → main propagates; _pending_worker_waits cleared."""
|
||||
with _make_ctx(topology) as ctx:
|
||||
from kernbench.runtime_api.multiprocessing import SpawnException
|
||||
|
||||
def _bad_worker(rank: int):
|
||||
raise ValueError(f"rank {rank} intentional failure")
|
||||
|
||||
with pytest.raises(SpawnException) as exc_info:
|
||||
ctx.multiprocessing.spawn(_bad_worker, args=(), nprocs=2)
|
||||
|
||||
assert ctx._pending_worker_waits == [], (
|
||||
"D0.4-(4): _pending_worker_waits must be cleared on failure"
|
||||
)
|
||||
# Root-cause rank errors are present; sibling SystemExit not in dict.
|
||||
assert 0 in exc_info.value.errors or 1 in exc_info.value.errors
|
||||
|
||||
|
||||
# ── T3.e: historical failure (pre-D0) — skipped per ADR ──────────────
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="ADR-0027 T3.e: historical failure mode — reproduces only "
|
||||
"pre-D0.2. Kept as documentation; not run in Phase 2."
|
||||
)
|
||||
def test_pre_d0_orphan_reproduction():
|
||||
"""Placeholder: exercises the pre-D0.2 code path that causes GreenletExit
|
||||
from kernel_runner._parent captured in worker context. See ADR-0024
|
||||
Phase B postmortem."""
|
||||
pass
|
||||
@@ -4,6 +4,7 @@ system:
|
||||
|
||||
sips:
|
||||
count: 2
|
||||
topology: ring_1d
|
||||
|
||||
components:
|
||||
switch: { kind: switch, impl: builtin.switch, attrs: { overhead_ns: 5.0 } }
|
||||
|
||||
Reference in New Issue
Block a user