Add Tensor indexing + hierarchical 3-level all-reduce kernel

Tensor.__setitem__ / __getitem__:
- Shard-aligned slice assignment and read on deployed tensors.
- Scalar broadcast and numpy array assignment supported.
- Cross-shard slices raise NotImplementedError (use copy_ for that).
- 3 new tests: single-PE, multi-PE, cross-shard error case.

Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/):
- 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent).
- Bidirectional ring reduce at each level: ceil((N-1)/2) rounds.
  Left half sends via dir_dec, right half via dir_inc (wrap).
  Representative receives from both sides.
- Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs.
- Registered in ccl.yaml as "hierarchical_allreduce" with topology: none
  (neighbors() override builds the full 3-level neighbor map).
- kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size.
- Mock-verified at 8/16/32/64/128 ranks.

Mock runtime fixes:
- Direction pairing: explicit N↔S, E↔W, parent↔parent instead of
  "first matching reverse". Fixes 2-element rings where N and S both
  point to the same peer.
- Deadlock detection: send-counter based (not just queue-depth-total)
  to catch chain reductions where send+recv pairs net to zero.
- Multi-cube program_id: pes_per_cube parameter enables
  program_id(axis=0) = PE within cube, program_id(axis=1) = cube id.
  Legacy single-cube tests unaffected (default = world_size).

504 tests pass in 12s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-12 23:52:04 -07:00
parent 1c8ddc2d03
commit 10b33b44ba
5 changed files with 432 additions and 25 deletions
+110
View File
@@ -159,6 +159,116 @@ class Tensor:
if ctx is not None:
ctx._free_tensor(self)
# ── Indexing (shard-aligned slices) ────────────────────────────
def _resolve_shard_index(self, key) -> tuple[int, int | None]:
"""Map a numpy-style index key to (flat_start_elem, flat_stop_elem).
Only shard-aligned slices on the last dimension are supported.
Returns (start, stop) in element units from the flat layout, or
raises IndexError / NotImplementedError for unsupported keys.
"""
if self._handle is None:
raise RuntimeError(f"Tensor '{self.name}' is not deployed")
ndim = len(self.shape)
if not isinstance(key, tuple):
key = (key,)
if len(key) != ndim:
raise IndexError(
f"expected {ndim} indices, got {len(key)}"
)
# All leading dims must be int (selecting a single row/plane).
for i, k in enumerate(key[:-1]):
if not isinstance(k, int):
raise NotImplementedError(
"only integer indices are supported for leading dims"
)
last = key[-1]
total_elems = math.prod(self.shape)
if isinstance(last, int):
# Single element
return (last, last + 1)
if isinstance(last, slice):
start, stop, step = last.indices(self.shape[-1])
if step != 1:
raise NotImplementedError("step != 1 not supported")
return (start, stop)
raise NotImplementedError(f"unsupported index type: {type(last)}")
def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard:
"""Return the single shard that fully covers [start_elem, stop_elem).
Raises NotImplementedError if the range spans multiple shards.
"""
isize = self.itemsize
start_byte = start_elem * isize
stop_byte = stop_elem * isize
for shard in self._handle.shards:
s_start = shard.offset_bytes
s_end = shard.offset_bytes + shard.nbytes
if start_byte >= s_start and stop_byte <= s_end:
return shard
raise NotImplementedError(
f"slice [{start_elem}:{stop_elem}] spans multiple shards "
f"(only shard-aligned slices are supported)"
)
def __getitem__(self, key):
"""Read a shard-aligned slice. Returns a numpy array.
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
"""
start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop)
if self._memory_store is None:
return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype))
isize = self.itemsize
local_start = (start * isize - shard.offset_bytes) // isize
local_count = stop - start
try:
arr = self._memory_store.read(
"hbm", self._shard_store_addr(shard),
)
flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1)
return flat[local_start : local_start + local_count]
except KeyError:
return np.zeros(local_count, dtype=_numpy_dtype(self.dtype))
def __setitem__(self, key, value):
"""Write a shard-aligned slice.
Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and
numpy array assignment are both supported.
"""
if self._handle is None or self._memory_store is None:
raise RuntimeError(
f"Tensor '{self.name}' must be deployed before assignment"
)
start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop)
np_dtype = _numpy_dtype(self.dtype)
isize = self.itemsize
local_start = (start * isize - shard.offset_bytes) // isize
local_count = stop - start
shard_elems = shard.nbytes // isize
addr = self._shard_store_addr(shard)
# Read current shard data (or zeros if uninitialized)
try:
arr = self._memory_store.read("hbm", addr)
arr = np.array(arr, dtype=np_dtype).reshape(-1).copy()
except KeyError:
arr = np.zeros(shard_elems, dtype=np_dtype)
# Write the slice
if isinstance(value, (int, float)):
arr[local_start : local_start + local_count] = np_dtype.type(value)
else:
v = np.asarray(value, dtype=np_dtype).reshape(-1)
arr[local_start : local_start + local_count] = v[:local_count]
self._memory_store.write("hbm", addr, arr)
def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None: