Add Tensor indexing + hierarchical 3-level all-reduce kernel
Tensor.__setitem__ / __getitem__: - Shard-aligned slice assignment and read on deployed tensors. - Scalar broadcast and numpy array assignment supported. - Cross-shard slices raise NotImplementedError (use copy_ for that). - 3 new tests: single-PE, multi-PE, cross-shard error case. Hierarchical all-reduce kernel (src/kernbench/ccl/algorithms/): - 3-level reduce: intra-cube (E/W) → inter-cube (N/S) → inter-SIP (parent). - Bidirectional ring reduce at each level: ceil((N-1)/2) rounds. Left half sends via dir_dec, right half via dir_inc (wrap). Representative receives from both sides. - Chain broadcast for reverse path: cube 0 PE 0 → all PE 0s → all PEs. - Registered in ccl.yaml as "hierarchical_allreduce" with topology: none (neighbors() override builds the full 3-level neighbor map). - kernel_args derives pes_per_cube/cubes_per_sip/num_sips from world_size. - Mock-verified at 8/16/32/64/128 ranks. Mock runtime fixes: - Direction pairing: explicit N↔S, E↔W, parent↔parent instead of "first matching reverse". Fixes 2-element rings where N and S both point to the same peer. - Deadlock detection: send-counter based (not just queue-depth-total) to catch chain reductions where send+recv pairs net to zero. - Multi-cube program_id: pes_per_cube parameter enables program_id(axis=0) = PE within cube, program_id(axis=1) = cube id. Legacy single-cube tests unaffected (default = world_size). 504 tests pass in 12s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -78,3 +78,11 @@ algorithms:
|
||||
buffer_kind: tcm
|
||||
world_size: 7
|
||||
n_elem: 16
|
||||
|
||||
# ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ──
|
||||
# Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat.
|
||||
hierarchical_allreduce:
|
||||
module: kernbench.ccl.algorithms.hierarchical_allreduce
|
||||
topology: none
|
||||
buffer_kind: tcm
|
||||
n_elem: 16
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Hierarchical all-reduce kernel (ADR-0023).
|
||||
|
||||
3-level reduce + broadcast exploiting the topology hierarchy:
|
||||
|
||||
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
|
||||
Bidirectional ring reduce to PE 0.
|
||||
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
|
||||
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
|
||||
Level 3 — Inter-SIP (2 SIPs, parent):
|
||||
Pair exchange between SIP representatives.
|
||||
Broadcast — Reverse chain through levels 2 and 1.
|
||||
|
||||
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
|
||||
right-half sends via dir_inc (wrapping). Representative receives from
|
||||
both sides. Rounds per level = ceil((group_size - 1) / 2).
|
||||
|
||||
Direction pairing (ring):
|
||||
Send via dir_dec at PE K → recv via dir_inc at PE K-1
|
||||
Send via dir_inc at PE K → recv via dir_dec at PE K+1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Positional kernel args for the ahbm backend."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
|
||||
|
||||
|
||||
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
|
||||
"""Build the 3-level neighbor map."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
|
||||
pe_id = rank % pes_per_cube
|
||||
cube_global = rank // pes_per_cube
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
result = {}
|
||||
|
||||
# Level 1: intra-cube ring (E/W, all PEs)
|
||||
cube_base = cube_global * pes_per_cube
|
||||
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
|
||||
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
|
||||
|
||||
# Level 2: inter-cube ring (N/S, PE 0 only)
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
sip_base = sip_id * cubes_per_sip * pes_per_cube
|
||||
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
|
||||
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
|
||||
result["N"] = next_cube_pe0
|
||||
result["S"] = prev_cube_pe0
|
||||
|
||||
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
|
||||
result["parent"] = other_sip_pe0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
|
||||
"""Bidirectional ring reduce to node 0.
|
||||
|
||||
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
|
||||
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
|
||||
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
|
||||
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
|
||||
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
|
||||
|
||||
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
|
||||
send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
half = group_size // 2
|
||||
|
||||
if my_id == 0:
|
||||
# Representative: recv left-half sum via dir_inc (from PE 1)
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
|
||||
if group_size - half - 1 >= 1:
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
elif my_id <= half:
|
||||
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
|
||||
if my_id < half: # not the far-edge
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_dec, src=acc)
|
||||
|
||||
else:
|
||||
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
|
||||
if my_id > half + 1: # not the near-edge
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
|
||||
"""Linear chain broadcast from node 0 via dir_inc.
|
||||
|
||||
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
|
||||
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
|
||||
|
||||
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
|
||||
# dir_dec is the "other" direction. We infer it from the ring:
|
||||
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
|
||||
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
|
||||
|
||||
if my_id == 0:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
|
||||
if my_id < group_size - 1:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
return acc
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
|
||||
"""Hierarchical all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA).
|
||||
n_elem: f16 elements per tile.
|
||||
pes_per_cube: PEs per cube (typically 8).
|
||||
cubes_per_sip: cubes per SIP (typically 16).
|
||||
num_sips: number of SIPs (typically 2).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
pe_id = tl.program_id(axis=0)
|
||||
cube_global = tl.program_id(axis=1)
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
rank = cube_global * pes_per_cube + pe_id
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
shape = (n_elem,)
|
||||
dtype = "f16"
|
||||
|
||||
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
tl.send(dir="parent", src=acc)
|
||||
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
# ── Broadcast back ──
|
||||
|
||||
# Level 2: cube 0 PE 0 → all PE 0s via chain
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# Level 1: PE 0 → all PEs in cube via chain
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -46,9 +46,13 @@ class _MockRankState:
|
||||
world_size: int,
|
||||
neighbors: dict[str, int],
|
||||
input_arr: np.ndarray,
|
||||
pes_per_cube: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
|
||||
# all ranks are in one cube (legacy single-cube behavior).
|
||||
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
|
||||
self.neighbors = neighbors # direction → peer rank
|
||||
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
|
||||
self._hbm: dict[int, np.ndarray] = {}
|
||||
@@ -99,10 +103,19 @@ class _MockTL:
|
||||
|
||||
# axis-aware
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
return self._state.rank if axis == 0 else 0
|
||||
# Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
|
||||
# Falls back to flat (all ranks in one cube) if pes_per_cube
|
||||
# is not set (legacy single-cube tests).
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.rank // ppc
|
||||
return self._state.rank % ppc
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
return self._state.world_size if axis == 0 else 1
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.world_size // ppc
|
||||
return ppc
|
||||
|
||||
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
|
||||
|
||||
@@ -272,18 +285,27 @@ class _MockTL:
|
||||
if data is None:
|
||||
raise RuntimeError("mock tl.send: src is None")
|
||||
peer_rank = self._state.neighbors[dir]
|
||||
# Find the reverse direction in peer's neighbors that points back to me
|
||||
# Find the reverse direction at the peer, mirroring real IPCQ
|
||||
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
|
||||
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
|
||||
"parent": "parent", "child_left": "child_left",
|
||||
"child_right": "child_right"}
|
||||
peer_state = self._scheduler.states[peer_rank]
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
reverse_dir = _REVERSE.get(dir)
|
||||
# Fall back to "first direction pointing at me" if the explicit
|
||||
# reverse doesn't exist at the peer (e.g. custom directions).
|
||||
if reverse_dir is None or reverse_dir not in peer_state.neighbors:
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
if reverse_dir is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
|
||||
)
|
||||
peer_state.recv_q[reverse_dir].append(data.copy())
|
||||
self._scheduler._send_counter += 1
|
||||
# After delivering, hand control back to scheduler so the receiver
|
||||
# can wake up.
|
||||
self._scheduler.yield_()
|
||||
@@ -388,33 +410,34 @@ class _MockScheduler:
|
||||
state.g = _spawn(state.rank)
|
||||
|
||||
# Drive each rank round-robin until all dead. Detect global deadlock.
|
||||
max_rounds = 10_000
|
||||
round_no = 0
|
||||
# A global send counter tracks whether any greenlet delivered data
|
||||
# in the current round. This is more reliable than queue-depth
|
||||
# tracking because a recv+send pair in the same round nets to zero
|
||||
# depth change yet still represents real progress.
|
||||
self._send_counter = 0
|
||||
max_idle_rounds = 10_000
|
||||
idle_rounds = 0
|
||||
while True:
|
||||
alive = [s for s in self.states if s.g is not None and not s.g.dead]
|
||||
if not alive:
|
||||
break
|
||||
progressed = False
|
||||
counter_before = self._send_counter
|
||||
for s in self.states:
|
||||
if s.g is None or s.g.dead:
|
||||
continue
|
||||
# Multi-rank greenlets share TLContext active state via the
|
||||
# module-level thread-local; restore this rank's tl before
|
||||
# resuming so TensorHandle operator overloads dispatch to
|
||||
# the right _MockTL.
|
||||
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
|
||||
s.g.switch()
|
||||
if s.g.dead:
|
||||
progressed = True
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
# Loose progress check: if no greenlet died and queues didn't grow,
|
||||
# advance round counter; abort after too many idle rounds.
|
||||
round_no += 1
|
||||
if round_no > max_rounds and not progressed:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_rounds} rounds)"
|
||||
)
|
||||
any_died = any(s.g is not None and s.g.dead for s in self.states)
|
||||
if self._send_counter > counter_before or any_died:
|
||||
idle_rounds = 0
|
||||
else:
|
||||
idle_rounds += 1
|
||||
if idle_rounds >= max_idle_rounds:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_idle_rounds} rounds)"
|
||||
)
|
||||
|
||||
return [
|
||||
s.output if s.output is not None else s._hbm.get(s._slice_addr)
|
||||
@@ -432,6 +455,7 @@ def run_kernel_in_mock(
|
||||
inputs: list[np.ndarray],
|
||||
kernel_args: tuple = (),
|
||||
algo_module: Any | None = None,
|
||||
pes_per_cube: int = 0,
|
||||
) -> list[np.ndarray]:
|
||||
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
|
||||
|
||||
@@ -443,6 +467,8 @@ def run_kernel_in_mock(
|
||||
local tile at HBM address 0.
|
||||
kernel_args: extra positional args after t_ptr
|
||||
algo_module: optional module providing ``neighbors()`` override
|
||||
pes_per_cube: PEs per cube for multi-cube program_id mapping.
|
||||
0 → single-cube legacy (all ranks in one cube).
|
||||
|
||||
Returns:
|
||||
Per-rank output ndarrays — whatever the kernel wrote via tl.store
|
||||
@@ -457,6 +483,7 @@ def run_kernel_in_mock(
|
||||
rank=r, world_size=world_size,
|
||||
neighbors=topo_fn(r, world_size),
|
||||
input_arr=inputs[r],
|
||||
pes_per_cube=pes_per_cube,
|
||||
)
|
||||
for r in range(world_size)
|
||||
]
|
||||
|
||||
@@ -159,6 +159,116 @@ class Tensor:
|
||||
if ctx is not None:
|
||||
ctx._free_tensor(self)
|
||||
|
||||
# ── Indexing (shard-aligned slices) ────────────────────────────
|
||||
|
||||
def _resolve_shard_index(self, key) -> tuple[int, int | None]:
|
||||
"""Map a numpy-style index key to (flat_start_elem, flat_stop_elem).
|
||||
|
||||
Only shard-aligned slices on the last dimension are supported.
|
||||
Returns (start, stop) in element units from the flat layout, or
|
||||
raises IndexError / NotImplementedError for unsupported keys.
|
||||
"""
|
||||
if self._handle is None:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed")
|
||||
ndim = len(self.shape)
|
||||
if not isinstance(key, tuple):
|
||||
key = (key,)
|
||||
if len(key) != ndim:
|
||||
raise IndexError(
|
||||
f"expected {ndim} indices, got {len(key)}"
|
||||
)
|
||||
# All leading dims must be int (selecting a single row/plane).
|
||||
for i, k in enumerate(key[:-1]):
|
||||
if not isinstance(k, int):
|
||||
raise NotImplementedError(
|
||||
"only integer indices are supported for leading dims"
|
||||
)
|
||||
last = key[-1]
|
||||
total_elems = math.prod(self.shape)
|
||||
if isinstance(last, int):
|
||||
# Single element
|
||||
return (last, last + 1)
|
||||
if isinstance(last, slice):
|
||||
start, stop, step = last.indices(self.shape[-1])
|
||||
if step != 1:
|
||||
raise NotImplementedError("step != 1 not supported")
|
||||
return (start, stop)
|
||||
raise NotImplementedError(f"unsupported index type: {type(last)}")
|
||||
|
||||
def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard:
|
||||
"""Return the single shard that fully covers [start_elem, stop_elem).
|
||||
|
||||
Raises NotImplementedError if the range spans multiple shards.
|
||||
"""
|
||||
isize = self.itemsize
|
||||
start_byte = start_elem * isize
|
||||
stop_byte = stop_elem * isize
|
||||
for shard in self._handle.shards:
|
||||
s_start = shard.offset_bytes
|
||||
s_end = shard.offset_bytes + shard.nbytes
|
||||
if start_byte >= s_start and stop_byte <= s_end:
|
||||
return shard
|
||||
raise NotImplementedError(
|
||||
f"slice [{start_elem}:{stop_elem}] spans multiple shards "
|
||||
f"(only shard-aligned slices are supported)"
|
||||
)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Read a shard-aligned slice. Returns a numpy array.
|
||||
|
||||
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
|
||||
"""
|
||||
start, stop = self._resolve_shard_index(key)
|
||||
shard = self._shard_for_range(start, stop)
|
||||
if self._memory_store is None:
|
||||
return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype))
|
||||
isize = self.itemsize
|
||||
local_start = (start * isize - shard.offset_bytes) // isize
|
||||
local_count = stop - start
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
"hbm", self._shard_store_addr(shard),
|
||||
)
|
||||
flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1)
|
||||
return flat[local_start : local_start + local_count]
|
||||
except KeyError:
|
||||
return np.zeros(local_count, dtype=_numpy_dtype(self.dtype))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Write a shard-aligned slice.
|
||||
|
||||
Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and
|
||||
numpy array assignment are both supported.
|
||||
"""
|
||||
if self._handle is None or self._memory_store is None:
|
||||
raise RuntimeError(
|
||||
f"Tensor '{self.name}' must be deployed before assignment"
|
||||
)
|
||||
start, stop = self._resolve_shard_index(key)
|
||||
shard = self._shard_for_range(start, stop)
|
||||
np_dtype = _numpy_dtype(self.dtype)
|
||||
isize = self.itemsize
|
||||
local_start = (start * isize - shard.offset_bytes) // isize
|
||||
local_count = stop - start
|
||||
shard_elems = shard.nbytes // isize
|
||||
addr = self._shard_store_addr(shard)
|
||||
|
||||
# Read current shard data (or zeros if uninitialized)
|
||||
try:
|
||||
arr = self._memory_store.read("hbm", addr)
|
||||
arr = np.array(arr, dtype=np_dtype).reshape(-1).copy()
|
||||
except KeyError:
|
||||
arr = np.zeros(shard_elems, dtype=np_dtype)
|
||||
|
||||
# Write the slice
|
||||
if isinstance(value, (int, float)):
|
||||
arr[local_start : local_start + local_count] = np_dtype.type(value)
|
||||
else:
|
||||
v = np.asarray(value, dtype=np_dtype).reshape(-1)
|
||||
arr[local_start : local_start + local_count] = v[:local_count]
|
||||
|
||||
self._memory_store.write("hbm", addr, arr)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
|
||||
@@ -134,3 +134,73 @@ def test_copy_shape_mismatch_raises():
|
||||
t.copy_(torch.from_numpy(src))
|
||||
|
||||
_run_with(body)
|
||||
|
||||
|
||||
# ── __setitem__ / __getitem__ (shard-aligned) ───────────────────────
|
||||
|
||||
|
||||
def test_setitem_getitem_single_pe():
|
||||
"""Scalar and slice assignment on a single-PE tensor round-trips."""
|
||||
|
||||
def body(torch):
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
|
||||
|
||||
# Scalar broadcast
|
||||
t[0, 0:4] = 3.0
|
||||
assert np.allclose(t[0, 0:4], 3.0)
|
||||
assert np.allclose(t[0, 4:8], 0.0) # untouched
|
||||
|
||||
# Array assignment
|
||||
t[0, 4:8] = np.array([10, 20, 30, 40], dtype=np.float16)
|
||||
assert np.array_equal(t[0, 4:8], [10, 20, 30, 40])
|
||||
|
||||
# Full read-back via numpy()
|
||||
expected = np.array([[3, 3, 3, 3, 10, 20, 30, 40]], dtype=np.float16)
|
||||
assert np.array_equal(t.numpy(), expected)
|
||||
|
||||
_run_with(body)
|
||||
|
||||
|
||||
def test_setitem_getitem_multi_pe_shard_aligned():
|
||||
"""Shard-aligned slice assignment on an 8-PE column-wise tensor."""
|
||||
|
||||
def body(torch):
|
||||
n_pe = 8
|
||||
n_elem = 4 # per shard
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=n_pe)
|
||||
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
|
||||
|
||||
# Write each shard with its rank value
|
||||
for r in range(n_pe):
|
||||
t[0, r * n_elem : (r + 1) * n_elem] = float(r + 1)
|
||||
|
||||
# Read back each shard
|
||||
for r in range(n_pe):
|
||||
expected = float(r + 1)
|
||||
arr = t[0, r * n_elem : (r + 1) * n_elem]
|
||||
assert np.allclose(arr, expected), f"shard {r}: {arr} != {expected}"
|
||||
|
||||
# Full gather
|
||||
full = t.numpy().reshape(-1)
|
||||
for r in range(n_pe):
|
||||
assert np.allclose(full[r * n_elem : (r + 1) * n_elem], float(r + 1))
|
||||
|
||||
_run_with(body)
|
||||
|
||||
|
||||
def test_setitem_cross_shard_raises():
|
||||
"""Slice spanning two shards raises NotImplementedError."""
|
||||
|
||||
def body(torch):
|
||||
n_pe = 4
|
||||
n_elem = 4
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=n_pe)
|
||||
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
|
||||
with pytest.raises(NotImplementedError, match="spans multiple shards"):
|
||||
t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8)
|
||||
|
||||
_run_with(body)
|
||||
|
||||
Reference in New Issue
Block a user