Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+92
View File
@@ -98,6 +98,16 @@ class GraphEngine:
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
# Derive PE_MMU node ID from PE_DMA node ID
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
mmu_id = f"{pe_prefix}.pe_mmu"
mmu_comp = self._components.get(mmu_id)
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
self._components[node_id]._mmu = mmu_comp.mmu
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
@@ -119,6 +129,27 @@ class GraphEngine:
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
def mmu_map(self, va: int, pa: int, size: int) -> None:
"""Sideband: install VA→PA mapping in all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_map_pe(
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
) -> None:
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
comp = self._components.get(mmu_id)
if comp is not None and hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_unmap(self, va: int, size: int) -> None:
"""Sideband: remove VA mapping from all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.unmap(va=va, size=size)
# ── internal ────────────────────────────────────────────────────
def _wire(
@@ -166,6 +197,11 @@ class GraphEngine:
yield from self._process_memory_direct(key, request, done)
return
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
yield from self._process_mmu_msg(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
@@ -341,3 +377,59 @@ class GraphEngine:
return entries
raise ValueError(f"unsupported request type: {type(request)}")
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
"""
start_ns = self._env.now
target_sips = getattr(request, "target_sips", "all")
# Determine target SIPs
sip_set: set[int] = set()
if target_sips == "all":
for ep_id in self._resolver.find_all_pcie_eps():
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
sip_set.add(sip_id)
else:
sip_set = set(target_sips)
entries = []
for sip_id in sorted(sip_set):
entries.append((
self._resolver.find_pcie_ep(sip_id),
self._resolver.find_io_cpu(sip_id),
0, # MmuMapMsg has no data payload
))
if not entries:
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
done.succeed()
return
if len(entries) == 1:
pcie_ep_id, io_cpu_id, _ = entries[0]
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP fan-out
sub_dones = []
for pcie_ep_id, io_cpu_id, _ in entries:
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
sub_done = self._env.event()
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
yield self._host_queues[pcie_ep_id].put(sub_txn)
sub_dones.append(sub_done)
for sd in sub_dones:
yield sd
elapsed = self._env.now - start_ns
self._results[key] = (
Completion(ok=True),
{"total_ns": elapsed, "msg_type": request.msg_type},
)
done.succeed()