08812eda58
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
204 lines
5.7 KiB
Python
204 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import weakref
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
from kernbench.policy.address.allocator import PEMemAllocator
|
|
from kernbench.policy.placement.dp import DPPolicy, ShardSpec
|
|
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TensorShard:
|
|
sip: int
|
|
cube: int
|
|
pe: int
|
|
pa: int
|
|
nbytes: int
|
|
offset_bytes: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TensorHandle:
|
|
name: str
|
|
shape: tuple[int, ...]
|
|
dtype: str
|
|
itemsize: int
|
|
shards: tuple[TensorShard, ...]
|
|
va_base: int = 0 # VA base address for the entire tensor
|
|
|
|
@property
|
|
def nbytes(self) -> int:
|
|
return math.prod(self.shape) * self.itemsize
|
|
|
|
|
|
_DTYPE_ITEMSIZE = {
|
|
"fp16": 2, "float16": 2, "f16": 2,
|
|
"fp32": 4, "float32": 4, "f32": 4,
|
|
"bf16": 2,
|
|
"int8": 1, "i8": 1,
|
|
"int16": 2, "i16": 2,
|
|
"int32": 4, "i32": 4,
|
|
}
|
|
|
|
|
|
def dtype_itemsize(dtype: str) -> int:
|
|
if dtype not in _DTYPE_ITEMSIZE:
|
|
raise ValueError(f"unsupported dtype: {dtype}")
|
|
return _DTYPE_ITEMSIZE[dtype]
|
|
|
|
|
|
def deploy_tensor(
|
|
*,
|
|
name: str,
|
|
shape: tuple[int, ...],
|
|
dtype: str,
|
|
placement: list[ShardSpec],
|
|
allocators: dict[int, PEMemAllocator],
|
|
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
|
va_allocator=None,
|
|
mmus: dict | None = None,
|
|
) -> TensorHandle:
|
|
from kernbench.policy.address.pe_mmu import PeMMU
|
|
|
|
isize = dtype_itemsize(dtype)
|
|
total_nbytes = math.prod(shape) * isize
|
|
|
|
# Allocate VA range for the entire tensor (if VA allocator provided)
|
|
va_base = 0
|
|
if va_allocator is not None:
|
|
va_base = va_allocator.alloc(total_nbytes)
|
|
|
|
shards: list[TensorShard] = []
|
|
for spec in placement:
|
|
alloc = allocators[spec.pe_index]
|
|
if mem_kind == "hbm":
|
|
pa = alloc.alloc_hbm(spec.nbytes)
|
|
else:
|
|
pa = alloc.alloc_tcm(spec.nbytes)
|
|
encoded_pa = pa.encode()
|
|
shards.append(TensorShard(
|
|
sip=alloc._sip_id,
|
|
cube=alloc._cube_id,
|
|
pe=alloc._pe_id,
|
|
pa=encoded_pa,
|
|
nbytes=spec.nbytes,
|
|
offset_bytes=spec.offset_bytes,
|
|
))
|
|
|
|
# Register VA→PA mapping in all MMUs (broadcast)
|
|
if va_base and mmus is not None:
|
|
shard_va = va_base + spec.offset_bytes
|
|
for mmu in mmus.values():
|
|
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
|
|
|
|
return TensorHandle(
|
|
name=name,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
itemsize=isize,
|
|
shards=tuple(shards),
|
|
va_base=va_base,
|
|
)
|
|
|
|
|
|
# ── PyTorch-like Tensor API ──────────────────────────────────────────
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DPMetadata:
|
|
"""Data-parallel placement metadata (stored as Tensor._dp_metadata)."""
|
|
|
|
placement: list[ShardSpec]
|
|
dp_policy: DPPolicy | None = None
|
|
sip: int = 0
|
|
cube: int = 0
|
|
target_pe: int | str = 0 # int → single PE, "all" → all PEs
|
|
|
|
|
|
class Tensor:
|
|
"""PyTorch-like tensor for benchmark code.
|
|
|
|
Usage::
|
|
|
|
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
|
|
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
shape: tuple[int, ...],
|
|
dtype: str = "f16",
|
|
name: str = "",
|
|
) -> None:
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
self.name = name
|
|
self._dp_metadata: DPMetadata | None = None
|
|
self._handle: TensorHandle | None = None
|
|
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
|
|
|
def __del__(self) -> None:
|
|
if self._ctx_ref is None or self._handle is None:
|
|
return
|
|
ctx = self._ctx_ref()
|
|
if ctx is not None:
|
|
ctx._free_tensor(self)
|
|
|
|
@property
|
|
def itemsize(self) -> int:
|
|
return dtype_itemsize(self.dtype)
|
|
|
|
@property
|
|
def nbytes(self) -> int:
|
|
return math.prod(self.shape) * self.itemsize
|
|
|
|
@property
|
|
def pa(self) -> int:
|
|
"""Primary PA (first shard). Used as kernel pointer argument."""
|
|
if self._handle is None or not self._handle.shards:
|
|
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
|
return self._handle.shards[0].pa
|
|
|
|
@property
|
|
def va(self) -> int:
|
|
"""VA base address for the entire tensor."""
|
|
if self._handle is None:
|
|
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
|
return self._handle.va_base
|
|
|
|
def to(
|
|
self,
|
|
placement: list[ShardSpec] | None = None,
|
|
*,
|
|
dp_policy: DPPolicy | None = None,
|
|
sip: int = 0,
|
|
cube: int = 0,
|
|
target_pe: int | str = 0,
|
|
) -> Tensor:
|
|
"""Set DP placement metadata (like torch.Tensor.to())."""
|
|
if placement is None:
|
|
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)]
|
|
self._dp_metadata = DPMetadata(
|
|
placement=placement, dp_policy=dp_policy,
|
|
sip=sip, cube=cube, target_pe=target_pe,
|
|
)
|
|
return self
|
|
|
|
def to_tensor_arg(self) -> TensorArg:
|
|
"""Convert deployed shards to KernelLaunchMsg TensorArg."""
|
|
if self._handle is None:
|
|
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
|
return TensorArg(
|
|
shards=tuple(
|
|
TensorArgShard(
|
|
sip=s.sip, cube=s.cube, pe=s.pe,
|
|
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
|
|
)
|
|
for s in self._handle.shards
|
|
),
|
|
va_base=self._handle.va_base,
|
|
)
|