Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,6 +27,7 @@ class TensorHandle:
|
||||
dtype: str
|
||||
itemsize: int
|
||||
shards: tuple[TensorShard, ...]
|
||||
va_base: int = 0 # VA base address for the entire tensor
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
@@ -56,8 +58,19 @@ def deploy_tensor(
|
||||
placement: list[ShardSpec],
|
||||
allocators: dict[int, PEMemAllocator],
|
||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||
va_allocator=None,
|
||||
mmus: dict | None = None,
|
||||
) -> TensorHandle:
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
|
||||
isize = dtype_itemsize(dtype)
|
||||
total_nbytes = math.prod(shape) * isize
|
||||
|
||||
# Allocate VA range for the entire tensor (if VA allocator provided)
|
||||
va_base = 0
|
||||
if va_allocator is not None:
|
||||
va_base = va_allocator.alloc(total_nbytes)
|
||||
|
||||
shards: list[TensorShard] = []
|
||||
for spec in placement:
|
||||
alloc = allocators[spec.pe_index]
|
||||
@@ -65,20 +78,29 @@ def deploy_tensor(
|
||||
pa = alloc.alloc_hbm(spec.nbytes)
|
||||
else:
|
||||
pa = alloc.alloc_tcm(spec.nbytes)
|
||||
encoded_pa = pa.encode()
|
||||
shards.append(TensorShard(
|
||||
sip=alloc._sip_id,
|
||||
cube=alloc._cube_id,
|
||||
pe=alloc._pe_id,
|
||||
pa=pa.encode(),
|
||||
pa=encoded_pa,
|
||||
nbytes=spec.nbytes,
|
||||
offset_bytes=spec.offset_bytes,
|
||||
))
|
||||
|
||||
# Register VA→PA mapping in all MMUs (broadcast)
|
||||
if va_base and mmus is not None:
|
||||
shard_va = va_base + spec.offset_bytes
|
||||
for mmu in mmus.values():
|
||||
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
|
||||
|
||||
return TensorHandle(
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
itemsize=isize,
|
||||
shards=tuple(shards),
|
||||
va_base=va_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,8 +123,7 @@ class Tensor:
|
||||
|
||||
Usage::
|
||||
|
||||
a = ctx.zeros((M, K), dtype="f16")
|
||||
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
|
||||
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
|
||||
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
||||
"""
|
||||
|
||||
@@ -117,6 +138,14 @@ class Tensor:
|
||||
self.name = name
|
||||
self._dp_metadata: DPMetadata | None = None
|
||||
self._handle: TensorHandle | None = None
|
||||
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._ctx_ref is None or self._handle is None:
|
||||
return
|
||||
ctx = self._ctx_ref()
|
||||
if ctx is not None:
|
||||
ctx._free_tensor(self)
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int:
|
||||
@@ -133,6 +162,13 @@ class Tensor:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return self._handle.shards[0].pa
|
||||
|
||||
@property
|
||||
def va(self) -> int:
|
||||
"""VA base address for the entire tensor."""
|
||||
if self._handle is None:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return self._handle.va_base
|
||||
|
||||
def to(
|
||||
self,
|
||||
placement: list[ShardSpec] | None = None,
|
||||
@@ -163,4 +199,5 @@ class Tensor:
|
||||
)
|
||||
for s in self._handle.shards
|
||||
),
|
||||
va_base=self._handle.va_base,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user