Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -98,6 +98,16 @@ class GraphEngine:
|
||||
self._components[node_id].in_ports["host"] = host_q
|
||||
self._pe_dma_queues[node_id] = host_q
|
||||
|
||||
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
|
||||
for node_id, node in graph.nodes.items():
|
||||
if node.kind == "pe_dma":
|
||||
# Derive PE_MMU node ID from PE_DMA node ID
|
||||
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
||||
mmu_id = f"{pe_prefix}.pe_mmu"
|
||||
mmu_comp = self._components.get(mmu_id)
|
||||
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
|
||||
self._components[node_id]._mmu = mmu_comp.mmu
|
||||
|
||||
# Start components after all ports are wired (ADR-0015 D3)
|
||||
for comp in self._components.values():
|
||||
comp.start(self._env)
|
||||
@@ -119,6 +129,27 @@ class GraphEngine:
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
||||
return self._results[str(handle)]
|
||||
|
||||
def mmu_map(self, va: int, pa: int, size: int) -> None:
|
||||
"""Sideband: install VA→PA mapping in all PE_MMU components."""
|
||||
for node_id, comp in self._components.items():
|
||||
if hasattr(comp, "mmu"):
|
||||
comp.mmu.map(va=va, pa=pa, size=size)
|
||||
|
||||
def mmu_map_pe(
|
||||
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
|
||||
) -> None:
|
||||
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
|
||||
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
|
||||
comp = self._components.get(mmu_id)
|
||||
if comp is not None and hasattr(comp, "mmu"):
|
||||
comp.mmu.map(va=va, pa=pa, size=size)
|
||||
|
||||
def mmu_unmap(self, va: int, size: int) -> None:
|
||||
"""Sideband: remove VA mapping from all PE_MMU components."""
|
||||
for node_id, comp in self._components.items():
|
||||
if hasattr(comp, "mmu"):
|
||||
comp.mmu.unmap(va=va, size=size)
|
||||
|
||||
# ── internal ────────────────────────────────────────────────────
|
||||
|
||||
def _wire(
|
||||
@@ -166,6 +197,11 @@ class GraphEngine:
|
||||
yield from self._process_memory_direct(key, request, done)
|
||||
return
|
||||
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
yield from self._process_mmu_msg(key, request, done)
|
||||
return
|
||||
|
||||
entries = self._entry_points(request)
|
||||
if not entries:
|
||||
self._results[key] = (
|
||||
@@ -341,3 +377,59 @@ class GraphEngine:
|
||||
return entries
|
||||
|
||||
raise ValueError(f"unsupported request type: {type(request)}")
|
||||
|
||||
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
|
||||
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
|
||||
|
||||
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
|
||||
"""
|
||||
start_ns = self._env.now
|
||||
target_sips = getattr(request, "target_sips", "all")
|
||||
|
||||
# Determine target SIPs
|
||||
sip_set: set[int] = set()
|
||||
if target_sips == "all":
|
||||
for ep_id in self._resolver.find_all_pcie_eps():
|
||||
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
|
||||
sip_set.add(sip_id)
|
||||
else:
|
||||
sip_set = set(target_sips)
|
||||
|
||||
entries = []
|
||||
for sip_id in sorted(sip_set):
|
||||
entries.append((
|
||||
self._resolver.find_pcie_ep(sip_id),
|
||||
self._resolver.find_io_cpu(sip_id),
|
||||
0, # MmuMapMsg has no data payload
|
||||
))
|
||||
|
||||
if not entries:
|
||||
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
|
||||
done.succeed()
|
||||
return
|
||||
|
||||
if len(entries) == 1:
|
||||
pcie_ep_id, io_cpu_id, _ = entries[0]
|
||||
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||
txn_done = self._env.event()
|
||||
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
|
||||
yield self._host_queues[pcie_ep_id].put(txn)
|
||||
yield txn_done
|
||||
else:
|
||||
# Multi-SIP fan-out
|
||||
sub_dones = []
|
||||
for pcie_ep_id, io_cpu_id, _ in entries:
|
||||
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||
sub_done = self._env.event()
|
||||
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
|
||||
yield self._host_queues[pcie_ep_id].put(sub_txn)
|
||||
sub_dones.append(sub_done)
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
|
||||
elapsed = self._env.now - start_ns
|
||||
self._results[key] = (
|
||||
Completion(ok=True),
|
||||
{"total_ns": elapsed, "msg_type": request.msg_type},
|
||||
)
|
||||
done.succeed()
|
||||
|
||||
Reference in New Issue
Block a user