Intercube allreduce: pe0 cube-mesh reduce + multi-SIP ring/torus/mesh
New intercube allreduce kernel replacing the old flat ring algorithms. Reduces across the 4x4 cube mesh within each SIP (pe0-only, same-lane), then inter-SIP exchange on root cube, then broadcast back. Supports ring_1d, torus_2d, and mesh_2d_no_wrap SIP topologies driven by topology.yaml. Integrated with dist.init_process_group / dist.all_reduce. New files: - src/kernbench/ccl/algorithms/intercube_allreduce.py (kernel) - src/kernbench/ccl/sfr_config.py (configure_sfr_intercube_multisip) - tests/test_allreduce_multidevice.py (config-driven, 3 topologies) - tests/test_distributed_intercube_allreduce.py (full distributed path) - tests/test_intercube_sfr_config.py (SFR wiring verification) Modified: - distributed.py: AhbmCCLBackend uses configure_sfr_intercube_multisip - topologies.py: added torus_2d, mesh_2d_no_wrap - install.py: global_E/W/N/S in _OPPOSITE_DIR - topology.yaml: added system.sips.topology - ccl.yaml: single intercube_allreduce algorithm - benches/ccl_allreduce.py: row_wise cube-mesh tensor layout Removed old flat-ring algorithms and their tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,29 +0,0 @@
|
||||
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
|
||||
|
||||
Each PE sends its tile to the E neighbor and receives one tile from W,
|
||||
then stores the received tile back into its own HBM slice. The simplest
|
||||
possible demonstration of ``tl.send`` / ``tl.recv``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem,)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, tl):
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Send our local HBM tile to the E neighbor.
|
||||
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=src)
|
||||
|
||||
# Receive a tile from W and store it into our slice (overwrite).
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.store(pe_addr, recv)
|
||||
@@ -1,192 +0,0 @@
|
||||
"""Hierarchical all-reduce kernel (ADR-0023).
|
||||
|
||||
3-level reduce + broadcast exploiting the topology hierarchy:
|
||||
|
||||
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
|
||||
Bidirectional ring reduce to PE 0.
|
||||
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
|
||||
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
|
||||
Level 3 — Inter-SIP (2 SIPs, parent):
|
||||
Pair exchange between SIP representatives.
|
||||
Broadcast — Reverse chain through levels 2 and 1.
|
||||
|
||||
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
|
||||
right-half sends via dir_inc (wrapping). Representative receives from
|
||||
both sides. Rounds per level = ceil((group_size - 1) / 2).
|
||||
|
||||
Direction pairing (ring):
|
||||
Send via dir_dec at PE K → recv via dir_inc at PE K-1
|
||||
Send via dir_inc at PE K → recv via dir_dec at PE K+1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Positional kernel args for the ahbm backend."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
|
||||
|
||||
|
||||
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
|
||||
"""Build the 3-level neighbor map."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
|
||||
pe_id = rank % pes_per_cube
|
||||
cube_global = rank // pes_per_cube
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
result = {}
|
||||
|
||||
# Level 1: intra-cube ring (E/W, all PEs)
|
||||
cube_base = cube_global * pes_per_cube
|
||||
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
|
||||
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
|
||||
|
||||
# Level 2: inter-cube ring (N/S, PE 0 only)
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
sip_base = sip_id * cubes_per_sip * pes_per_cube
|
||||
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
|
||||
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
|
||||
result["N"] = next_cube_pe0
|
||||
result["S"] = prev_cube_pe0
|
||||
|
||||
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
|
||||
result["parent"] = other_sip_pe0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
|
||||
"""Bidirectional ring reduce to node 0.
|
||||
|
||||
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
|
||||
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
|
||||
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
|
||||
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
|
||||
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
|
||||
|
||||
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
|
||||
send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
half = group_size // 2
|
||||
|
||||
if my_id == 0:
|
||||
# Representative: recv left-half sum via dir_inc (from PE 1)
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
|
||||
if group_size - half - 1 >= 1:
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
elif my_id <= half:
|
||||
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
|
||||
if my_id < half: # not the far-edge
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_dec, src=acc)
|
||||
|
||||
else:
|
||||
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
|
||||
if my_id > half + 1: # not the near-edge
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
|
||||
"""Linear chain broadcast from node 0 via dir_inc.
|
||||
|
||||
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
|
||||
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
|
||||
|
||||
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
|
||||
# dir_dec is the "other" direction. We infer it from the ring:
|
||||
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
|
||||
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
|
||||
|
||||
if my_id == 0:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
|
||||
if my_id < group_size - 1:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
return acc
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
|
||||
"""Hierarchical all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA).
|
||||
n_elem: f16 elements per tile.
|
||||
pes_per_cube: PEs per cube (typically 8).
|
||||
cubes_per_sip: cubes per SIP (typically 16).
|
||||
num_sips: number of SIPs (typically 2).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
pe_id = tl.program_id(axis=0)
|
||||
cube_global = tl.program_id(axis=1)
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
rank = cube_global * pes_per_cube + pe_id
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
shape = (n_elem,)
|
||||
dtype = "f16"
|
||||
|
||||
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
tl.send(dir="parent", src=acc)
|
||||
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
# ── Broadcast back ──
|
||||
|
||||
# Level 2: cube 0 PE 0 → all PE 0s via chain
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# Level 1: PE 0 → all PEs in cube via chain
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
|
||||
|
||||
Reduces across the 4×4 cube mesh within each SIP, then exchanges
|
||||
between SIPs using the configured SIP topology, and broadcasts back.
|
||||
|
||||
Supported SIP topologies (selected via ``sip_topo_kind``):
|
||||
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
|
||||
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
|
||||
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
|
||||
|
||||
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
SIP_TOPO_RING = 0
|
||||
SIP_TOPO_TORUS = 1
|
||||
SIP_TOPO_MESH = 2
|
||||
|
||||
TOPO_NAME_TO_KIND = {
|
||||
"ring_1d": SIP_TOPO_RING,
|
||||
"torus_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d_no_wrap": SIP_TOPO_MESH,
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
def _inter_sip_ring(acc, n_sips, n_elem, tl):
|
||||
current = acc
|
||||
for _ in range(n_sips - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
# Row ring (global_E / global_W)
|
||||
current = acc
|
||||
for _ in range(sip_topo_w - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
# Col ring (global_S / global_N)
|
||||
current = acc
|
||||
for _ in range(sip_topo_h - 1):
|
||||
tl.send(dir="global_S", src=current)
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
sip_row = sip_rank // sip_topo_w
|
||||
sip_col = sip_rank % sip_topo_w
|
||||
|
||||
# Row reduce W → E
|
||||
if sip_col == 0:
|
||||
tl.send(dir="global_E", src=acc)
|
||||
elif sip_col < sip_topo_w - 1:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Row broadcast E → W
|
||||
if sip_col == sip_topo_w - 1:
|
||||
tl.send(dir="global_W", src=acc)
|
||||
elif sip_col > 0:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Col reduce N → S
|
||||
if sip_row == 0:
|
||||
tl.send(dir="global_S", src=acc)
|
||||
elif sip_row < sip_topo_h - 1:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Col broadcast S → N
|
||||
if sip_row == sip_topo_h - 1:
|
||||
tl.send(dir="global_N", src=acc)
|
||||
elif sip_row > 0:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def allreduce_intercube_multidevice(
|
||||
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
|
||||
):
|
||||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||||
|
||||
Args:
|
||||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||||
n_elem: f16 elements per cube tile.
|
||||
cube_w: cube mesh width (columns).
|
||||
cube_h: cube mesh height (rows).
|
||||
n_sips: number of SIPs.
|
||||
sip_rank: this SIP's rank (0-based).
|
||||
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
|
||||
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
|
||||
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
cube_id = tl.program_id(axis=1)
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
kernel = allreduce_intercube_multidevice
|
||||
@@ -1,73 +0,0 @@
|
||||
"""2D-mesh all-reduce kernel (ADR-0023).
|
||||
|
||||
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
|
||||
1. Row reduce: ring all-reduce along E/W within each row.
|
||||
2. Column reduce: ring all-reduce along N/S within each column.
|
||||
|
||||
After both phases, every rank holds the global sum.
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. Math/recv
|
||||
handles are passed directly to the next send, avoiding store→reload
|
||||
which doesn't propagate correctly with timing-only Phase 1 math.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Mesh all-reduce requires ``world_size`` to be a perfect square —
|
||||
the mesh side length is ``sqrt(world_size)``.
|
||||
"""
|
||||
side = int(round(math.sqrt(world_size)))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_allreduce requires a square world_size; got {world_size}"
|
||||
)
|
||||
return (n_elem, side)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, side, tl):
|
||||
"""All-reduce on a square mesh.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA shared across ranks)
|
||||
n_elem: number of f16 elements per tile
|
||||
side: mesh side length (sqrt(world_size))
|
||||
tl: TLContext (ADR-0022).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
# ── Phase 1: row ring (E direction) ──
|
||||
# Ring forwards each received tile (not the cumulative acc) so every
|
||||
# tile passes through every rank exactly once.
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
|
||||
# store/reload here — the math handle's scratch addr is the source for
|
||||
# the first column send and Phase 2 ipcq_copy replays from there.
|
||||
current = acc
|
||||
|
||||
# ── Phase 2: column ring (S direction) ──
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="S", src=current)
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
|
||||
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
|
||||
of all PE tiles.
|
||||
|
||||
Strategy
|
||||
--------
|
||||
Each PE starts with its own tile in HBM. The kernel:
|
||||
1. Loads the local tile into a TensorHandle (the accumulator).
|
||||
2. In each of ``world_size - 1`` rounds:
|
||||
- Sends the current accumulator/recv slot to the E neighbor.
|
||||
- Receives a tile from the W neighbor — the recv handle points
|
||||
into the per-direction TCM slot.
|
||||
- Adds the received tile to the accumulator using the TensorHandle
|
||||
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
|
||||
3. Stores the final accumulator back to HBM via tl.store. The store is
|
||||
recorded in op_log with both src and dst, so Phase 2 will copy the
|
||||
replayed math result from PE-local scratch into HBM.
|
||||
|
||||
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
|
||||
not yet computed, so the accumulator data flowing through Phase 1 may
|
||||
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
|
||||
in stable t_start order, producing correct final HBM contents.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
|
||||
"""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Ring all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address of the column-sharded tensor — all PEs
|
||||
share this base. The per-PE slice lives at
|
||||
``t_ptr + global_rank * n_elem * 2``.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
|
||||
global rank from ``program_id(axis=0)`` (local PE) and
|
||||
``program_id(axis=1)`` (cube id):
|
||||
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2 # f16
|
||||
|
||||
# Each PE reads from its own slice of the shared base address
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Load the local tile — handle points at HBM[pe_addr].
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
# The ring forwards each received tile to the next neighbor (NOT the
|
||||
# cumulative accumulator), so every rank's tile passes through every
|
||||
# rank exactly once. The accumulator sums the new arrival each round.
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
|
||||
# numpy in Phase 2 via DataExecutor). The result handle lives at
|
||||
# an auto-allocated PE-local scratch addr.
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
# Final result back to this PE's HBM slice. Op_log captures the
|
||||
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
|
||||
# accumulated value into HBM for verification.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Two-phase binary tree all-reduce:
|
||||
|
||||
Phase 1 (reduce up):
|
||||
- leaf nodes send their value to ``parent``
|
||||
- internal nodes recv from each child, sum, then send to ``parent``
|
||||
- root accumulates child contributions; final acc holds global sum
|
||||
|
||||
Phase 2 (broadcast down):
|
||||
- root sends acc to ``child_left`` and ``child_right`` (if present)
|
||||
- internal nodes recv from ``parent``, then forward to children
|
||||
- all ranks store the final acc to HBM
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. The kernel
|
||||
deliberately avoids the store→reload→send pattern: math/recv handles
|
||||
are passed directly to the next send so PE_DMA snapshots a deterministic
|
||||
source addr that Phase 2 can replay.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Tree all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Compute children/parent existence (matches tree_binary topology generator)
|
||||
has_parent = rank > 0
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
has_left = left < world_size
|
||||
has_right = right < world_size
|
||||
|
||||
# ── Phase 1: reduce up ──
|
||||
if has_left:
|
||||
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if has_right:
|
||||
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
if has_parent:
|
||||
# Send the math/load handle directly — its addr is either the
|
||||
# original HBM tile (leaf) or the PE-local scratch where the
|
||||
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
|
||||
tl.send(dir="parent", src=acc)
|
||||
|
||||
# ── Phase 2: broadcast down ──
|
||||
if has_parent:
|
||||
# Replace acc with the value broadcast from the parent (the global
|
||||
# sum). The recv handle points at the parent-direction TCM slot.
|
||||
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
|
||||
|
||||
if has_left:
|
||||
tl.send(dir="child_left", src=acc)
|
||||
if has_right:
|
||||
tl.send(dir="child_right", src=acc)
|
||||
|
||||
# Final store to HBM for the bench's verification path.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -219,7 +219,11 @@ def install_ipcq(
|
||||
"neighbor_table": neighbor_table,
|
||||
}
|
||||
|
||||
_OPPOSITE_DIR = {"E": "W", "W": "E", "N": "S", "S": "N"}
|
||||
_OPPOSITE_DIR = {
|
||||
"E": "W", "W": "E", "N": "S", "S": "N",
|
||||
"global_E": "global_W", "global_W": "global_E",
|
||||
"global_N": "global_S", "global_S": "global_N",
|
||||
}
|
||||
|
||||
def reverse_direction(my_rank: int, peer_rank: int, my_dir: str) -> str | None:
|
||||
"""Find peer's direction that reciprocates my_dir→peer_rank.
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""SFR configuration for intercube + inter-SIP IPCQ wiring.
|
||||
|
||||
Provides ``configure_sfr_intercube_multisip`` which programs PE_IPCQ
|
||||
neighbor tables for:
|
||||
|
||||
1. Intercube within each SIP — pe0 of every cube connects to pe0 of
|
||||
its N/S/E/W mesh neighbors (no wrap-around).
|
||||
2. Inter-SIP on ALL cubes — pe0 of cube_c on sip_A connects to pe0 of
|
||||
cube_c on each peer SIP, using ``global_E``/``global_W`` (ring) or
|
||||
``global_N``/``global_S``/``global_E``/``global_W`` (mesh/torus)
|
||||
direction labels. Wiring all cubes allows the kernel to
|
||||
dynamically elect the root cube at runtime.
|
||||
|
||||
SIP-level topology is read from ``topology.yaml`` →
|
||||
``system.sips.topology`` (e.g. ``ring_1d``, ``mesh_2d``).
|
||||
Intercube mesh dimensions come from ``sip.cube_mesh.w/h``.
|
||||
|
||||
Internally delegates to ``install_ipcq`` with a computed ``rank_to_pe``
|
||||
(pe0-only) and a closure-captured ``neighbors()`` function.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
from kernbench.ccl.install import install_ipcq
|
||||
from kernbench.ccl.topologies import _BUILTIN as _TOPO_BUILTINS
|
||||
|
||||
|
||||
def configure_sfr_intercube_multisip(
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
cfg: dict,
|
||||
) -> dict[str, Any]:
|
||||
"""Wire IPCQ for intercube (pe0, mesh) + inter-SIP (pe0, all cubes).
|
||||
|
||||
Args:
|
||||
engine: GraphEngine with ``_components``.
|
||||
spec: topology spec dict (from topology.yaml).
|
||||
cfg: merged algorithm config (from ``resolve_algorithm_config``).
|
||||
|
||||
Returns:
|
||||
The install plan dict from ``install_ipcq``.
|
||||
"""
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
mesh_w = int(cm["w"])
|
||||
mesh_h = int(cm["h"])
|
||||
n_cubes = mesh_w * mesh_h
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
sip_topology = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
|
||||
if sip_topology not in _TOPO_BUILTINS:
|
||||
raise ValueError(
|
||||
f"Unknown sip topology '{sip_topology}'. "
|
||||
f"Available: {list(_TOPO_BUILTINS)}"
|
||||
)
|
||||
sip_topo_fn = _TOPO_BUILTINS[sip_topology]
|
||||
|
||||
world_size = n_sips * n_cubes
|
||||
pe_idx_to_pe: list[tuple[int, int, int]] = [
|
||||
(sip, cube, 0)
|
||||
for sip in range(n_sips)
|
||||
for cube in range(n_cubes)
|
||||
]
|
||||
|
||||
def _neighbors(pe_idx: int, ws: int, _base: dict) -> dict[str, int]:
|
||||
sip = pe_idx // n_cubes
|
||||
cube = pe_idx % n_cubes
|
||||
row = cube // mesh_w
|
||||
col = cube % mesh_w
|
||||
|
||||
nbrs: dict[str, int] = {}
|
||||
|
||||
# Intercube within SIP (mesh, no wrap-around)
|
||||
if col < mesh_w - 1:
|
||||
nbrs["E"] = sip * n_cubes + (row * mesh_w + col + 1)
|
||||
if col > 0:
|
||||
nbrs["W"] = sip * n_cubes + (row * mesh_w + col - 1)
|
||||
if row < mesh_h - 1:
|
||||
nbrs["S"] = sip * n_cubes + ((row + 1) * mesh_w + col)
|
||||
if row > 0:
|
||||
nbrs["N"] = sip * n_cubes + ((row - 1) * mesh_w + col)
|
||||
|
||||
# Inter-SIP on ALL cubes
|
||||
if n_sips > 1:
|
||||
sip_nbrs = sip_topo_fn(sip, n_sips)
|
||||
for d, peer_sip in sip_nbrs.items():
|
||||
nbrs[f"global_{d}"] = peer_sip * n_cubes + cube
|
||||
|
||||
return nbrs
|
||||
|
||||
mock_module = types.SimpleNamespace(neighbors=_neighbors)
|
||||
|
||||
cfg_copy = dict(cfg)
|
||||
cfg_copy["world_size"] = world_size
|
||||
cfg_copy["topology"] = "none"
|
||||
|
||||
return install_ipcq(
|
||||
engine, spec, cfg_copy,
|
||||
algo_module=mock_module,
|
||||
rank_to_pe=pe_idx_to_pe,
|
||||
)
|
||||
@@ -1,492 +0,0 @@
|
||||
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
|
||||
|
||||
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
|
||||
no PE_DMA, no fabric simulation. Just enough to verify *functional*
|
||||
correctness of an IPCQ-based collective algorithm.
|
||||
|
||||
Cross-rank send/recv is implemented with greenlet cooperative scheduling
|
||||
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
|
||||
queues are unbounded.
|
||||
|
||||
Typical usage in a test::
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
from kernbench.ccl.algorithms.ring_allreduce import kernel
|
||||
|
||||
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel, world_size=4, topology="ring_1d",
|
||||
inputs=inputs, kernel_args=(16,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], sum(inputs))
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.ccl.topologies import resolve_topology
|
||||
from kernbench.common.ipcq_types import IpcqInvalidDirection
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
|
||||
# ── Per-rank fake state ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockRankState:
|
||||
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
neighbors: dict[str, int],
|
||||
input_arr: np.ndarray,
|
||||
pes_per_cube: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
|
||||
# all ranks are in one cube (legacy single-cube behavior).
|
||||
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
|
||||
self.neighbors = neighbors # direction → peer rank
|
||||
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
|
||||
self._hbm: dict[int, np.ndarray] = {}
|
||||
self._tcm: dict[int, np.ndarray] = {}
|
||||
# ``t_ptr`` is the address the kernel sees. Real benches use a
|
||||
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
|
||||
# Mirror that here: each rank's slice lives at the rank-specific addr.
|
||||
nbytes = int(input_arr.nbytes)
|
||||
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
|
||||
self._slice_addr = rank * nbytes
|
||||
self._hbm[self._slice_addr] = input_arr.copy()
|
||||
# Inbound recv FIFOs: direction → deque[ndarray]
|
||||
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
|
||||
# Output (set when kernel calls tl.store at slice address)
|
||||
self.output: np.ndarray | None = None
|
||||
# Greenlet for this rank — set later
|
||||
self.g: greenlet | None = None
|
||||
|
||||
|
||||
# ── Mock TLContext ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockTL:
|
||||
"""Drop-in tl shim for mock runtime.
|
||||
|
||||
Supports the subset of TLContext API that algorithm authors use:
|
||||
program_id, num_programs, load, store, send, recv, recv_async, wait,
|
||||
plus arithmetic operations on TensorHandle (eager numpy execution,
|
||||
no SimPy involved).
|
||||
"""
|
||||
|
||||
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
|
||||
self._state = state
|
||||
self._scheduler = scheduler
|
||||
self._handle_counter = 0
|
||||
|
||||
def _next_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"mt{self._handle_counter}"
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._state.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._state.world_size
|
||||
|
||||
# axis-aware
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
# Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
|
||||
# Falls back to flat (all ranks in one cube) if pes_per_cube
|
||||
# is not set (legacy single-cube tests).
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.rank // ppc
|
||||
return self._state.rank % ppc
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.world_size // ppc
|
||||
return ppc
|
||||
|
||||
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
|
||||
|
||||
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
if a_data is None or b_data is None:
|
||||
result = None
|
||||
elif op == "add":
|
||||
result = a_data + b_data
|
||||
elif op == "sub":
|
||||
result = a_data - b_data
|
||||
elif op == "mul":
|
||||
result = a_data * b_data
|
||||
elif op == "div":
|
||||
result = a_data / b_data
|
||||
elif op == "maximum":
|
||||
result = np.maximum(a_data, b_data)
|
||||
elif op == "minimum":
|
||||
result = np.minimum(a_data, b_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("maximum", a, b)
|
||||
|
||||
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("minimum", a, b)
|
||||
|
||||
def fma(
|
||||
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
c_data = np.asarray(c.data) if c.data is not None else None
|
||||
result = (
|
||||
a_data * b_data + c_data
|
||||
if (a_data is not None and b_data is not None and c_data is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def clamp(
|
||||
self,
|
||||
x: TensorHandle,
|
||||
min: TensorHandle,
|
||||
max: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
lo = np.asarray(min.data) if min.data is not None else None
|
||||
hi = np.asarray(max.data) if max.data is not None else None
|
||||
result = (
|
||||
np.minimum(np.maximum(x_data, lo), hi)
|
||||
if (x_data is not None and lo is not None and hi is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
else:
|
||||
x_max = np.max(x_data, axis=axis, keepdims=True)
|
||||
e = np.exp(x_data - x_max)
|
||||
s = np.sum(e, axis=axis, keepdims=True)
|
||||
result = e / s
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(-int(a) // int(b))
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
elif op == "exp":
|
||||
result = np.exp(x_data)
|
||||
elif op == "log":
|
||||
result = np.log(x_data)
|
||||
elif op == "sqrt":
|
||||
result = np.sqrt(x_data)
|
||||
elif op == "abs":
|
||||
result = np.abs(x_data)
|
||||
elif op == "sigmoid":
|
||||
result = 1.0 / (1.0 + np.exp(-x_data))
|
||||
elif op == "cos":
|
||||
result = np.cos(x_data)
|
||||
elif op == "sin":
|
||||
result = np.sin(x_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
data = self._state._hbm.get(ptr)
|
||||
if data is None:
|
||||
data = np.zeros(shape, dtype=np.float16)
|
||||
return TensorHandle(
|
||||
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
|
||||
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
|
||||
)
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
if handle.data is not None:
|
||||
self._state._hbm[ptr] = np.asarray(handle.data)
|
||||
if ptr == self._state._slice_addr:
|
||||
self._state.output = self._state._hbm[ptr]
|
||||
|
||||
# IPCQ
|
||||
def send(
|
||||
self,
|
||||
dir: str,
|
||||
src: TensorHandle | None = None,
|
||||
*,
|
||||
src_addr: int | None = None,
|
||||
nbytes: int | None = None,
|
||||
shape: tuple[int, ...] | None = None,
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
) -> None:
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
if src is not None:
|
||||
if src.data is not None:
|
||||
data = np.asarray(src.data)
|
||||
else:
|
||||
# Resolve from this rank's local memory at src.addr
|
||||
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
|
||||
stored = space_dict.get(src.addr)
|
||||
if stored is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
|
||||
)
|
||||
data = np.asarray(stored)
|
||||
else:
|
||||
data = None
|
||||
if data is None:
|
||||
raise RuntimeError("mock tl.send: src is None")
|
||||
peer_rank = self._state.neighbors[dir]
|
||||
# Find the reverse direction at the peer, mirroring real IPCQ
|
||||
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
|
||||
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
|
||||
"parent": "parent", "child_left": "child_left",
|
||||
"child_right": "child_right"}
|
||||
peer_state = self._scheduler.states[peer_rank]
|
||||
reverse_dir = _REVERSE.get(dir)
|
||||
# Fall back to "first direction pointing at me" if the explicit
|
||||
# reverse doesn't exist at the peer (e.g. custom directions).
|
||||
if reverse_dir is None or reverse_dir not in peer_state.neighbors:
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
if reverse_dir is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
|
||||
)
|
||||
peer_state.recv_q[reverse_dir].append(data.copy())
|
||||
self._scheduler._send_counter += 1
|
||||
# After delivering, hand control back to scheduler so the receiver
|
||||
# can wake up.
|
||||
self._scheduler.yield_()
|
||||
|
||||
def recv_async(
|
||||
self,
|
||||
dir: str,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> dict:
|
||||
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv_async: direction {dir!r} not in neighbors"
|
||||
)
|
||||
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
|
||||
|
||||
def wait(self, future: Any) -> TensorHandle:
|
||||
"""Block until the recv future has data."""
|
||||
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
|
||||
raise TypeError("tl.wait: expected recv future from tl.recv_async")
|
||||
d = future["dir"]
|
||||
while not self._state.recv_q[d]:
|
||||
self._scheduler.yield_()
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, future["dtype"])
|
||||
|
||||
def recv(
|
||||
self,
|
||||
dir: str | None = None,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
if dir is not None and dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
# Wait for data
|
||||
while True:
|
||||
if dir is None:
|
||||
# round-robin over directions
|
||||
for d in self._state.neighbors:
|
||||
if self._state.recv_q[d]:
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, dtype)
|
||||
else:
|
||||
if self._state.recv_q[dir]:
|
||||
data = self._state.recv_q[dir].popleft()
|
||||
return self._make_handle(data, dir, dtype)
|
||||
# Yield to other ranks
|
||||
self._scheduler.yield_()
|
||||
|
||||
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=f"recv_{direction}",
|
||||
addr=0, shape=data.shape, dtype=dtype,
|
||||
nbytes=int(data.nbytes), data=data, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
# ── Cooperative scheduler ────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockScheduler:
|
||||
"""Round-robin cooperative scheduler over rank greenlets."""
|
||||
|
||||
def __init__(self, states: list[_MockRankState]) -> None:
|
||||
self.states = states
|
||||
self._parent: greenlet | None = None
|
||||
self._cur_idx = 0
|
||||
|
||||
def yield_(self) -> None:
|
||||
"""Called from inside a rank greenlet to give other ranks a turn."""
|
||||
assert self._parent is not None
|
||||
self._parent.switch()
|
||||
|
||||
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
self._parent = greenlet.getcurrent()
|
||||
n = len(self.states)
|
||||
|
||||
# Per-rank tl shim
|
||||
tls: dict[int, _MockTL] = {}
|
||||
|
||||
def _spawn(rank_idx: int) -> greenlet:
|
||||
state = self.states[rank_idx]
|
||||
tl = _MockTL(state, self)
|
||||
tls[rank_idx] = tl
|
||||
|
||||
def _entry():
|
||||
# Activate this rank's tl for TensorHandle operator overloads
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
|
||||
return greenlet(_entry)
|
||||
|
||||
for state in self.states:
|
||||
state.g = _spawn(state.rank)
|
||||
|
||||
# Drive each rank round-robin until all dead. Detect global deadlock.
|
||||
# A global send counter tracks whether any greenlet delivered data
|
||||
# in the current round. This is more reliable than queue-depth
|
||||
# tracking because a recv+send pair in the same round nets to zero
|
||||
# depth change yet still represents real progress.
|
||||
self._send_counter = 0
|
||||
max_idle_rounds = 10_000
|
||||
idle_rounds = 0
|
||||
while True:
|
||||
alive = [s for s in self.states if s.g is not None and not s.g.dead]
|
||||
if not alive:
|
||||
break
|
||||
counter_before = self._send_counter
|
||||
for s in self.states:
|
||||
if s.g is None or s.g.dead:
|
||||
continue
|
||||
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
|
||||
s.g.switch()
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
any_died = any(s.g is not None and s.g.dead for s in self.states)
|
||||
if self._send_counter > counter_before or any_died:
|
||||
idle_rounds = 0
|
||||
else:
|
||||
idle_rounds += 1
|
||||
if idle_rounds >= max_idle_rounds:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_idle_rounds} rounds)"
|
||||
)
|
||||
|
||||
return [
|
||||
s.output if s.output is not None else s._hbm.get(s._slice_addr)
|
||||
for s in self.states
|
||||
]
|
||||
|
||||
|
||||
# ── Public entry ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_kernel_in_mock(
|
||||
kernel_fn: Callable,
|
||||
world_size: int,
|
||||
topology: str,
|
||||
inputs: list[np.ndarray],
|
||||
kernel_args: tuple = (),
|
||||
algo_module: Any | None = None,
|
||||
pes_per_cube: int = 0,
|
||||
) -> list[np.ndarray]:
|
||||
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
|
||||
|
||||
Args:
|
||||
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
|
||||
world_size: number of ranks
|
||||
topology: builtin topology name (e.g. "ring_1d")
|
||||
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
|
||||
local tile at HBM address 0.
|
||||
kernel_args: extra positional args after t_ptr
|
||||
algo_module: optional module providing ``neighbors()`` override
|
||||
pes_per_cube: PEs per cube for multi-cube program_id mapping.
|
||||
0 → single-cube legacy (all ranks in one cube).
|
||||
|
||||
Returns:
|
||||
Per-rank output ndarrays — whatever the kernel wrote via tl.store
|
||||
(or the original input if the kernel didn't store).
|
||||
"""
|
||||
if len(inputs) != world_size:
|
||||
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
|
||||
|
||||
topo_fn = resolve_topology(topology, algo_module=algo_module)
|
||||
states = [
|
||||
_MockRankState(
|
||||
rank=r, world_size=world_size,
|
||||
neighbors=topo_fn(r, world_size),
|
||||
input_arr=inputs[r],
|
||||
pes_per_cube=pes_per_cube,
|
||||
)
|
||||
for r in range(world_size)
|
||||
]
|
||||
|
||||
sched = _MockScheduler(states)
|
||||
return sched.run(kernel_fn, kernel_args)
|
||||
@@ -73,6 +73,39 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap:
|
||||
return n
|
||||
|
||||
|
||||
def torus_2d(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D torus (N/S/E/W) with wrap-around on all edges.
|
||||
|
||||
Alias for mesh_2d (which already wraps). Explicit name for clarity
|
||||
when used as a SIP-level topology.
|
||||
"""
|
||||
return mesh_2d(rank, world_size)
|
||||
|
||||
|
||||
def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D mesh (N/S/E/W) WITHOUT wrap-around.
|
||||
|
||||
Edge nodes have fewer neighbors (no wrapping). Used for SIP-level
|
||||
topologies where physical links don't wrap.
|
||||
"""
|
||||
side = int(round(world_size ** 0.5))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_2d_no_wrap requires square world_size, got {world_size}"
|
||||
)
|
||||
r, c = divmod(rank, side)
|
||||
n: NeighborMap = {}
|
||||
if r > 0:
|
||||
n["N"] = (r - 1) * side + c
|
||||
if r < side - 1:
|
||||
n["S"] = (r + 1) * side + c
|
||||
if c > 0:
|
||||
n["W"] = r * side + (c - 1)
|
||||
if c < side - 1:
|
||||
n["E"] = r * side + (c + 1)
|
||||
return n
|
||||
|
||||
|
||||
def none(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Empty map — algorithm's neighbors() must build from scratch."""
|
||||
return {}
|
||||
@@ -82,6 +115,8 @@ _BUILTIN: dict[str, TopologyFn] = {
|
||||
"ring_1d": ring_1d,
|
||||
"ring_1d_unidir": ring_1d_unidir,
|
||||
"mesh_2d": mesh_2d,
|
||||
"torus_2d": torus_2d,
|
||||
"mesh_2d_no_wrap": mesh_2d_no_wrap,
|
||||
"tree_binary": tree_binary,
|
||||
"none": none,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user