From 10b33b44ba5a0cbca7fa3b0236a1b41f6f9477ae Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Sun, 12 Apr 2026 23:52:04 -0700 Subject: [PATCH] Add Tensor indexing + hierarchical 3-level all-reduce kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tensor.__setitem__ / __getitem__: - Shard-aligned slice assignment and read on deployed tensors. - Scalar broadcast and numpy array assignment supported. - Cross-shard slices raise NotImplementedError (use copy_ for that). - 3 new tests: single-PE, multi-PE, cross-shard error case. Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/): - 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent). - Bidirectional ring reduce at each level: ceil((N-1)/2) rounds. Left half sends via dir_dec, right half via dir_inc (wrap). Representative receives from both sides. - Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs. - Registered in ccl.yaml as "hierarchical_allreduce" with topology: none (neighbors() override builds the full 3-level neighbor map). - kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size. - Mock-verified at 8/16/32/64/128 ranks. Mock runtime fixes: - Direction pairing: explicit N↔S, E↔W, parent↔parent instead of "first matching reverse". Fixes 2-element rings where N and S both point to the same peer. - Deadlock detection: send-counter based (not just queue-depth-total) to catch chain reductions where send+recv pairs net to zero. - Multi-cube program_id: pes_per_cube parameter enables program_id(axis=0) = PE within cube, program_id(axis=1) = cube id. Legacy single-cube tests unaffected (default = world_size). 504 tests pass in 12s. Co-Authored-By: Claude Opus 4.6 (1M context) --- ccl.yaml | 8 + .../ccl/algorithms/hierarchical_allreduce.py | 192 ++++++++++++++++++ src/kernbench/ccl/testing.py | 77 ++++--- src/kernbench/runtime_api/tensor.py | 110 ++++++++++ tests/test_runtime_api_tensor.py | 70 +++++++ 5 files changed, 432 insertions(+), 25 deletions(-) create mode 100644 src/kernbench/ccl/algorithms/hierarchical_allreduce.py diff --git a/ccl.yaml b/ccl.yaml index 4bac308..c1d43f3 100644 --- a/ccl.yaml +++ b/ccl.yaml @@ -78,3 +78,11 @@ algorithms: buffer_kind: tcm world_size: 7 n_elem: 16 + + # ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ── + # Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat. + hierarchical_allreduce: + module: kernbench.ccl.algorithms.hierarchical_allreduce + topology: none + buffer_kind: tcm + n_elem: 16 diff --git a/src/kernbench/ccl/algorithms/hierarchical_allreduce.py b/src/kernbench/ccl/algorithms/hierarchical_allreduce.py new file mode 100644 index 0000000..b8700dd --- /dev/null +++ b/src/kernbench/ccl/algorithms/hierarchical_allreduce.py @@ -0,0 +1,192 @@ +"""Hierarchical all-reduce kernel (ADR-0023). + +3-level reduce + broadcast exploiting the topology hierarchy: + + Level 1 — Intra-cube (8 PEs, E/W, fastest link): + Bidirectional ring reduce to PE 0. + Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe): + Bidirectional ring reduce of PE 0s to cube 0 PE 0. + Level 3 — Inter-SIP (2 SIPs, parent): + Pair exchange between SIP representatives. + Broadcast — Reverse chain through levels 2 and 1. + +Bidirectional reduce: left-half sends toward node 0 via dir_dec, +right-half sends via dir_inc (wrapping). Representative receives from +both sides. Rounds per level = ceil((group_size - 1) / 2). + +Direction pairing (ring): + Send via dir_dec at PE K → recv via dir_inc at PE K-1 + Send via dir_inc at PE K → recv via dir_dec at PE K+1 +""" +from __future__ import annotations + + +def kernel_args(world_size: int, n_elem: int) -> tuple: + """Positional kernel args for the ahbm backend.""" + pes_per_cube = 8 + num_sips = max(1, world_size // 128) if world_size > 128 else 1 + cubes_per_sip = world_size // (pes_per_cube * num_sips) + return (n_elem, pes_per_cube, cubes_per_sip, num_sips) + + +def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict: + """Build the 3-level neighbor map.""" + pes_per_cube = 8 + num_sips = max(1, world_size // 128) if world_size > 128 else 1 + cubes_per_sip = world_size // (pes_per_cube * num_sips) + + pe_id = rank % pes_per_cube + cube_global = rank // pes_per_cube + sip_id = cube_global // cubes_per_sip + local_cube_id = cube_global % cubes_per_sip + + result = {} + + # Level 1: intra-cube ring (E/W, all PEs) + cube_base = cube_global * pes_per_cube + result["E"] = cube_base + (pe_id + 1) % pes_per_cube + result["W"] = cube_base + (pe_id - 1) % pes_per_cube + + # Level 2: inter-cube ring (N/S, PE 0 only) + if pe_id == 0 and cubes_per_sip > 1: + sip_base = sip_id * cubes_per_sip * pes_per_cube + next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube + prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube + result["N"] = next_cube_pe0 + result["S"] = prev_cube_pe0 + + # Level 3: inter-SIP (parent, PE 0 cube 0 only) + if pe_id == 0 and local_cube_id == 0 and num_sips > 1: + other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube + result["parent"] = other_sip_pe0 + + return result + + +def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype): + """Bidirectional ring reduce to node 0. + + Left half (1..half): chain reduces via dir_dec (toward lower IDs). + Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec). + Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0). + Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc). + Node 0: recvs left sum via dir_inc, right sum via dir_dec. + + Direction pairing: send dir_dec at K → recv dir_inc at K-1. + send dir_inc at K → recv dir_dec at K+1. + """ + if group_size <= 1: + return acc + + half = group_size // 2 + + if my_id == 0: + # Representative: recv left-half sum via dir_inc (from PE 1) + recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype) + acc = acc + recv + # Recv right-half sum via dir_dec (from PE N-1, wrapped) + if group_size - half - 1 >= 1: + recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype) + acc = acc + recv + + elif my_id <= half: + # Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec + if my_id < half: # not the far-edge + recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype) + acc = acc + recv + tl.send(dir=dir_dec, src=acc) + + else: + # Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc + if my_id > half + 1: # not the near-edge + recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype) + acc = acc + recv + tl.send(dir=dir_inc, src=acc) + + return acc + + +def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype): + """Linear chain broadcast from node 0 via dir_inc. + + Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit + from the ring pairing), stores, sends via dir_inc → node 2. Etc. + + Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1. + """ + if group_size <= 1: + return acc + + # In ring pairing: send via dir_inc at K → recv via dir_dec at K+1. + # dir_dec is the "other" direction. We infer it from the ring: + # if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S". + _recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc) + + if my_id == 0: + tl.send(dir=dir_inc, src=acc) + else: + acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype) + if my_id < group_size - 1: + tl.send(dir=dir_inc, src=acc) + return acc + + +def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl): + """Hierarchical all-reduce. + + Args: + t_ptr: HBM base address (column-sharded VA). + n_elem: f16 elements per tile. + pes_per_cube: PEs per cube (typically 8). + cubes_per_sip: cubes per SIP (typically 16). + num_sips: number of SIPs (typically 2). + tl: TLContext (auto-injected). + """ + pe_id = tl.program_id(axis=0) + cube_global = tl.program_id(axis=1) + sip_id = cube_global // cubes_per_sip + local_cube_id = cube_global % cubes_per_sip + + rank = cube_global * pes_per_cube + pe_id + nbytes = n_elem * 2 + pe_addr = t_ptr + rank * nbytes + + acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") + shape = (n_elem,) + dtype = "f16" + + # ── Level 1: intra-cube bidirectional reduce to PE 0 ── + acc = _bidir_reduce( + tl, acc, my_id=pe_id, group_size=pes_per_cube, + dir_inc="E", dir_dec="W", shape=shape, dtype=dtype, + ) + + # ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ── + if pe_id == 0 and cubes_per_sip > 1: + acc = _bidir_reduce( + tl, acc, my_id=local_cube_id, group_size=cubes_per_sip, + dir_inc="N", dir_dec="S", shape=shape, dtype=dtype, + ) + + # ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ── + if pe_id == 0 and local_cube_id == 0 and num_sips > 1: + tl.send(dir="parent", src=acc) + recv = tl.recv(dir="parent", shape=shape, dtype=dtype) + acc = acc + recv + + # ── Broadcast back ── + + # Level 2: cube 0 PE 0 → all PE 0s via chain + if pe_id == 0 and cubes_per_sip > 1: + acc = _chain_broadcast( + tl, acc, my_id=local_cube_id, group_size=cubes_per_sip, + dir_inc="N", shape=shape, dtype=dtype, + ) + + # Level 1: PE 0 → all PEs in cube via chain + acc = _chain_broadcast( + tl, acc, my_id=pe_id, group_size=pes_per_cube, + dir_inc="E", shape=shape, dtype=dtype, + ) + + tl.store(pe_addr, acc) diff --git a/src/kernbench/ccl/testing.py b/src/kernbench/ccl/testing.py index 2d099ef..499f6d0 100644 --- a/src/kernbench/ccl/testing.py +++ b/src/kernbench/ccl/testing.py @@ -46,9 +46,13 @@ class _MockRankState: world_size: int, neighbors: dict[str, int], input_arr: np.ndarray, + pes_per_cube: int = 0, ) -> None: self.rank = rank self.world_size = world_size + # PEs per cube for program_id(axis=0/1). If 0 or world_size, + # all ranks are in one cube (legacy single-cube behavior). + self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size self.neighbors = neighbors # direction → peer rank # HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing. self._hbm: dict[int, np.ndarray] = {} @@ -99,10 +103,19 @@ class _MockTL: # axis-aware def program_id(self, axis: int = 0) -> int: - return self._state.rank if axis == 0 else 0 + # Multi-cube: axis=0 = PE within cube, axis=1 = global cube id. + # Falls back to flat (all ranks in one cube) if pes_per_cube + # is not set (legacy single-cube tests). + ppc = self._state.pes_per_cube + if axis == 1: + return self._state.rank // ppc + return self._state.rank % ppc def num_programs(self, axis: int = 0) -> int: - return self._state.world_size if axis == 0 else 1 + ppc = self._state.pes_per_cube + if axis == 1: + return self._state.world_size // ppc + return ppc # ── arithmetic ops (called by TensorHandle.__add__ etc.) ── @@ -272,18 +285,27 @@ class _MockTL: if data is None: raise RuntimeError("mock tl.send: src is None") peer_rank = self._state.neighbors[dir] - # Find the reverse direction in peer's neighbors that points back to me + # Find the reverse direction at the peer, mirroring real IPCQ + # install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc. + _REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E", + "parent": "parent", "child_left": "child_left", + "child_right": "child_right"} peer_state = self._scheduler.states[peer_rank] - reverse_dir = None - for d, target in peer_state.neighbors.items(): - if target == self._state.rank: - reverse_dir = d - break + reverse_dir = _REVERSE.get(dir) + # Fall back to "first direction pointing at me" if the explicit + # reverse doesn't exist at the peer (e.g. custom directions). + if reverse_dir is None or reverse_dir not in peer_state.neighbors: + reverse_dir = None + for d, target in peer_state.neighbors.items(): + if target == self._state.rank: + reverse_dir = d + break if reverse_dir is None: raise RuntimeError( f"mock tl.send: peer rank {peer_rank} has no reverse direction" ) peer_state.recv_q[reverse_dir].append(data.copy()) + self._scheduler._send_counter += 1 # After delivering, hand control back to scheduler so the receiver # can wake up. self._scheduler.yield_() @@ -388,33 +410,34 @@ class _MockScheduler: state.g = _spawn(state.rank) # Drive each rank round-robin until all dead. Detect global deadlock. - max_rounds = 10_000 - round_no = 0 + # A global send counter tracks whether any greenlet delivered data + # in the current round. This is more reliable than queue-depth + # tracking because a recv+send pair in the same round nets to zero + # depth change yet still represents real progress. + self._send_counter = 0 + max_idle_rounds = 10_000 + idle_rounds = 0 while True: alive = [s for s in self.states if s.g is not None and not s.g.dead] if not alive: break - progressed = False + counter_before = self._send_counter for s in self.states: if s.g is None or s.g.dead: continue - # Multi-rank greenlets share TLContext active state via the - # module-level thread-local; restore this rank's tl before - # resuming so TensorHandle operator overloads dispatch to - # the right _MockTL. TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined] s.g.switch() - if s.g.dead: - progressed = True TLContext._set_active(None) # type: ignore[attr-defined] - # Loose progress check: if no greenlet died and queues didn't grow, - # advance round counter; abort after too many idle rounds. - round_no += 1 - if round_no > max_rounds and not progressed: - raise RuntimeError( - "mock CCL runtime: deadlock detected (no progress for " - f"{max_rounds} rounds)" - ) + any_died = any(s.g is not None and s.g.dead for s in self.states) + if self._send_counter > counter_before or any_died: + idle_rounds = 0 + else: + idle_rounds += 1 + if idle_rounds >= max_idle_rounds: + raise RuntimeError( + "mock CCL runtime: deadlock detected (no progress for " + f"{max_idle_rounds} rounds)" + ) return [ s.output if s.output is not None else s._hbm.get(s._slice_addr) @@ -432,6 +455,7 @@ def run_kernel_in_mock( inputs: list[np.ndarray], kernel_args: tuple = (), algo_module: Any | None = None, + pes_per_cube: int = 0, ) -> list[np.ndarray]: """Run a CCL kernel under the mock runtime with no SimPy/fabric. @@ -443,6 +467,8 @@ def run_kernel_in_mock( local tile at HBM address 0. kernel_args: extra positional args after t_ptr algo_module: optional module providing ``neighbors()`` override + pes_per_cube: PEs per cube for multi-cube program_id mapping. + 0 → single-cube legacy (all ranks in one cube). Returns: Per-rank output ndarrays — whatever the kernel wrote via tl.store @@ -457,6 +483,7 @@ def run_kernel_in_mock( rank=r, world_size=world_size, neighbors=topo_fn(r, world_size), input_arr=inputs[r], + pes_per_cube=pes_per_cube, ) for r in range(world_size) ] diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 05f86d2..8226f3c 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -159,6 +159,116 @@ class Tensor: if ctx is not None: ctx._free_tensor(self) + # ── Indexing (shard-aligned slices) ──────────────────────────── + + def _resolve_shard_index(self, key) -> tuple[int, int | None]: + """Map a numpy-style index key to (flat_start_elem, flat_stop_elem). + + Only shard-aligned slices on the last dimension are supported. + Returns (start, stop) in element units from the flat layout, or + raises IndexError / NotImplementedError for unsupported keys. + """ + if self._handle is None: + raise RuntimeError(f"Tensor '{self.name}' is not deployed") + ndim = len(self.shape) + if not isinstance(key, tuple): + key = (key,) + if len(key) != ndim: + raise IndexError( + f"expected {ndim} indices, got {len(key)}" + ) + # All leading dims must be int (selecting a single row/plane). + for i, k in enumerate(key[:-1]): + if not isinstance(k, int): + raise NotImplementedError( + "only integer indices are supported for leading dims" + ) + last = key[-1] + total_elems = math.prod(self.shape) + if isinstance(last, int): + # Single element + return (last, last + 1) + if isinstance(last, slice): + start, stop, step = last.indices(self.shape[-1]) + if step != 1: + raise NotImplementedError("step != 1 not supported") + return (start, stop) + raise NotImplementedError(f"unsupported index type: {type(last)}") + + def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard: + """Return the single shard that fully covers [start_elem, stop_elem). + + Raises NotImplementedError if the range spans multiple shards. + """ + isize = self.itemsize + start_byte = start_elem * isize + stop_byte = stop_elem * isize + for shard in self._handle.shards: + s_start = shard.offset_bytes + s_end = shard.offset_bytes + shard.nbytes + if start_byte >= s_start and stop_byte <= s_end: + return shard + raise NotImplementedError( + f"slice [{start_elem}:{stop_elem}] spans multiple shards " + f"(only shard-aligned slices are supported)" + ) + + def __getitem__(self, key): + """Read a shard-aligned slice. Returns a numpy array. + + Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case. + """ + start, stop = self._resolve_shard_index(key) + shard = self._shard_for_range(start, stop) + if self._memory_store is None: + return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype)) + isize = self.itemsize + local_start = (start * isize - shard.offset_bytes) // isize + local_count = stop - start + try: + arr = self._memory_store.read( + "hbm", self._shard_store_addr(shard), + ) + flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1) + return flat[local_start : local_start + local_count] + except KeyError: + return np.zeros(local_count, dtype=_numpy_dtype(self.dtype)) + + def __setitem__(self, key, value): + """Write a shard-aligned slice. + + Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and + numpy array assignment are both supported. + """ + if self._handle is None or self._memory_store is None: + raise RuntimeError( + f"Tensor '{self.name}' must be deployed before assignment" + ) + start, stop = self._resolve_shard_index(key) + shard = self._shard_for_range(start, stop) + np_dtype = _numpy_dtype(self.dtype) + isize = self.itemsize + local_start = (start * isize - shard.offset_bytes) // isize + local_count = stop - start + shard_elems = shard.nbytes // isize + addr = self._shard_store_addr(shard) + + # Read current shard data (or zeros if uninitialized) + try: + arr = self._memory_store.read("hbm", addr) + arr = np.array(arr, dtype=np_dtype).reshape(-1).copy() + except KeyError: + arr = np.zeros(shard_elems, dtype=np_dtype) + + # Write the slice + if isinstance(value, (int, float)): + arr[local_start : local_start + local_count] = np_dtype.type(value) + else: + v = np.asarray(value, dtype=np_dtype).reshape(-1) + arr[local_start : local_start + local_count] = v[:local_count] + + self._memory_store.write("hbm", addr, arr) + def __repr__(self) -> str: parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"] if self._memory_store is not None and self._handle is not None: diff --git a/tests/test_runtime_api_tensor.py b/tests/test_runtime_api_tensor.py index b06eac0..54a4698 100644 --- a/tests/test_runtime_api_tensor.py +++ b/tests/test_runtime_api_tensor.py @@ -134,3 +134,73 @@ def test_copy_shape_mismatch_raises(): t.copy_(torch.from_numpy(src)) _run_with(body) + + +# ── __setitem__ / __getitem__ (shard-aligned) ─────────────────────── + + +def test_setitem_getitem_single_pe(): + """Scalar and slice assignment on a single-PE tensor round-trips.""" + + def body(torch): + dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", + num_sips=1, num_cubes=1, num_pes=1) + t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t") + + # Scalar broadcast + t[0, 0:4] = 3.0 + assert np.allclose(t[0, 0:4], 3.0) + assert np.allclose(t[0, 4:8], 0.0) # untouched + + # Array assignment + t[0, 4:8] = np.array([10, 20, 30, 40], dtype=np.float16) + assert np.array_equal(t[0, 4:8], [10, 20, 30, 40]) + + # Full read-back via numpy() + expected = np.array([[3, 3, 3, 3, 10, 20, 30, 40]], dtype=np.float16) + assert np.array_equal(t.numpy(), expected) + + _run_with(body) + + +def test_setitem_getitem_multi_pe_shard_aligned(): + """Shard-aligned slice assignment on an 8-PE column-wise tensor.""" + + def body(torch): + n_pe = 8 + n_elem = 4 # per shard + dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", + num_sips=1, num_cubes=1, num_pes=n_pe) + t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") + + # Write each shard with its rank value + for r in range(n_pe): + t[0, r * n_elem : (r + 1) * n_elem] = float(r + 1) + + # Read back each shard + for r in range(n_pe): + expected = float(r + 1) + arr = t[0, r * n_elem : (r + 1) * n_elem] + assert np.allclose(arr, expected), f"shard {r}: {arr} != {expected}" + + # Full gather + full = t.numpy().reshape(-1) + for r in range(n_pe): + assert np.allclose(full[r * n_elem : (r + 1) * n_elem], float(r + 1)) + + _run_with(body) + + +def test_setitem_cross_shard_raises(): + """Slice spanning two shards raises NotImplementedError.""" + + def body(torch): + n_pe = 4 + n_elem = 4 + dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", + num_sips=1, num_cubes=1, num_pes=n_pe) + t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") + with pytest.raises(NotImplementedError, match="spans multiple shards"): + t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8) + + _run_with(body)