Add Tensor indexing + hierarchical 3-level all-reduce kernel
Tensor.__setitem__ / __getitem__: - Shard-aligned slice assignment and read on deployed tensors. - Scalar broadcast and numpy array assignment supported. - Cross-shard slices raise NotImplementedError (use copy_ for that). - 3 new tests: single-PE, multi-PE, cross-shard error case. Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/): - 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent). - Bidirectional ring reduce at each level: ceil((N-1)/2) rounds. Left half sends via dir_dec, right half via dir_inc (wrap). Representative receives from both sides. - Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs. - Registered in ccl.yaml as "hierarchical_allreduce" with topology: none (neighbors() override builds the full 3-level neighbor map). - kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size. - Mock-verified at 8/16/32/64/128 ranks. Mock runtime fixes: - Direction pairing: explicit N↔S, E↔W, parent↔parent instead of "first matching reverse". Fixes 2-element rings where N and S both point to the same peer. - Deadlock detection: send-counter based (not just queue-depth-total) to catch chain reductions where send+recv pairs net to zero. - Multi-cube program_id: pes_per_cube parameter enables program_id(axis=0) = PE within cube, program_id(axis=1) = cube id. Legacy single-cube tests unaffected (default = world_size). 504 tests pass in 12s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -159,6 +159,116 @@ class Tensor:
|
||||
if ctx is not None:
|
||||
ctx._free_tensor(self)
|
||||
|
||||
# ── Indexing (shard-aligned slices) ────────────────────────────
|
||||
|
||||
def _resolve_shard_index(self, key) -> tuple[int, int | None]:
|
||||
"""Map a numpy-style index key to (flat_start_elem, flat_stop_elem).
|
||||
|
||||
Only shard-aligned slices on the last dimension are supported.
|
||||
Returns (start, stop) in element units from the flat layout, or
|
||||
raises IndexError / NotImplementedError for unsupported keys.
|
||||
"""
|
||||
if self._handle is None:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed")
|
||||
ndim = len(self.shape)
|
||||
if not isinstance(key, tuple):
|
||||
key = (key,)
|
||||
if len(key) != ndim:
|
||||
raise IndexError(
|
||||
f"expected {ndim} indices, got {len(key)}"
|
||||
)
|
||||
# All leading dims must be int (selecting a single row/plane).
|
||||
for i, k in enumerate(key[:-1]):
|
||||
if not isinstance(k, int):
|
||||
raise NotImplementedError(
|
||||
"only integer indices are supported for leading dims"
|
||||
)
|
||||
last = key[-1]
|
||||
total_elems = math.prod(self.shape)
|
||||
if isinstance(last, int):
|
||||
# Single element
|
||||
return (last, last + 1)
|
||||
if isinstance(last, slice):
|
||||
start, stop, step = last.indices(self.shape[-1])
|
||||
if step != 1:
|
||||
raise NotImplementedError("step != 1 not supported")
|
||||
return (start, stop)
|
||||
raise NotImplementedError(f"unsupported index type: {type(last)}")
|
||||
|
||||
def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard:
|
||||
"""Return the single shard that fully covers [start_elem, stop_elem).
|
||||
|
||||
Raises NotImplementedError if the range spans multiple shards.
|
||||
"""
|
||||
isize = self.itemsize
|
||||
start_byte = start_elem * isize
|
||||
stop_byte = stop_elem * isize
|
||||
for shard in self._handle.shards:
|
||||
s_start = shard.offset_bytes
|
||||
s_end = shard.offset_bytes + shard.nbytes
|
||||
if start_byte >= s_start and stop_byte <= s_end:
|
||||
return shard
|
||||
raise NotImplementedError(
|
||||
f"slice [{start_elem}:{stop_elem}] spans multiple shards "
|
||||
f"(only shard-aligned slices are supported)"
|
||||
)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Read a shard-aligned slice. Returns a numpy array.
|
||||
|
||||
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
|
||||
"""
|
||||
start, stop = self._resolve_shard_index(key)
|
||||
shard = self._shard_for_range(start, stop)
|
||||
if self._memory_store is None:
|
||||
return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype))
|
||||
isize = self.itemsize
|
||||
local_start = (start * isize - shard.offset_bytes) // isize
|
||||
local_count = stop - start
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
"hbm", self._shard_store_addr(shard),
|
||||
)
|
||||
flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1)
|
||||
return flat[local_start : local_start + local_count]
|
||||
except KeyError:
|
||||
return np.zeros(local_count, dtype=_numpy_dtype(self.dtype))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Write a shard-aligned slice.
|
||||
|
||||
Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and
|
||||
numpy array assignment are both supported.
|
||||
"""
|
||||
if self._handle is None or self._memory_store is None:
|
||||
raise RuntimeError(
|
||||
f"Tensor '{self.name}' must be deployed before assignment"
|
||||
)
|
||||
start, stop = self._resolve_shard_index(key)
|
||||
shard = self._shard_for_range(start, stop)
|
||||
np_dtype = _numpy_dtype(self.dtype)
|
||||
isize = self.itemsize
|
||||
local_start = (start * isize - shard.offset_bytes) // isize
|
||||
local_count = stop - start
|
||||
shard_elems = shard.nbytes // isize
|
||||
addr = self._shard_store_addr(shard)
|
||||
|
||||
# Read current shard data (or zeros if uninitialized)
|
||||
try:
|
||||
arr = self._memory_store.read("hbm", addr)
|
||||
arr = np.array(arr, dtype=np_dtype).reshape(-1).copy()
|
||||
except KeyError:
|
||||
arr = np.zeros(shard_elems, dtype=np_dtype)
|
||||
|
||||
# Write the slice
|
||||
if isinstance(value, (int, float)):
|
||||
arr[local_start : local_start + local_count] = np_dtype.type(value)
|
||||
else:
|
||||
v = np.asarray(value, dtype=np_dtype).reshape(-1)
|
||||
arr[local_start : local_start + local_count] = v[:local_count]
|
||||
|
||||
self._memory_store.write("hbm", addr, arr)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
|
||||
Reference in New Issue
Block a user