from __future__ import annotations import math import weakref from dataclasses import dataclass from typing import Literal from kernbench.policy.address.allocator import PEMemAllocator from kernbench.policy.placement.dp import DPPolicy, ShardSpec from kernbench.runtime_api.kernel import TensorArg, TensorArgShard @dataclass(frozen=True) class TensorShard: sip: int cube: int pe: int pa: int nbytes: int offset_bytes: int @dataclass(frozen=True) class TensorHandle: name: str shape: tuple[int, ...] dtype: str itemsize: int shards: tuple[TensorShard, ...] va_base: int = 0 # VA base address for the entire tensor @property def nbytes(self) -> int: return math.prod(self.shape) * self.itemsize _DTYPE_ITEMSIZE = { "fp16": 2, "float16": 2, "f16": 2, "fp32": 4, "float32": 4, "f32": 4, "bf16": 2, "int8": 1, "i8": 1, "int16": 2, "i16": 2, "int32": 4, "i32": 4, } def dtype_itemsize(dtype: str) -> int: if dtype not in _DTYPE_ITEMSIZE: raise ValueError(f"unsupported dtype: {dtype}") return _DTYPE_ITEMSIZE[dtype] def deploy_tensor( *, name: str, shape: tuple[int, ...], dtype: str, placement: list[ShardSpec], allocators: dict[int, PEMemAllocator], mem_kind: Literal["hbm", "tcm"] = "hbm", va_allocator=None, mmus: dict | None = None, ) -> TensorHandle: from kernbench.policy.address.pe_mmu import PeMMU isize = dtype_itemsize(dtype) total_nbytes = math.prod(shape) * isize # Allocate VA range for the entire tensor (if VA allocator provided) va_base = 0 if va_allocator is not None: va_base = va_allocator.alloc(total_nbytes) shards: list[TensorShard] = [] for spec in placement: alloc = allocators[spec.pe_index] if mem_kind == "hbm": pa = alloc.alloc_hbm(spec.nbytes) else: pa = alloc.alloc_tcm(spec.nbytes) encoded_pa = pa.encode() shards.append(TensorShard( sip=alloc._sip_id, cube=alloc._cube_id, pe=alloc._pe_id, pa=encoded_pa, nbytes=spec.nbytes, offset_bytes=spec.offset_bytes, )) # Register VA→PA mapping in all MMUs (broadcast) if va_base and mmus is not None: shard_va = va_base + spec.offset_bytes for mmu in mmus.values(): mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes) return TensorHandle( name=name, shape=shape, dtype=dtype, itemsize=isize, shards=tuple(shards), va_base=va_base, ) # ── PyTorch-like Tensor API ────────────────────────────────────────── @dataclass(frozen=True) class DPMetadata: """Data-parallel placement metadata (stored as Tensor._dp_metadata).""" placement: list[ShardSpec] dp_policy: DPPolicy | None = None sip: int = 0 cube: int = 0 target_pe: int | str = 0 # int → single PE, "all" → all PEs class Tensor: """PyTorch-like tensor for benchmark code. Usage:: a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate")) ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K) """ def __init__( self, shape: tuple[int, ...], dtype: str = "f16", name: str = "", ) -> None: self.shape = shape self.dtype = dtype self.name = name self._dp_metadata: DPMetadata | None = None self._handle: TensorHandle | None = None self._ctx_ref: weakref.ref | None = None # set by RuntimeContext def __del__(self) -> None: if self._ctx_ref is None or self._handle is None: return ctx = self._ctx_ref() if ctx is not None: ctx._free_tensor(self) @property def itemsize(self) -> int: return dtype_itemsize(self.dtype) @property def nbytes(self) -> int: return math.prod(self.shape) * self.itemsize @property def pa(self) -> int: """Primary PA (first shard). Used as kernel pointer argument.""" if self._handle is None or not self._handle.shards: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return self._handle.shards[0].pa @property def va(self) -> int: """VA base address for the entire tensor.""" if self._handle is None: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return self._handle.va_base def to( self, placement: list[ShardSpec] | None = None, *, dp_policy: DPPolicy | None = None, sip: int = 0, cube: int = 0, target_pe: int | str = 0, ) -> Tensor: """Set DP placement metadata (like torch.Tensor.to()).""" if placement is None: placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)] self._dp_metadata = DPMetadata( placement=placement, dp_policy=dp_policy, sip=sip, cube=cube, target_pe=target_pe, ) return self def to_tensor_arg(self) -> TensorArg: """Convert deployed shards to KernelLaunchMsg TensorArg.""" if self._handle is None: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return TensorArg( shards=tuple( TensorArgShard( sip=s.sip, cube=s.cube, pe=s.pe, pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes, ) for s in self._handle.shards ), va_base=self._handle.va_base, )