from __future__ import annotations import math import weakref from dataclasses import dataclass from typing import Literal import numpy as np from kernbench.policy.address.allocator import PEMemAllocator from kernbench.policy.placement.dp import DPPolicy, ShardSpec from kernbench.runtime_api.kernel import TensorArg, TensorArgShard @dataclass(frozen=True) class TensorShard: sip: int cube: int pe: int pa: int nbytes: int offset_bytes: int @dataclass(frozen=True) class TensorHandle: name: str shape: tuple[int, ...] dtype: str itemsize: int shards: tuple[TensorShard, ...] va_base: int = 0 # VA base address for the entire tensor @property def nbytes(self) -> int: return math.prod(self.shape) * self.itemsize _DTYPE_ITEMSIZE = { "fp16": 2, "float16": 2, "f16": 2, "fp32": 4, "float32": 4, "f32": 4, "bf16": 2, "int8": 1, "i8": 1, "int16": 2, "i16": 2, "int32": 4, "i32": 4, } def dtype_itemsize(dtype: str) -> int: if dtype not in _DTYPE_ITEMSIZE: raise ValueError(f"unsupported dtype: {dtype}") return _DTYPE_ITEMSIZE[dtype] _NUMPY_DTYPE = { "f16": np.float16, "fp16": np.float16, "float16": np.float16, "f32": np.float32, "fp32": np.float32, "float32": np.float32, "bf16": np.float16, "i8": np.int8, "int8": np.int8, "i16": np.int16, "int16": np.int16, "i32": np.int32, "int32": np.int32, } def _numpy_dtype(dtype: str) -> np.dtype: return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16)) def deploy_tensor( *, name: str, shape: tuple[int, ...], dtype: str, placement: list[ShardSpec], allocators: dict[tuple[int, int, int], PEMemAllocator], mem_kind: Literal["hbm", "tcm"] = "hbm", va_allocator=None, ) -> TensorHandle: isize = dtype_itemsize(dtype) total_nbytes = math.prod(shape) * isize # Allocate VA range for the entire tensor (if VA allocator provided) va_base = 0 if va_allocator is not None: va_base = va_allocator.alloc(total_nbytes) shards: list[TensorShard] = [] for spec in placement: alloc = allocators[(spec.sip, spec.cube, spec.pe)] if mem_kind == "hbm": pa = alloc.alloc_hbm(spec.nbytes) else: pa = alloc.alloc_tcm(spec.nbytes) shards.append(TensorShard( sip=spec.sip, cube=spec.cube, pe=spec.pe, pa=pa.encode(), nbytes=spec.nbytes, offset_bytes=spec.offset_bytes, )) return TensorHandle( name=name, shape=shape, dtype=dtype, itemsize=isize, shards=tuple(shards), va_base=va_base, ) # ── PyTorch-like Tensor API ────────────────────────────────────────── @dataclass(frozen=True) class DPMetadata: """Data-parallel placement metadata (stored as Tensor._dp_metadata).""" placement: list[ShardSpec] dp_policy: DPPolicy | None = None sip: int = 0 cube: int = 0 target_pe: int | tuple[int, ...] | str = 0 # int → single PE, tuple → specific PEs, "all" → all PEs class Tensor: """PyTorch-like tensor for benchmark code. Usage:: a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate")) ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K) """ def __init__( self, shape: tuple[int, ...], dtype: str = "f16", name: str = "", ) -> None: self.shape = shape self.dtype = dtype self.name = name self._dp_metadata: DPMetadata | None = None self._handle: TensorHandle | None = None self._ctx_ref: weakref.ref | None = None # set by RuntimeContext self._memory_store = None # set by RuntimeContext when enable_data=True # Host-side staging buffer for torch.from_numpy() results. A tensor # with a non-None _host_buffer is NOT deployed to any PE — it lives # only on the host. Use `target.copy_(host_tensor)` to scatter the # data into a deployed, sharded target tensor. self._host_buffer: np.ndarray | None = None def __del__(self) -> None: if self._ctx_ref is None or self._handle is None: return ctx = self._ctx_ref() if ctx is not None: ctx._free_tensor(self) # ── Indexing (shard-aligned slices) ──────────────────────────── def _resolve_shard_index(self, key) -> tuple[int, int | None]: """Map a numpy-style index key to (flat_start_elem, flat_stop_elem). Only shard-aligned slices on the last dimension are supported. Returns (start, stop) in element units from the flat layout, or raises IndexError / NotImplementedError for unsupported keys. """ if self._handle is None: raise RuntimeError(f"Tensor '{self.name}' is not deployed") ndim = len(self.shape) if not isinstance(key, tuple): key = (key,) if len(key) != ndim: raise IndexError( f"expected {ndim} indices, got {len(key)}" ) # All leading dims must be int (selecting a single row/plane). for i, k in enumerate(key[:-1]): if not isinstance(k, int): raise NotImplementedError( "only integer indices are supported for leading dims" ) last = key[-1] total_elems = math.prod(self.shape) if isinstance(last, int): # Single element return (last, last + 1) if isinstance(last, slice): start, stop, step = last.indices(self.shape[-1]) if step != 1: raise NotImplementedError("step != 1 not supported") return (start, stop) raise NotImplementedError(f"unsupported index type: {type(last)}") def _shard_for_range(self, start_elem: int, stop_elem: int) -> TensorShard: """Return the single shard that fully covers [start_elem, stop_elem). Raises NotImplementedError if the range spans multiple shards. """ isize = self.itemsize start_byte = start_elem * isize stop_byte = stop_elem * isize for shard in self._handle.shards: s_start = shard.offset_bytes s_end = shard.offset_bytes + shard.nbytes if start_byte >= s_start and stop_byte <= s_end: return shard raise NotImplementedError( f"slice [{start_elem}:{stop_elem}] spans multiple shards " f"(only shard-aligned slices are supported)" ) def __getitem__(self, key): """Read a shard-aligned slice. Returns a numpy array. Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case. """ start, stop = self._resolve_shard_index(key) shard = self._shard_for_range(start, stop) if self._memory_store is None: return np.zeros(stop - start, dtype=_numpy_dtype(self.dtype)) isize = self.itemsize local_start = (start * isize - shard.offset_bytes) // isize local_count = stop - start try: arr = self._memory_store.read( "hbm", self._shard_store_addr(shard), ) flat = np.asarray(arr, dtype=_numpy_dtype(self.dtype)).reshape(-1) return flat[local_start : local_start + local_count] except KeyError: return np.zeros(local_count, dtype=_numpy_dtype(self.dtype)) def __setitem__(self, key, value): """Write a shard-aligned slice. Mirrors ``torch.Tensor.__setitem__``. Scalar broadcast and numpy array assignment are both supported. """ if self._handle is None or self._memory_store is None: raise RuntimeError( f"Tensor '{self.name}' must be deployed before assignment" ) start, stop = self._resolve_shard_index(key) shard = self._shard_for_range(start, stop) np_dtype = _numpy_dtype(self.dtype) isize = self.itemsize local_start = (start * isize - shard.offset_bytes) // isize local_count = stop - start shard_elems = shard.nbytes // isize addr = self._shard_store_addr(shard) # Read current shard data (or zeros if uninitialized) try: arr = self._memory_store.read("hbm", addr) arr = np.array(arr, dtype=np_dtype).reshape(-1).copy() except KeyError: arr = np.zeros(shard_elems, dtype=np_dtype) # Write the slice if isinstance(value, (int, float)): arr[local_start : local_start + local_count] = np_dtype.type(value) else: v = np.asarray(value, dtype=np_dtype).reshape(-1) arr[local_start : local_start + local_count] = v[:local_count] self._memory_store.write("hbm", addr, arr) def __repr__(self) -> str: parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"] if self._memory_store is not None and self._handle is not None: arr = self.data parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}") else: parts.append(", data=N/A (placeholder)") parts.append(")") return "".join(parts) @property def data(self) -> np.ndarray: """Tensor data as numpy array. Gathers all shards into a single full-shape array. Returns actual values when enable_data=True, zeros placeholder otherwise (like an uninitialized tensor). Alias of ``numpy()``. """ return self.numpy() def _shard_store_addr(self, shard: TensorShard) -> int: """MemoryStore key for a shard. Kernels read tensors via VA (translated to PA by PE_DMA's MMU when a mapping exists, otherwise the addr is treated as a PA-equivalent key). Tensor I/O therefore writes/reads at ``va_base + offset_bytes`` when ``va_base`` is set, falling back to ``shard.pa`` for the VA-less mode used by some legacy paths. """ if self._handle and self._handle.va_base: return self._handle.va_base + shard.offset_bytes return shard.pa def numpy(self) -> np.ndarray: """Return a single numpy array gathered from all shards. Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are gathered into a single full-shape ndarray according to each shard's ``offset_bytes`` / ``nbytes`` range. """ np_dtype = _numpy_dtype(self.dtype) # Host-side tensor (created via torch.from_numpy) has no shards. if self._host_buffer is not None: return self._host_buffer.copy() if self._handle is None or self._memory_store is None: return np.zeros(self.shape, dtype=np_dtype) flat = np.zeros(math.prod(self.shape), dtype=np_dtype) for shard in self._handle.shards: start = shard.offset_bytes // self.itemsize count = shard.nbytes // self.itemsize try: piece = self._memory_store.read( "hbm", self._shard_store_addr(shard), ) except KeyError: continue flat[start : start + count] = ( np.asarray(piece, dtype=np_dtype).reshape(-1)[:count] ) return flat.reshape(self.shape) def copy_(self, source: "Tensor") -> "Tensor": """In-place copy from another tensor into self. Mirrors ``torch.Tensor.copy_()``. If ``source`` is a host tensor (from ``torch.from_numpy``), its ndarray is split across self's shards using each shard's byte range. If ``source`` is a deployed (sharded) tensor, its contents are gathered first and then re-scattered into self's shard layout. Shapes must match. Returns self. """ if self._handle is None or self._memory_store is None: raise RuntimeError( f"Tensor '{self.name}' must be deployed before copy_()" ) if source.shape != self.shape: raise ValueError( f"copy_ shape mismatch: self={self.shape} source={source.shape}" ) np_dtype = _numpy_dtype(self.dtype) arr = source.numpy().astype(np_dtype, copy=False) flat = np.ascontiguousarray(arr).reshape(-1) for shard in self._handle.shards: start = shard.offset_bytes // self.itemsize count = shard.nbytes // self.itemsize piece = flat[start : start + count].copy() self._memory_store.write( "hbm", self._shard_store_addr(shard), piece, ) return self @property def itemsize(self) -> int: return dtype_itemsize(self.dtype) @property def nbytes(self) -> int: return math.prod(self.shape) * self.itemsize @property def pa(self) -> int: """Primary PA (first shard). Used as kernel pointer argument.""" if self._handle is None or not self._handle.shards: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return self._handle.shards[0].pa @property def va(self) -> int: """VA base address for the entire tensor.""" if self._handle is None: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return self._handle.va_base def to( self, placement: list[ShardSpec] | None = None, *, dp_policy: DPPolicy | None = None, sip: int = 0, cube: int = 0, target_pe: int | tuple[int, ...] | str = 0, ) -> Tensor: """Set DP placement metadata (like torch.Tensor.to()).""" if placement is None: placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=self.nbytes)] self._dp_metadata = DPMetadata( placement=placement, dp_policy=dp_policy, sip=sip, cube=cube, target_pe=target_pe, ) return self def to_tensor_arg(self) -> TensorArg: """Convert deployed shards to KernelLaunchMsg TensorArg.""" if self._handle is None: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return TensorArg( shards=tuple( TensorArgShard( sip=s.sip, cube=s.cube, pe=s.pe, pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes, ) for s in self._handle.shards ), va_base=self._handle.va_base, )