Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.impls.sram import SramComponent
|
||||
from kernbench.components.impls.xbar import PositionAwareXbarComponent
|
||||
@@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
|
||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||
|
||||
__all__ = [
|
||||
@@ -47,6 +49,7 @@ __all__ = [
|
||||
"PeDmaComponent",
|
||||
"PeGemmComponent",
|
||||
"PeMathComponent",
|
||||
"PeMmuComponent",
|
||||
"PeSchedulerComponent",
|
||||
"PeTcmComponent",
|
||||
"TransitComponent",
|
||||
|
||||
@@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase):
|
||||
|
||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||
"""Return list of (sip, cube) pairs to fan out to."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
|
||||
)
|
||||
|
||||
target_cubes = getattr(request, "target_cubes", "all")
|
||||
|
||||
@@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase):
|
||||
targets.append(key)
|
||||
return targets
|
||||
|
||||
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes == "all":
|
||||
n_cubes = 16
|
||||
if self.ctx and self.ctx.spec:
|
||||
sips = self.ctx.spec.get("system", {}).get("sips", {})
|
||||
n_cubes = sips.get("cubes_per_sip", 16)
|
||||
return [(my_sip, c) for c in range(n_cubes)]
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
|
||||
return []
|
||||
|
||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||
|
||||
@@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase):
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch forward txns, collect response txns."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
@@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase):
|
||||
elif self.ctx is not None and txn.request is not None:
|
||||
if isinstance(txn.request, KernelLaunchMsg):
|
||||
env.process(self._kernel_launch_fanout(env, txn))
|
||||
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
env.process(self._mmu_msg_fanout(env, txn))
|
||||
else:
|
||||
env.process(self._dma_fanout(env, txn))
|
||||
else:
|
||||
@@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase):
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
||||
|
||||
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
|
||||
|
||||
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
|
||||
PE_MMU is a terminal node — completes the transaction directly.
|
||||
"""
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
pe_ids = self._resolve_pe_ids(target_pe)
|
||||
|
||||
if not pe_ids:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Fan out to each PE_MMU
|
||||
sub_dones: list[simpy.Event] = []
|
||||
for pe_id in pe_ids:
|
||||
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
|
||||
try:
|
||||
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=0, done=sub_done,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
sub_dones.append(sub_done)
|
||||
|
||||
# Wait for all PE_MMUs to complete
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
|
||||
# Send aggregate response on reverse path
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||
if isinstance(target_pe, int):
|
||||
|
||||
@@ -84,12 +84,15 @@ class PeCpuComponent(ComponentBase):
|
||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → PA (pointer), ScalarArg → value
|
||||
# TensorArg → VA base (or PA fallback), ScalarArg → value
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
shard = self._find_shard(arg.shards)
|
||||
kernel_args.append(shard.pa)
|
||||
if arg.va_base:
|
||||
kernel_args.append(arg.va_base)
|
||||
else:
|
||||
shard = self._find_shard(arg.shards)
|
||||
kernel_args.append(shard.pa)
|
||||
elif arg.arg_kind == "scalar":
|
||||
kernel_args.append(arg.value)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase):
|
||||
super().__init__(node, ctx)
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
self._mmu = None # PeMMU instance, set by engine wiring
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
@@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase):
|
||||
cmd = pe_txn.command
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
|
||||
# Determine direction and target PA
|
||||
# Determine direction and target address (VA → PA via MMU)
|
||||
if isinstance(cmd, DmaReadCmd):
|
||||
dma_res = self._dma_read
|
||||
target_pa = cmd.src_pa
|
||||
raw_addr = cmd.src_addr
|
||||
is_write = False
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
dma_res = self._dma_write
|
||||
target_pa = cmd.dst_pa
|
||||
raw_addr = cmd.dst_addr
|
||||
is_write = True
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
return
|
||||
|
||||
# Resolve PA → HBM node and compute path
|
||||
# Translate VA → PA via MMU (if available), then resolve HBM node
|
||||
# If MMU has no mapping for this address (PageFault), treat as PA directly
|
||||
# (backward-compatible with PA-only mode)
|
||||
if self._mmu is not None:
|
||||
from kernbench.policy.address.pe_mmu import PageFault
|
||||
try:
|
||||
target_pa = self._mmu.translate(raw_addr)
|
||||
if self._mmu.overhead_ns > 0:
|
||||
yield env.timeout(self._mmu.overhead_ns)
|
||||
except PageFault:
|
||||
target_pa = raw_addr
|
||||
else:
|
||||
target_pa = raw_addr # fallback: treat as PA directly
|
||||
pa = PhysAddr.decode(target_pa)
|
||||
dst_node = self.ctx.resolver.resolve(pa)
|
||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""PE_MMU component: address translation unit.
|
||||
|
||||
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
|
||||
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeMmuComponent(ComponentBase):
|
||||
"""PE_MMU: per-PE virtual-to-physical address translation.
|
||||
|
||||
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
|
||||
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
|
||||
via the ``mmu`` property for synchronous VA→PA translation.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
|
||||
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
|
||||
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
|
||||
|
||||
@property
|
||||
def mmu(self) -> PeMMU:
|
||||
"""The underlying PeMMU utility object for direct translate() calls."""
|
||||
return self._mmu
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
|
||||
if hasattr(txn, "request"):
|
||||
request = txn.request
|
||||
if isinstance(request, MmuMapMsg):
|
||||
for entry in request.entries:
|
||||
self._mmu.map(
|
||||
va=entry["va"], pa=entry["pa"], size=entry["size"],
|
||||
)
|
||||
txn.done.succeed()
|
||||
elif isinstance(request, MmuUnmapMsg):
|
||||
for entry in request.entries:
|
||||
self._mmu.unmap(va=entry["va"], size=entry["size"])
|
||||
txn.done.succeed()
|
||||
else:
|
||||
# Forward non-MMU transactions normally
|
||||
yield from self._forward_txn(env, txn)
|
||||
else:
|
||||
yield from self._forward_txn(env, txn)
|
||||
@@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase):
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
|
||||
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
@@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase):
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", pa=0,
|
||||
id=f"out_tile_{tile_idx}", addr=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
@@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase):
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
|
||||
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
@@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase):
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
Reference in New Issue
Block a user