Add Tensor indexing + hierarchical 3-level all-reduce kernel

Tensor.__setitem__ / __getitem__:
- Shard-aligned slice assignment and read on deployed tensors.
- Scalar broadcast and numpy array assignment supported.
- Cross-shard slices raise NotImplementedError (use copy_ for that).
- 3 new tests: single-PE, multi-PE, cross-shard error case.

Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/):
- 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent).
- Bidirectional ring reduce at each level: ceil((N-1)/2) rounds.
  Left half sends via dir_dec, right half via dir_inc (wrap).
  Representative receives from both sides.
- Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs.
- Registered in ccl.yaml as "hierarchical_allreduce" with topology: none
  (neighbors() override builds the full 3-level neighbor map).
- kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size.
- Mock-verified at 8/16/32/64/128 ranks.

Mock runtime fixes:
- Direction pairing: explicit N↔S, E↔W, parent↔parent instead of
  "first matching reverse". Fixes 2-element rings where N and S both
  point to the same peer.
- Deadlock detection: send-counter based (not just queue-depth-total)
  to catch chain reductions where send+recv pairs net to zero.
- Multi-cube program_id: pes_per_cube parameter enables
  program_id(axis=0) = PE within cube, program_id(axis=1) = cube id.
  Legacy single-cube tests unaffected (default = world_size).

504 tests pass in 12s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-12 23:52:04 -07:00
parent 1c8ddc2d03
commit 10b33b44ba
5 changed files with 432 additions and 25 deletions
+8
View File
@@ -78,3 +78,11 @@ algorithms:
buffer_kind: tcm buffer_kind: tcm
world_size: 7 world_size: 7
n_elem: 16 n_elem: 16
# ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ──
# Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat.
hierarchical_allreduce:
module: kernbench.ccl.algorithms.hierarchical_allreduce
topology: none
buffer_kind: tcm
n_elem: 16
@@ -0,0 +1,192 @@
"""Hierarchical all-reduce kernel (ADR-0023).
3-level reduce + broadcast exploiting the topology hierarchy:
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
Bidirectional ring reduce to PE 0.
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
Level 3 — Inter-SIP (2 SIPs, parent):
Pair exchange between SIP representatives.
Broadcast — Reverse chain through levels 2 and 1.
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
right-half sends via dir_inc (wrapping). Representative receives from
both sides. Rounds per level = ceil((group_size - 1) / 2).
Direction pairing (ring):
Send via dir_dec at PE K → recv via dir_inc at PE K-1
Send via dir_inc at PE K → recv via dir_dec at PE K+1
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Positional kernel args for the ahbm backend."""
pes_per_cube = 8
num_sips = max(1, world_size // 128) if world_size > 128 else 1
cubes_per_sip = world_size // (pes_per_cube * num_sips)
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
"""Build the 3-level neighbor map."""
pes_per_cube = 8
num_sips = max(1, world_size // 128) if world_size > 128 else 1
cubes_per_sip = world_size // (pes_per_cube * num_sips)
pe_id = rank % pes_per_cube
cube_global = rank // pes_per_cube
sip_id = cube_global // cubes_per_sip
local_cube_id = cube_global % cubes_per_sip
result = {}
# Level 1: intra-cube ring (E/W, all PEs)
cube_base = cube_global * pes_per_cube
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
# Level 2: inter-cube ring (N/S, PE 0 only)
if pe_id == 0 and cubes_per_sip > 1:
sip_base = sip_id * cubes_per_sip * pes_per_cube
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
result["N"] = next_cube_pe0
result["S"] = prev_cube_pe0
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
result["parent"] = other_sip_pe0
return result
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
"""Bidirectional ring reduce to node 0.
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
send dir_inc at K → recv dir_dec at K+1.
"""
if group_size <= 1:
return acc
half = group_size // 2
if my_id == 0:
# Representative: recv left-half sum via dir_inc (from PE 1)
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
acc = acc + recv
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
if group_size - half - 1 >= 1:
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
acc = acc + recv
elif my_id <= half:
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
if my_id < half: # not the far-edge
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
acc = acc + recv
tl.send(dir=dir_dec, src=acc)
else:
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
if my_id > half + 1: # not the near-edge
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
acc = acc + recv
tl.send(dir=dir_inc, src=acc)
return acc
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
"""Linear chain broadcast from node 0 via dir_inc.
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
"""
if group_size <= 1:
return acc
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
# dir_dec is the "other" direction. We infer it from the ring:
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
if my_id == 0:
tl.send(dir=dir_inc, src=acc)
else:
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
if my_id < group_size - 1:
tl.send(dir=dir_inc, src=acc)
return acc
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
"""Hierarchical all-reduce.
Args:
t_ptr: HBM base address (column-sharded VA).
n_elem: f16 elements per tile.
pes_per_cube: PEs per cube (typically 8).
cubes_per_sip: cubes per SIP (typically 16).
num_sips: number of SIPs (typically 2).
tl: TLContext (auto-injected).
"""
pe_id = tl.program_id(axis=0)
cube_global = tl.program_id(axis=1)
sip_id = cube_global // cubes_per_sip
local_cube_id = cube_global % cubes_per_sip
rank = cube_global * pes_per_cube + pe_id
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
shape = (n_elem,)
dtype = "f16"
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
acc = _bidir_reduce(
tl, acc, my_id=pe_id, group_size=pes_per_cube,
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
)
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
if pe_id == 0 and cubes_per_sip > 1:
acc = _bidir_reduce(
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
)
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
tl.send(dir="parent", src=acc)
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
acc = acc + recv
# ── Broadcast back ──
# Level 2: cube 0 PE 0 → all PE 0s via chain
if pe_id == 0 and cubes_per_sip > 1:
acc = _chain_broadcast(
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
dir_inc="N", shape=shape, dtype=dtype,
)
# Level 1: PE 0 → all PEs in cube via chain
acc = _chain_broadcast(
tl, acc, my_id=pe_id, group_size=pes_per_cube,
dir_inc="E", shape=shape, dtype=dtype,
)
tl.store(pe_addr, acc)
+52 -25
View File
@@ -46,9 +46,13 @@ class _MockRankState:
world_size: int, world_size: int,
neighbors: dict[str, int], neighbors: dict[str, int],
input_arr: np.ndarray, input_arr: np.ndarray,
pes_per_cube: int = 0,
) -> None: ) -> None:
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
# all ranks are in one cube (legacy single-cube behavior).
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
self.neighbors = neighbors # direction → peer rank self.neighbors = neighbors # direction → peer rank
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing. # HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
self._hbm: dict[int, np.ndarray] = {} self._hbm: dict[int, np.ndarray] = {}
@@ -99,10 +103,19 @@ class _MockTL:
# axis-aware # axis-aware
def program_id(self, axis: int = 0) -> int: def program_id(self, axis: int = 0) -> int:
return self._state.rank if axis == 0 else 0 # Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
# Falls back to flat (all ranks in one cube) if pes_per_cube
# is not set (legacy single-cube tests).
ppc = self._state.pes_per_cube
if axis == 1:
return self._state.rank // ppc
return self._state.rank % ppc
def num_programs(self, axis: int = 0) -> int: def num_programs(self, axis: int = 0) -> int:
return self._state.world_size if axis == 0 else 1 ppc = self._state.pes_per_cube
if axis == 1:
return self._state.world_size // ppc
return ppc
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ── # ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
@@ -272,18 +285,27 @@ class _MockTL:
if data is None: if data is None:
raise RuntimeError("mock tl.send: src is None") raise RuntimeError("mock tl.send: src is None")
peer_rank = self._state.neighbors[dir] peer_rank = self._state.neighbors[dir]
# Find the reverse direction in peer's neighbors that points back to me # Find the reverse direction at the peer, mirroring real IPCQ
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
"parent": "parent", "child_left": "child_left",
"child_right": "child_right"}
peer_state = self._scheduler.states[peer_rank] peer_state = self._scheduler.states[peer_rank]
reverse_dir = None reverse_dir = _REVERSE.get(dir)
for d, target in peer_state.neighbors.items(): # Fall back to "first direction pointing at me" if the explicit
if target == self._state.rank: # reverse doesn't exist at the peer (e.g. custom directions).
reverse_dir = d if reverse_dir is None or reverse_dir not in peer_state.neighbors:
break reverse_dir = None
for d, target in peer_state.neighbors.items():
if target == self._state.rank:
reverse_dir = d
break
if reverse_dir is None: if reverse_dir is None:
raise RuntimeError( raise RuntimeError(
f"mock tl.send: peer rank {peer_rank} has no reverse direction" f"mock tl.send: peer rank {peer_rank} has no reverse direction"
) )
peer_state.recv_q[reverse_dir].append(data.copy()) peer_state.recv_q[reverse_dir].append(data.copy())
self._scheduler._send_counter += 1
# After delivering, hand control back to scheduler so the receiver # After delivering, hand control back to scheduler so the receiver
# can wake up. # can wake up.
self._scheduler.yield_() self._scheduler.yield_()
@@ -388,33 +410,34 @@ class _MockScheduler:
state.g = _spawn(state.rank) state.g = _spawn(state.rank)
# Drive each rank round-robin until all dead. Detect global deadlock. # Drive each rank round-robin until all dead. Detect global deadlock.
max_rounds = 10_000 # A global send counter tracks whether any greenlet delivered data
round_no = 0 # in the current round. This is more reliable than queue-depth
# tracking because a recv+send pair in the same round nets to zero
# depth change yet still represents real progress.
self._send_counter = 0
max_idle_rounds = 10_000
idle_rounds = 0
while True: while True:
alive = [s for s in self.states if s.g is not None and not s.g.dead] alive = [s for s in self.states if s.g is not None and not s.g.dead]
if not alive: if not alive:
break break
progressed = False counter_before = self._send_counter
for s in self.states: for s in self.states:
if s.g is None or s.g.dead: if s.g is None or s.g.dead:
continue continue
# Multi-rank greenlets share TLContext active state via the
# module-level thread-local; restore this rank's tl before
# resuming so TensorHandle operator overloads dispatch to
# the right _MockTL.
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined] TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
s.g.switch() s.g.switch()
if s.g.dead:
progressed = True
TLContext._set_active(None) # type: ignore[attr-defined] TLContext._set_active(None) # type: ignore[attr-defined]
# Loose progress check: if no greenlet died and queues didn't grow, any_died = any(s.g is not None and s.g.dead for s in self.states)
# advance round counter; abort after too many idle rounds. if self._send_counter > counter_before or any_died:
round_no += 1 idle_rounds = 0
if round_no > max_rounds and not progressed: else:
raise RuntimeError( idle_rounds += 1
"mock CCL runtime: deadlock detected (no progress for " if idle_rounds >= max_idle_rounds:
f"{max_rounds} rounds)" raise RuntimeError(
) "mock CCL runtime: deadlock detected (no progress for "
f"{max_idle_rounds} rounds)"
)
return [ return [
s.output if s.output is not None else s._hbm.get(s._slice_addr) s.output if s.output is not None else s._hbm.get(s._slice_addr)
@@ -432,6 +455,7 @@ def run_kernel_in_mock(
inputs: list[np.ndarray], inputs: list[np.ndarray],
kernel_args: tuple = (), kernel_args: tuple = (),
algo_module: Any | None = None, algo_module: Any | None = None,
pes_per_cube: int = 0,
) -> list[np.ndarray]: ) -> list[np.ndarray]:
"""Run a CCL kernel under the mock runtime with no SimPy/fabric. """Run a CCL kernel under the mock runtime with no SimPy/fabric.
@@ -443,6 +467,8 @@ def run_kernel_in_mock(
local tile at HBM address 0. local tile at HBM address 0.
kernel_args: extra positional args after t_ptr kernel_args: extra positional args after t_ptr
algo_module: optional module providing ``neighbors()`` override algo_module: optional module providing ``neighbors()`` override
pes_per_cube: PEs per cube for multi-cube program_id mapping.
0 → single-cube legacy (all ranks in one cube).
Returns: Returns:
Per-rank output ndarrays — whatever the kernel wrote via tl.store Per-rank output ndarrays — whatever the kernel wrote via tl.store
@@ -457,6 +483,7 @@ def run_kernel_in_mock(
rank=r, world_size=world_size, rank=r, world_size=world_size,
neighbors=topo_fn(r, world_size), neighbors=topo_fn(r, world_size),
input_arr=inputs[r], input_arr=inputs[r],
pes_per_cube=pes_per_cube,
) )
for r in range(world_size) for r in range(world_size)
] ]
+110
View File
@@ -159,6 +159,116 @@ class Tensor:
if ctx is not None: if ctx is not None:
ctx._free_tensor(self) ctx._free_tensor(self)
# ── Indexing (shard-aligned slices) ────────────────────────────
def _resolve_shard_index(self, key) -> tuple[int, int | None]:
"""Map a numpy-style index key to (flat_start_elem, flat_stop_elem).
Only shard-aligned slices on the last dimension are supported.
Returns (start, stop) in element units from the flat layout, or
raises IndexError / NotImplementedError for unsupported keys.
"""
if self._handle is None:
raise RuntimeError(f"Tensor '{self.name}' is not deployed")
ndim = len(self.shape)
if not isinstance(key, tuple):
key = (key,)
if len(key) != ndim:
raise IndexError(
f"expected {ndim} indices, got {len(key)}"
)
# All leading dims must be int (selecting a single row/plane).
for i, k in enumerate(key[:-1]):
if not isinstance(k, int):
raise NotImplementedError(
"only integer indices are supported for leading dims"
)
last = key[-1]
total_elems = math.prod(self.shape)
if isinstance(last, int):
# Single element
return (last, last + 1)
if isinstance(last, slice):
start, stop, step = last.indices(self.shape[-1])
if step != 1:
raise NotImplementedError("step != 1 not supported")
return (start, stop)
raise NotImplementedError(f"unsupported index type: {type(last)}")
def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard:
"""Return the single shard that fully covers [start_elem, stop_elem).
Raises NotImplementedError if the range spans multiple shards.
"""
isize = self.itemsize
start_byte = start_elem * isize
stop_byte = stop_elem * isize
for shard in self._handle.shards:
s_start = shard.offset_bytes
s_end = shard.offset_bytes + shard.nbytes
if start_byte >= s_start and stop_byte <= s_end:
return shard
raise NotImplementedError(
f"slice [{start_elem}:{stop_elem}] spans multiple shards "
f"(only shard-aligned slices are supported)"
)
def __getitem__(self, key):
"""Read a shard-aligned slice. Returns a numpy array.
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
"""
start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop)
if self._memory_store is None:
return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype))
isize = self.itemsize
local_start = (start * isize - shard.offset_bytes) // isize
local_count = stop - start
try:
arr = self._memory_store.read(
"hbm", self._shard_store_addr(shard),
)
flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1)
return flat[local_start : local_start + local_count]
except KeyError:
return np.zeros(local_count, dtype=_numpy_dtype(self.dtype))
def __setitem__(self, key, value):
"""Write a shard-aligned slice.
Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and
numpy array assignment are both supported.
"""
if self._handle is None or self._memory_store is None:
raise RuntimeError(
f"Tensor '{self.name}' must be deployed before assignment"
)
start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop)
np_dtype = _numpy_dtype(self.dtype)
isize = self.itemsize
local_start = (start * isize - shard.offset_bytes) // isize
local_count = stop - start
shard_elems = shard.nbytes // isize
addr = self._shard_store_addr(shard)
# Read current shard data (or zeros if uninitialized)
try:
arr = self._memory_store.read("hbm", addr)
arr = np.array(arr, dtype=np_dtype).reshape(-1).copy()
except KeyError:
arr = np.zeros(shard_elems, dtype=np_dtype)
# Write the slice
if isinstance(value, (int, float)):
arr[local_start : local_start + local_count] = np_dtype.type(value)
else:
v = np.asarray(value, dtype=np_dtype).reshape(-1)
arr[local_start : local_start + local_count] = v[:local_count]
self._memory_store.write("hbm", addr, arr)
def __repr__(self) -> str: def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"] parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None: if self._memory_store is not None and self._handle is not None:
+70
View File
@@ -134,3 +134,73 @@ def test_copy_shape_mismatch_raises():
t.copy_(torch.from_numpy(src)) t.copy_(torch.from_numpy(src))
_run_with(body) _run_with(body)
# ── __setitem__ / __getitem__ (shard-aligned) ───────────────────────
def test_setitem_getitem_single_pe():
"""Scalar and slice assignment on a single-PE tensor round-trips."""
def body(torch):
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1)
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
# Scalar broadcast
t[0, 0:4] = 3.0
assert np.allclose(t[0, 0:4], 3.0)
assert np.allclose(t[0, 4:8], 0.0) # untouched
# Array assignment
t[0, 4:8] = np.array([10, 20, 30, 40], dtype=np.float16)
assert np.array_equal(t[0, 4:8], [10, 20, 30, 40])
# Full read-back via numpy()
expected = np.array([[3, 3, 3, 3, 10, 20, 30, 40]], dtype=np.float16)
assert np.array_equal(t.numpy(), expected)
_run_with(body)
def test_setitem_getitem_multi_pe_shard_aligned():
"""Shard-aligned slice assignment on an 8-PE column-wise tensor."""
def body(torch):
n_pe = 8
n_elem = 4 # per shard
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=n_pe)
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
# Write each shard with its rank value
for r in range(n_pe):
t[0, r * n_elem : (r + 1) * n_elem] = float(r + 1)
# Read back each shard
for r in range(n_pe):
expected = float(r + 1)
arr = t[0, r * n_elem : (r + 1) * n_elem]
assert np.allclose(arr, expected), f"shard {r}: {arr} != {expected}"
# Full gather
full = t.numpy().reshape(-1)
for r in range(n_pe):
assert np.allclose(full[r * n_elem : (r + 1) * n_elem], float(r + 1))
_run_with(body)
def test_setitem_cross_shard_raises():
"""Slice spanning two shards raises NotImplementedError."""
def body(torch):
n_pe = 4
n_elem = 4
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=n_pe)
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
with pytest.raises(NotImplementedError, match="spans multiple shards"):
t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8)
_run_with(body)