Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
@@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
@@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
@@ -47,6 +49,7 @@ __all__ = [
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
+13 -1
View File
@@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase):
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
from kernbench.runtime_api.kernel import (
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
)
target_cubes = getattr(request, "target_cubes", "all")
@@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase):
targets.append(key)
return targets
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
my_sip = self._my_sip()
if target_cubes == "all":
n_cubes = 16
if self.ctx and self.ctx.spec:
sips = self.ctx.spec.get("system", {}).get("sips", {})
n_cubes = sips.get("cubes_per_sip", 16)
return [(my_sip, c) for c in range(n_cubes)]
return [(my_sip, c) for c in target_cubes]
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
+60 -1
View File
@@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase):
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
@@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase):
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
env.process(self._mmu_msg_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
@@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase):
n_slices = mm.get("hbm_slices_per_cube", 8)
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
PE_MMU is a terminal node — completes the transaction directly.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_MMU
sub_dones: list[simpy.Event] = []
for pe_id in pe_ids:
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
# Wait for all PE_MMUs to complete
for sd in sub_dones:
yield sd
# Send aggregate response on reverse path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
+6 -3
View File
@@ -84,12 +84,15 @@ class PeCpuComponent(ComponentBase):
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → PA (pointer), ScalarArg → value
# TensorArg → VA base (or PA fallback), ScalarArg → value
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
if arg.va_base:
kernel_args.append(arg.va_base)
else:
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
+17 -4
View File
@@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase):
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
self._mmu = None # PeMMU instance, set by engine wiring
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
@@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase):
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target PA
# Determine direction and target address (VA → PA via MMU)
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
target_pa = cmd.src_pa
raw_addr = cmd.src_addr
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
target_pa = cmd.dst_pa
raw_addr = cmd.dst_addr
is_write = True
else:
pe_txn.done.succeed()
return
# Resolve PA → HBM node and compute path
# Translate VA → PA via MMU (if available), then resolve HBM node
# If MMU has no mapping for this address (PageFault), treat as PA directly
# (backward-compatible with PA-only mode)
if self._mmu is not None:
from kernbench.policy.address.pe_mmu import PageFault
try:
target_pa = self._mmu.translate(raw_addr)
if self._mmu.overhead_ns > 0:
yield env.timeout(self._mmu.overhead_ns)
except PageFault:
target_pa = raw_addr
else:
target_pa = raw_addr # fallback: treat as PA directly
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
+66
View File
@@ -0,0 +1,66 @@
"""PE_MMU component: address translation unit.
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.policy.address.pe_mmu import PeMMU
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMmuComponent(ComponentBase):
"""PE_MMU: per-PE virtual-to-physical address translation.
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
via the ``mmu`` property for synchronous VA→PA translation.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
@property
def mmu(self) -> PeMMU:
"""The underlying PeMMU utility object for direct translate() calls."""
return self._mmu
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def _worker(self, env: simpy.Environment) -> Generator:
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if hasattr(txn, "request"):
request = txn.request
if isinstance(request, MmuMapMsg):
for entry in request.entries:
self._mmu.map(
va=entry["va"], pa=entry["pa"], size=entry["size"],
)
txn.done.succeed()
elif isinstance(request, MmuUnmapMsg):
for entry in request.entries:
self._mmu.unmap(va=entry["va"], size=entry["size"])
txn.done.succeed()
else:
# Forward non-MMU transactions normally
yield from self._forward_txn(env, txn)
else:
yield from self._forward_txn(env, txn)
@@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
@@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", pa=0,
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
@@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
@@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase):
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done