Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+188 -22
View File
@@ -19,8 +19,18 @@ class RuntimeContext:
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
def __enter__(self):
return self
def __exit__(self, *exc):
self.cleanup()
return False
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
@@ -58,6 +68,92 @@ class RuntimeContext:
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── Tensor lifecycle ─────────────────────────────────────────────
def _free_tensor(self, tensor: Any) -> None:
"""Free a single tensor: unmap MMU, return VA and PA."""
handle = tensor._handle
if handle is None:
return
tensor._handle = None
if not handle.va_base:
return
from kernbench.runtime_api.kernel import MmuUnmapMsg
dp_policy = None
if tensor._dp_metadata is not None:
dp_policy = tensor._dp_metadata.dp_policy
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
# Send MmuUnmapMsg through fabric
from collections import defaultdict
if is_cube_replicate:
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in group_shards
)
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Return VA space
if self._va_allocator is not None:
self._va_allocator.free(handle.va_base, handle.nbytes)
# Return PA space
if self._allocators:
for shard in handle.shards:
flat_idx = (
shard.sip * self._num_cubes * self._pes_per_cube
+ shard.cube * self._pes_per_cube
+ shard.pe
)
alloc = self._allocators.get(flat_idx)
if alloc is not None:
from kernbench.policy.address.phyaddr import PhysAddr
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
def cleanup(self) -> None:
"""Free all tensors created by this context."""
for ref in self._tensors:
t = ref()
if t is not None and t._handle is not None:
self._free_tensor(t)
self._tensors.clear()
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
@@ -111,6 +207,26 @@ class RuntimeContext:
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
def _next_tensor_name(self) -> str:
@@ -122,63 +238,57 @@ class RuntimeContext:
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
placement: list | None,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
if not isinstance(dp, DPPolicy):
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy: DPPolicy | None = None
# Resolve placement: dp= takes priority over placement=
if dp is not None and isinstance(dp, DPPolicy):
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
elif placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
pe_indices = {s.pe_index for s in placement}
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
@@ -186,8 +296,64 @@ class RuntimeContext:
dtype=dtype,
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
t._ctx_ref = weakref.ref(self)
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None: