10b33b44ba
Tensor.__setitem__ / __getitem__: - Shard-aligned slice assignment and read on deployed tensors. - Scalar broadcast and numpy array assignment supported. - Cross-shard slices raise NotImplementedError (use copy_ for that). - 3 new tests: single-PE, multi-PE, cross-shard error case. Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/): - 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent). - Bidirectional ring reduce at each level: ceil((N-1)/2) rounds. Left half sends via dir_dec, right half via dir_inc (wrap). Representative receives from both sides. - Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs. - Registered in ccl.yaml as "hierarchical_allreduce" with topology: none (neighbors() override builds the full 3-level neighbor map). - kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size. - Mock-verified at 8/16/32/64/128 ranks. Mock runtime fixes: - Direction pairing: explicit N↔S, E↔W, parent↔parent instead of "first matching reverse". Fixes 2-element rings where N and S both point to the same peer. - Deadlock detection: send-counter based (not just queue-depth-total) to catch chain reductions where send+recv pairs net to zero. - Multi-cube program_id: pes_per_cube parameter enables program_id(axis=0) = PE within cube, program_id(axis=1) = cube id. Legacy single-cube tests unaffected (default = world_size). 504 tests pass in 12s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
193 lines
6.9 KiB
Python
193 lines
6.9 KiB
Python
"""Hierarchical all-reduce kernel (ADR-0023).
|
|
|
|
3-level reduce + broadcast exploiting the topology hierarchy:
|
|
|
|
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
|
|
Bidirectional ring reduce to PE 0.
|
|
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
|
|
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
|
|
Level 3 — Inter-SIP (2 SIPs, parent):
|
|
Pair exchange between SIP representatives.
|
|
Broadcast — Reverse chain through levels 2 and 1.
|
|
|
|
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
|
|
right-half sends via dir_inc (wrapping). Representative receives from
|
|
both sides. Rounds per level = ceil((group_size - 1) / 2).
|
|
|
|
Direction pairing (ring):
|
|
Send via dir_dec at PE K → recv via dir_inc at PE K-1
|
|
Send via dir_inc at PE K → recv via dir_dec at PE K+1
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
|
|
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
|
"""Positional kernel args for the ahbm backend."""
|
|
pes_per_cube = 8
|
|
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
|
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
|
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
|
|
|
|
|
|
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
|
|
"""Build the 3-level neighbor map."""
|
|
pes_per_cube = 8
|
|
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
|
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
|
|
|
pe_id = rank % pes_per_cube
|
|
cube_global = rank // pes_per_cube
|
|
sip_id = cube_global // cubes_per_sip
|
|
local_cube_id = cube_global % cubes_per_sip
|
|
|
|
result = {}
|
|
|
|
# Level 1: intra-cube ring (E/W, all PEs)
|
|
cube_base = cube_global * pes_per_cube
|
|
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
|
|
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
|
|
|
|
# Level 2: inter-cube ring (N/S, PE 0 only)
|
|
if pe_id == 0 and cubes_per_sip > 1:
|
|
sip_base = sip_id * cubes_per_sip * pes_per_cube
|
|
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
|
|
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
|
|
result["N"] = next_cube_pe0
|
|
result["S"] = prev_cube_pe0
|
|
|
|
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
|
|
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
|
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
|
|
result["parent"] = other_sip_pe0
|
|
|
|
return result
|
|
|
|
|
|
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
|
|
"""Bidirectional ring reduce to node 0.
|
|
|
|
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
|
|
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
|
|
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
|
|
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
|
|
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
|
|
|
|
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
|
|
send dir_inc at K → recv dir_dec at K+1.
|
|
"""
|
|
if group_size <= 1:
|
|
return acc
|
|
|
|
half = group_size // 2
|
|
|
|
if my_id == 0:
|
|
# Representative: recv left-half sum via dir_inc (from PE 1)
|
|
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
|
acc = acc + recv
|
|
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
|
|
if group_size - half - 1 >= 1:
|
|
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
|
acc = acc + recv
|
|
|
|
elif my_id <= half:
|
|
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
|
|
if my_id < half: # not the far-edge
|
|
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
|
acc = acc + recv
|
|
tl.send(dir=dir_dec, src=acc)
|
|
|
|
else:
|
|
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
|
|
if my_id > half + 1: # not the near-edge
|
|
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
|
acc = acc + recv
|
|
tl.send(dir=dir_inc, src=acc)
|
|
|
|
return acc
|
|
|
|
|
|
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
|
|
"""Linear chain broadcast from node 0 via dir_inc.
|
|
|
|
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
|
|
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
|
|
|
|
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
|
|
"""
|
|
if group_size <= 1:
|
|
return acc
|
|
|
|
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
|
|
# dir_dec is the "other" direction. We infer it from the ring:
|
|
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
|
|
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
|
|
|
|
if my_id == 0:
|
|
tl.send(dir=dir_inc, src=acc)
|
|
else:
|
|
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
|
|
if my_id < group_size - 1:
|
|
tl.send(dir=dir_inc, src=acc)
|
|
return acc
|
|
|
|
|
|
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
|
|
"""Hierarchical all-reduce.
|
|
|
|
Args:
|
|
t_ptr: HBM base address (column-sharded VA).
|
|
n_elem: f16 elements per tile.
|
|
pes_per_cube: PEs per cube (typically 8).
|
|
cubes_per_sip: cubes per SIP (typically 16).
|
|
num_sips: number of SIPs (typically 2).
|
|
tl: TLContext (auto-injected).
|
|
"""
|
|
pe_id = tl.program_id(axis=0)
|
|
cube_global = tl.program_id(axis=1)
|
|
sip_id = cube_global // cubes_per_sip
|
|
local_cube_id = cube_global % cubes_per_sip
|
|
|
|
rank = cube_global * pes_per_cube + pe_id
|
|
nbytes = n_elem * 2
|
|
pe_addr = t_ptr + rank * nbytes
|
|
|
|
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
|
shape = (n_elem,)
|
|
dtype = "f16"
|
|
|
|
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
|
|
acc = _bidir_reduce(
|
|
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
|
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
|
|
)
|
|
|
|
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
|
|
if pe_id == 0 and cubes_per_sip > 1:
|
|
acc = _bidir_reduce(
|
|
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
|
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
|
|
)
|
|
|
|
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
|
|
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
|
tl.send(dir="parent", src=acc)
|
|
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
|
|
acc = acc + recv
|
|
|
|
# ── Broadcast back ──
|
|
|
|
# Level 2: cube 0 PE 0 → all PE 0s via chain
|
|
if pe_id == 0 and cubes_per_sip > 1:
|
|
acc = _chain_broadcast(
|
|
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
|
dir_inc="N", shape=shape, dtype=dtype,
|
|
)
|
|
|
|
# Level 1: PE 0 → all PEs in cube via chain
|
|
acc = _chain_broadcast(
|
|
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
|
dir_inc="E", shape=shape, dtype=dtype,
|
|
)
|
|
|
|
tl.store(pe_addr, acc)
|