Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -19,8 +19,18 @@ class RuntimeContext:
|
||||
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_va_allocator: Any = field(default=None, init=False)
|
||||
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
_tensors: list[Any] = field(default_factory=list, init=False)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.cleanup()
|
||||
return False
|
||||
|
||||
def submit(self, request: Any) -> RequestHandle:
|
||||
submit_fn = getattr(self.engine, "submit", None)
|
||||
@@ -58,6 +68,92 @@ class RuntimeContext:
|
||||
def handles(self) -> list[RequestHandle]:
|
||||
return list(self._handles)
|
||||
|
||||
# ── Tensor lifecycle ─────────────────────────────────────────────
|
||||
|
||||
def _free_tensor(self, tensor: Any) -> None:
|
||||
"""Free a single tensor: unmap MMU, return VA and PA."""
|
||||
handle = tensor._handle
|
||||
if handle is None:
|
||||
return
|
||||
tensor._handle = None
|
||||
|
||||
if not handle.va_base:
|
||||
return
|
||||
|
||||
from kernbench.runtime_api.kernel import MmuUnmapMsg
|
||||
|
||||
dp_policy = None
|
||||
if tensor._dp_metadata is not None:
|
||||
dp_policy = tensor._dp_metadata.dp_policy
|
||||
|
||||
is_cube_replicate = (
|
||||
dp_policy is not None and dp_policy.cube == "replicate"
|
||||
)
|
||||
|
||||
# Send MmuUnmapMsg through fabric
|
||||
from collections import defaultdict
|
||||
if is_cube_replicate:
|
||||
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||
for shard in handle.shards:
|
||||
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||
for (sip, cube), group_shards in cube_groups.items():
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||||
for s in group_shards
|
||||
)
|
||||
msg = MmuUnmapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
|
||||
entries=entries,
|
||||
target_sips=(sip,),
|
||||
target_cubes=(cube,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
else:
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||||
for s in handle.shards
|
||||
)
|
||||
sip_set = sorted({s.sip for s in handle.shards})
|
||||
cube_set = sorted({s.cube for s in handle.shards})
|
||||
msg = MmuUnmapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"unmap_{tensor.name}",
|
||||
entries=entries,
|
||||
target_sips=tuple(sip_set),
|
||||
target_cubes=tuple(cube_set),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
|
||||
# Return VA space
|
||||
if self._va_allocator is not None:
|
||||
self._va_allocator.free(handle.va_base, handle.nbytes)
|
||||
|
||||
# Return PA space
|
||||
if self._allocators:
|
||||
for shard in handle.shards:
|
||||
flat_idx = (
|
||||
shard.sip * self._num_cubes * self._pes_per_cube
|
||||
+ shard.cube * self._pes_per_cube
|
||||
+ shard.pe
|
||||
)
|
||||
alloc = self._allocators.get(flat_idx)
|
||||
if alloc is not None:
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Free all tensors created by this context."""
|
||||
for ref in self._tensors:
|
||||
t = ref()
|
||||
if t is not None and t._handle is not None:
|
||||
self._free_tensor(t)
|
||||
self._tensors.clear()
|
||||
|
||||
# ── PyTorch-like tensor API ──────────────────────────────────────
|
||||
|
||||
def _ensure_allocators(self) -> dict:
|
||||
@@ -111,6 +207,26 @@ class RuntimeContext:
|
||||
self._allocators[flat_idx] = PEMemAllocator(
|
||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||
)
|
||||
|
||||
# Initialize VA allocator and per-PE MMUs
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
|
||||
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
|
||||
page_size = int(pe_mmu_attrs.get("page_size", 4096))
|
||||
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
|
||||
|
||||
self._va_allocator = VirtualAllocator(
|
||||
va_base=0x1_0000_0000,
|
||||
va_size=64 * (1 << 30), # 64 GB VA space
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pes = sip_count * cubes_per_sip * pes_per_cube
|
||||
for flat_idx in range(total_pes):
|
||||
self._mmus[flat_idx] = PeMMU(
|
||||
page_size=page_size, overhead_ns=tlb_overhead_ns,
|
||||
)
|
||||
|
||||
return self._allocators
|
||||
|
||||
def _next_tensor_name(self) -> str:
|
||||
@@ -122,63 +238,57 @@ class RuntimeContext:
|
||||
shape: tuple[int, ...],
|
||||
dtype: str = "f16",
|
||||
*,
|
||||
placement: list | None = None,
|
||||
dp: Any = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
|
||||
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
|
||||
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
|
||||
|
||||
def empty(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str = "f16",
|
||||
*,
|
||||
placement: list | None = None,
|
||||
dp: Any = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
||||
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
|
||||
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
|
||||
|
||||
def _create_tensor(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str,
|
||||
placement: list | None,
|
||||
name: str | None,
|
||||
pattern: str | None,
|
||||
dp: Any = None,
|
||||
):
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
|
||||
|
||||
if not isinstance(dp, DPPolicy):
|
||||
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
|
||||
|
||||
tensor_name = name or self._next_tensor_name()
|
||||
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
|
||||
|
||||
dp_policy: DPPolicy | None = None
|
||||
|
||||
# Resolve placement: dp= takes priority over placement=
|
||||
if dp is not None and isinstance(dp, DPPolicy):
|
||||
dp_policy = dp
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
|
||||
total_cubes = self._num_sips * self._num_cubes
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||
)
|
||||
elif placement is None:
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
|
||||
dp_policy = dp
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
|
||||
total_cubes = self._num_sips * self._num_cubes
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||
)
|
||||
|
||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||
pe_indices = {s.pe_index for s in placement}
|
||||
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
|
||||
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
||||
|
||||
# Allocate PAs via PEMemAllocator
|
||||
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
|
||||
allocators = self._ensure_allocators()
|
||||
handle = deploy_tensor(
|
||||
name=tensor_name,
|
||||
@@ -186,8 +296,64 @@ class RuntimeContext:
|
||||
dtype=dtype,
|
||||
placement=placement,
|
||||
allocators=allocators,
|
||||
va_allocator=self._va_allocator,
|
||||
mmus=self._mmus,
|
||||
)
|
||||
t._handle = handle
|
||||
import weakref
|
||||
t._ctx_ref = weakref.ref(self)
|
||||
self._tensors.append(weakref.ref(t))
|
||||
|
||||
# Install VA→PA mappings via fabric MmuMapMsg
|
||||
if handle.va_base:
|
||||
from collections import defaultdict
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
is_cube_replicate = (
|
||||
dp_policy is not None and dp_policy.cube == "replicate"
|
||||
)
|
||||
|
||||
if is_cube_replicate:
|
||||
# Replicate: each (sip, cube) gets only its own local PA mappings
|
||||
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||
for shard in handle.shards:
|
||||
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||
|
||||
for (sip, cube), group_shards in cube_groups.items():
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in group_shards
|
||||
)
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
|
||||
entries=entries,
|
||||
target_sips=(sip,),
|
||||
target_cubes=(cube,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
else:
|
||||
# Sharded: broadcast all mappings to all target (sip, cube)s
|
||||
entries = tuple(
|
||||
{"va": handle.va_base + s.offset_bytes,
|
||||
"pa": s.pa, "size": s.nbytes}
|
||||
for s in handle.shards
|
||||
)
|
||||
sip_set = sorted({s.sip for s in handle.shards})
|
||||
cube_set = sorted({s.cube for s in handle.shards})
|
||||
msg = MmuMapMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"mmu_{tensor_name}",
|
||||
entries=entries,
|
||||
target_sips=tuple(sip_set),
|
||||
target_cubes=tuple(cube_set),
|
||||
target_pe="all",
|
||||
)
|
||||
h = self.submit(msg)
|
||||
self.wait(h)
|
||||
|
||||
# Submit MemoryWriteMsg per shard (deploy data to device)
|
||||
if pattern is not None:
|
||||
|
||||
@@ -69,6 +69,7 @@ class TensorArgShard:
|
||||
class TensorArg:
|
||||
shards: tuple[TensorArgShard, ...]
|
||||
arg_kind: Literal["tensor"] = "tensor"
|
||||
va_base: int = 0 # VA base address for the entire tensor
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -121,3 +122,33 @@ class PeDmaMsg:
|
||||
nbytes: int
|
||||
is_write: bool = False
|
||||
msg_type: Literal["pe_dma"] = "pe_dma"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmuMapMsg:
|
||||
"""MMU mapping install: broadcast VA→PA entries to target PEs.
|
||||
|
||||
Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU.
|
||||
target_sips controls which SIPs receive the message.
|
||||
"""
|
||||
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...)
|
||||
target_sips: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["mmu_map"] = "mmu_map"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmuUnmapMsg:
|
||||
"""MMU mapping removal: broadcast VA ranges to unmap from all PEs."""
|
||||
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
entries: tuple[dict, ...] # ({"va": int, "size": int}, ...)
|
||||
target_sips: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,6 +27,7 @@ class TensorHandle:
|
||||
dtype: str
|
||||
itemsize: int
|
||||
shards: tuple[TensorShard, ...]
|
||||
va_base: int = 0 # VA base address for the entire tensor
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
@@ -56,8 +58,19 @@ def deploy_tensor(
|
||||
placement: list[ShardSpec],
|
||||
allocators: dict[int, PEMemAllocator],
|
||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||
va_allocator=None,
|
||||
mmus: dict | None = None,
|
||||
) -> TensorHandle:
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
|
||||
isize = dtype_itemsize(dtype)
|
||||
total_nbytes = math.prod(shape) * isize
|
||||
|
||||
# Allocate VA range for the entire tensor (if VA allocator provided)
|
||||
va_base = 0
|
||||
if va_allocator is not None:
|
||||
va_base = va_allocator.alloc(total_nbytes)
|
||||
|
||||
shards: list[TensorShard] = []
|
||||
for spec in placement:
|
||||
alloc = allocators[spec.pe_index]
|
||||
@@ -65,20 +78,29 @@ def deploy_tensor(
|
||||
pa = alloc.alloc_hbm(spec.nbytes)
|
||||
else:
|
||||
pa = alloc.alloc_tcm(spec.nbytes)
|
||||
encoded_pa = pa.encode()
|
||||
shards.append(TensorShard(
|
||||
sip=alloc._sip_id,
|
||||
cube=alloc._cube_id,
|
||||
pe=alloc._pe_id,
|
||||
pa=pa.encode(),
|
||||
pa=encoded_pa,
|
||||
nbytes=spec.nbytes,
|
||||
offset_bytes=spec.offset_bytes,
|
||||
))
|
||||
|
||||
# Register VA→PA mapping in all MMUs (broadcast)
|
||||
if va_base and mmus is not None:
|
||||
shard_va = va_base + spec.offset_bytes
|
||||
for mmu in mmus.values():
|
||||
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
|
||||
|
||||
return TensorHandle(
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
itemsize=isize,
|
||||
shards=tuple(shards),
|
||||
va_base=va_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,8 +123,7 @@ class Tensor:
|
||||
|
||||
Usage::
|
||||
|
||||
a = ctx.zeros((M, K), dtype="f16")
|
||||
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
|
||||
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
|
||||
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
||||
"""
|
||||
|
||||
@@ -117,6 +138,14 @@ class Tensor:
|
||||
self.name = name
|
||||
self._dp_metadata: DPMetadata | None = None
|
||||
self._handle: TensorHandle | None = None
|
||||
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._ctx_ref is None or self._handle is None:
|
||||
return
|
||||
ctx = self._ctx_ref()
|
||||
if ctx is not None:
|
||||
ctx._free_tensor(self)
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int:
|
||||
@@ -133,6 +162,13 @@ class Tensor:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return self._handle.shards[0].pa
|
||||
|
||||
@property
|
||||
def va(self) -> int:
|
||||
"""VA base address for the entire tensor."""
|
||||
if self._handle is None:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return self._handle.va_base
|
||||
|
||||
def to(
|
||||
self,
|
||||
placement: list[ShardSpec] | None = None,
|
||||
@@ -163,4 +199,5 @@ class Tensor:
|
||||
)
|
||||
for s in self._handle.shards
|
||||
),
|
||||
va_base=self._handle.va_base,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user