Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+7 -4
View File
@@ -207,12 +207,15 @@ benchmark instances by default.
## R10. Memory Addressing (Phase 0)
In Phase 0, the simulator uses a **PA-first memory model**:
The simulator uses a **VA/PA memory model** (ADR-0011):
- All memory operations use device physical addresses (PA) only.
- Virtual addressing, MMU/IOMMU, and address translation latency are out of scope.
- Tensors are assigned a contiguous virtual address (VA) range at deployment.
- PE_MMU translates VA→PA per access; TLB overhead is configurable.
- Mapping installation (MmuMapMsg) traverses the fabric with measured latency.
- Replicate tensors use per-cube local PA mapping; sharded tensors broadcast.
- PA-only fallback is retained for backward compatibility.
- Tensor placement is represented as a list of PA shards, each explicitly tagged
with `(sip, cube, pe)`.
with `(sip, cube, pe)`, plus a tensor-wide `va_base`.
All memory access latency MUST be modeled explicitly via graph traversal.
No implicit translation or hidden latency is allowed.
+1 -1
View File
@@ -1,2 +1,2 @@
def run(ctx):
def run(torch):
print("IPCQ all reduce kernel bench")
+2 -2
View File
@@ -15,7 +15,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
Expected layout (repo root):
benches/<bench_id>.py
def run(ctx: RuntimeContext) -> Any
def run(torch: RuntimeContext) -> Any
"""
bench_id = bench_id.strip()
if not bench_id:
@@ -30,7 +30,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
run_fn = getattr(mod, "run", None)
if run_fn is None:
raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.")
raise ValueError(f"Bench module {module_path} must define a 'run(torch)' function.")
if not callable(run_fn):
raise ValueError(f"'run' in {module_path} is not callable.")
+5 -5
View File
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
tl.wait(handle)
def run(ctx):
def run(torch):
"""Run the QKV GEMM benchmark."""
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
out = ctx.empty(
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
out = torch.empty(
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
)
# Launch GEMM kernel
ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
torch.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
+5 -5
View File
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
tl.wait(handle)
def run(ctx):
def run(torch):
"""Run the multi-PE QKV GEMM benchmark."""
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
out = ctx.empty(
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
out = torch.empty(
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
)
# Launch GEMM kernel on all PEs
ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
torch.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
@@ -1,8 +1,8 @@
# ADR-0011: Memory Addressing Simplification (PA-first)
# ADR-0011: Memory Addressing — PA-first with VA/MMU Extension
## Status
Accepted
Accepted (Phase 1 VA/MMU implemented)
## Context
@@ -11,49 +11,82 @@ translation path for DMA: host allocates physical memory at PE level, maps it
into a virtual address space, installs mappings, and DMA requests use virtual
addresses that are translated to physical addresses.
For early development, we want a minimal, deterministic model that enables:
- correct routing and latency accounting through the graph,
- stable tensor deployment and kernel execution semantics,
- future extension toward VA/MMU without rewriting workflows.
The PA-only model (Phase 0) was insufficient for running standard Triton kernels
that use `base_addr + offset` patterns on sharded tensors — each PE's shard has
a different PA, but the kernel needs a single contiguous address space.
---
## Decision
### D1. Phase 0 model is PA-only
The simulator uses a PA-first model:
### D1. Phase 0 model is PA-only (original, retained as fallback)
- All device memory accesses (MemoryRead/MemoryWrite) operate on device physical
addresses (PA) plus size.
- Tensor handles store PA-based shard mappings after deployment.
- KernelLaunch passes tensor arguments as PA-based mappings (or references to them).
- MMU/IOMMU concepts (virtual address spaces, page tables, translation latency)
are NOT modeled in Phase 0.
- PA-only mode remains functional via PageFault fallback in PE_DMA.
### D2. Allocation produces PA mappings
Device allocation selects PE-local memory regions and returns PA mappings
sufficient to execute kernels and issue DMA requests.
### D3. Extension path (non-breaking)
### D3. Phase 1: VA/MMU layer (implemented)
A future ADR MAY introduce an optional VA/MMU layer by:
#### D3.1 Virtual Address Model
- introducing virtual addresses in tensor handles,
- adding a mapping-install step,
- modeling translation latency and page granularity.
- Each tensor gets a single contiguous VA range (`TensorHandle.va_base`).
- `TensorShard` does NOT carry a `va` field — shard VA is derived as
`va_base + offset_bytes`.
- Kernels receive `va_base` as their pointer argument (via `TensorArg.va_base`).
- `DmaReadCmd.src_addr` and `DmaWriteCmd.dst_addr` carry VA (not PA).
The Phase 0 PA model remains a valid fast-path configuration.
#### D3.2 PE_MMU Component
- Hybrid design: SimPy component (inbox for MmuMapMsg) + utility (synchronous
`translate()` called by PE_DMA).
- Page-aligned dict lookup for O(1) VA→PA translation.
- `tlb_overhead_ns` configurable per-access latency.
- PageFault fallback: if VA has no mapping, PE_DMA treats it as PA directly
(backward compatibility with PA-only tests).
#### D3.3 Mapping Installation
- `MmuMapMsg` traverses the fabric: Host → PCIE_EP → IO_CPU (cube fan-out) →
M_CPU (PE fan-out) → NOC → PE_MMU. Latency is measured end-to-end.
- `MmuMapMsg.target_sips` controls SIP-level routing to prevent cross-SIP
mapping contamination for replicated tensors.
- Mapping strategy based on `DPPolicy.cube`:
- **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only.
Each cube's PEs see only their local PA. No cross-cube mapping installed.
- **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all
target cubes. Enables cross-PE and cross-cube DMA.
#### D3.4 Tensor Lifecycle
- `del tensor` triggers automatic cleanup via `Tensor.__del__` + `weakref` to
RuntimeContext. Sends `MmuUnmapMsg` through fabric, returns VA and PA space.
- `with RuntimeContext(...) as ctx:` provides scope-based bulk cleanup.
- `RuntimeContext._tensors` uses `weakref.ref` to avoid preventing GC.
- `PEMemAllocator` uses free-list with coalescing (not bump allocator).
- `VirtualAllocator` uses free-list with coalescing for VA space.
#### D3.5 Allocators
- `VirtualAllocator`: device-wide VA space, page-aligned alloc/free with
coalescing.
- `PEMemAllocator`: per-PE HBM/TCM, free-list based alloc/free with coalescing.
- Page size configurable via `topology.yaml` pe_mmu attrs (default 4096).
---
## Consequences
- Early implementation stays simple and testable.
- All latency remains explicit via graph traversal, not hidden translation.
- Future VA/MMU modeling can be added without breaking existing benchmarks.
- Triton kernels use `base_addr + offset` patterns naturally on sharded tensors.
- All latency remains explicit via graph traversal, including MMU mapping
installation and per-access TLB overhead.
- PA-only mode retained as fallback (PageFault → treat as PA).
- Benchmark parameter renamed `ctx``torch` for PyTorch code compatibility.
- IPCQ and other fixed-address resources bypass MMU (use PA directly).
---
@@ -62,4 +95,6 @@ The Phase 0 PA model remains a valid fast-path configuration.
- ADR-0007 (runtime_api vs sim_engine boundaries)
- ADR-0008 (tensor deployment)
- ADR-0009 (kernel execution)
- ADR-0014 (PE-internal execution model)
- ADR-0015 (component port/wire model)
- SPEC R2 (latency by traversal)
+6 -6
View File
@@ -28,7 +28,7 @@ class TensorHandle:
"""
id: str
pa: int # physical address in HBM/TCM
addr: int # address (VA when MMU enabled, PA otherwise)
shape: tuple[int, ...]
dtype: str
nbytes: int # total byte size
@@ -50,19 +50,19 @@ class CompletionHandle:
@dataclass(frozen=True)
class DmaReadCmd:
"""DMA READ: HBM → PE_TCM."""
"""DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
src_pa: int
src_addr: int
nbytes: int
@dataclass(frozen=True)
class DmaWriteCmd:
"""DMA WRITE: PE_TCM → HBM."""
"""DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
dst_pa: int
dst_addr: int
nbytes: int
@@ -108,7 +108,7 @@ class CompositeCmd:
op: Literal["gemm", "math"]
a: TensorHandle
b: TensorHandle | None
out_pa: int
out_addr: int
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
@@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
@@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
@@ -47,6 +49,7 @@ __all__ = [
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
+13 -1
View File
@@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase):
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
from kernbench.runtime_api.kernel import (
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
)
target_cubes = getattr(request, "target_cubes", "all")
@@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase):
targets.append(key)
return targets
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
my_sip = self._my_sip()
if target_cubes == "all":
n_cubes = 16
if self.ctx and self.ctx.spec:
sips = self.ctx.spec.get("system", {}).get("sips", {})
n_cubes = sips.get("cubes_per_sip", 16)
return [(my_sip, c) for c in range(n_cubes)]
return [(my_sip, c) for c in target_cubes]
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
+60 -1
View File
@@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase):
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
@@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase):
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
env.process(self._mmu_msg_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
@@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase):
n_slices = mm.get("hbm_slices_per_cube", 8)
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
PE_MMU is a terminal node — completes the transaction directly.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_MMU
sub_dones: list[simpy.Event] = []
for pe_id in pe_ids:
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
# Wait for all PE_MMUs to complete
for sd in sub_dones:
yield sd
# Send aggregate response on reverse path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
+6 -3
View File
@@ -84,12 +84,15 @@ class PeCpuComponent(ComponentBase):
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → PA (pointer), ScalarArg → value
# TensorArg → VA base (or PA fallback), ScalarArg → value
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
if arg.va_base:
kernel_args.append(arg.va_base)
else:
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
+17 -4
View File
@@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase):
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
self._mmu = None # PeMMU instance, set by engine wiring
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
@@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase):
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target PA
# Determine direction and target address (VA → PA via MMU)
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
target_pa = cmd.src_pa
raw_addr = cmd.src_addr
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
target_pa = cmd.dst_pa
raw_addr = cmd.dst_addr
is_write = True
else:
pe_txn.done.succeed()
return
# Resolve PA → HBM node and compute path
# Translate VA → PA via MMU (if available), then resolve HBM node
# If MMU has no mapping for this address (PageFault), treat as PA directly
# (backward-compatible with PA-only mode)
if self._mmu is not None:
from kernbench.policy.address.pe_mmu import PageFault
try:
target_pa = self._mmu.translate(raw_addr)
if self._mmu.overhead_ns > 0:
yield env.timeout(self._mmu.overhead_ns)
except PageFault:
target_pa = raw_addr
else:
target_pa = raw_addr # fallback: treat as PA directly
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
+66
View File
@@ -0,0 +1,66 @@
"""PE_MMU component: address translation unit.
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.policy.address.pe_mmu import PeMMU
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMmuComponent(ComponentBase):
"""PE_MMU: per-PE virtual-to-physical address translation.
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
via the ``mmu`` property for synchronous VA→PA translation.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
@property
def mmu(self) -> PeMMU:
"""The underlying PeMMU utility object for direct translate() calls."""
return self._mmu
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def _worker(self, env: simpy.Environment) -> Generator:
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if hasattr(txn, "request"):
request = txn.request
if isinstance(request, MmuMapMsg):
for entry in request.entries:
self._mmu.map(
va=entry["va"], pa=entry["pa"], size=entry["size"],
)
txn.done.succeed()
elif isinstance(request, MmuUnmapMsg):
for entry in request.entries:
self._mmu.unmap(va=entry["va"], size=entry["size"])
txn.done.succeed()
else:
# Forward non-MMU transactions normally
yield from self._forward_txn(env, txn)
else:
yield from self._forward_txn(env, txn)
@@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
@@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", pa=0,
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
@@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
@@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase):
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
+83 -16
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import bisect
from dataclasses import dataclass
from kernbench.policy.address.phyaddr import PhysAddr
@@ -29,6 +30,63 @@ class AddressConfig:
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
class _FreeList:
"""Offset-based free-list allocator with coalescing."""
def __init__(self, capacity: int) -> None:
self._capacity = capacity
self._used = 0
self._free: list[tuple[int, int]] = [(0, capacity)] # (offset, size)
@property
def used(self) -> int:
return self._used
@property
def total(self) -> int:
return self._capacity
def alloc(self, nbytes: int) -> int:
"""Allocate nbytes, return offset. Raises AllocationError if full."""
for i, (start, size) in enumerate(self._free):
if size >= nbytes:
if size == nbytes:
self._free.pop(i)
else:
self._free[i] = (start + nbytes, size - nbytes)
self._used += nbytes
return start
raise AllocationError(
f"overflow: need {nbytes}, "
f"largest free block {max((s for _, s in self._free), default=0)}"
)
def free(self, offset: int, nbytes: int) -> None:
"""Return a range to the free-list with coalescing."""
self._used -= nbytes
new_start = offset
new_end = offset + nbytes
idx = bisect.bisect_left(self._free, (offset,))
# Coalesce with previous block
if idx > 0:
prev_start, prev_size = self._free[idx - 1]
if prev_start + prev_size == new_start:
new_start = prev_start
idx -= 1
self._free.pop(idx)
# Coalesce with next block
if idx < len(self._free):
next_start, next_size = self._free[idx]
if new_end == next_start:
new_end = next_start + next_size
self._free.pop(idx)
self._free.insert(idx, (new_start, new_end - new_start))
class PEMemAllocator:
def __init__(
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
@@ -38,39 +96,48 @@ class PEMemAllocator:
self._cube_id = cube_id
self._pe_id = pe_id
self._cfg = cfg
self._hbm_cursor = 0
self._tcm_cursor = 0
self._hbm = _FreeList(cfg.hbm_slice_bytes)
self._tcm = _FreeList(cfg.tcm_allocatable_bytes)
def alloc_hbm(self, nbytes: int) -> PhysAddr:
if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes:
try:
offset = self._hbm.alloc(nbytes)
except AllocationError:
raise AllocationError(
f"HBM overflow: need {nbytes}, "
f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}"
f"available {self._cfg.hbm_slice_bytes - self._hbm.used}"
)
pa = PhysAddr.pe_hbm_addr(
return PhysAddr.pe_hbm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor,
pe_id=self._pe_id, pe_local_hbm_offset=offset,
slice_size_bytes=self._cfg.hbm_slice_bytes,
)
self._hbm_cursor += nbytes
return pa
def free_hbm(self, pa: PhysAddr, nbytes: int) -> None:
# Extract PE-local offset from the PA's hbm_offset
pe_slice_start = self._pe_id * self._cfg.hbm_slice_bytes
offset = pa.hbm_offset - pe_slice_start
self._hbm.free(offset, nbytes)
def alloc_tcm(self, nbytes: int) -> PhysAddr:
if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes:
try:
offset = self._tcm.alloc(nbytes)
except AllocationError:
raise AllocationError(
f"TCM overflow: need {nbytes}, "
f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}"
f"available {self._cfg.tcm_allocatable_bytes - self._tcm.used}"
)
pa = PhysAddr.pe_tcm_addr(
return PhysAddr.pe_tcm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, tcm_offset=self._tcm_cursor,
pe_id=self._pe_id, tcm_offset=offset,
)
self._tcm_cursor += nbytes
return pa
def free_tcm(self, pa: PhysAddr, nbytes: int) -> None:
self._tcm.free(pa.sub_offset, nbytes)
@property
def hbm_used(self) -> int:
return self._hbm_cursor
return self._hbm.used
@property
def hbm_total(self) -> int:
@@ -78,7 +145,7 @@ class PEMemAllocator:
@property
def tcm_used(self) -> int:
return self._tcm_cursor
return self._tcm.used
@property
def tcm_total(self) -> int:
+66
View File
@@ -0,0 +1,66 @@
"""PeMMU: per-PE virtual-to-physical address translation.
Page-aligned dict lookup for O(1) VA→PA translation.
Used as a utility class by PE_DMA, PE_GEMM (direct call),
and as a component inbox target for MmuMapMsg/MmuUnmapMsg.
"""
from __future__ import annotations
class PageFault(Exception):
"""Raised when VA has no mapping in the page table."""
def __init__(self, va: int | str) -> None:
if isinstance(va, str):
super().__init__(va)
else:
self.va = va
super().__init__(f"PageFault at VA 0x{va:x}")
class PeMMU:
"""Per-PE MMU with page-aligned VA→PA translation table.
Args:
page_size: Page size in bytes (default 2 MB).
overhead_ns: Per-access TLB lookup latency in nanoseconds.
"""
def __init__(
self,
page_size: int = 2 * 1024 * 1024,
overhead_ns: float = 0.0,
) -> None:
self._page_size = page_size
self._page_shift = (page_size - 1).bit_length()
self._page_mask = page_size - 1
self._table: dict[int, int] = {} # va_page_number → pa_page_base
self._overhead_ns = overhead_ns
@property
def overhead_ns(self) -> float:
return self._overhead_ns
@property
def num_entries(self) -> int:
return len(self._table)
def map(self, va: int, pa: int, size: int) -> None:
"""Register VA→PA mapping for a contiguous range."""
for off in range(0, size, self._page_size):
vpn = (va + off) >> self._page_shift
self._table[vpn] = pa + off
def unmap(self, va: int, size: int) -> None:
"""Remove VA mapping for a contiguous range."""
for off in range(0, size, self._page_size):
vpn = (va + off) >> self._page_shift
self._table.pop(vpn, None)
def translate(self, va: int) -> int:
"""Translate VA to PA. Raises PageFault if unmapped."""
vpn = va >> self._page_shift
pa_page_base = self._table.get(vpn)
if pa_page_base is None:
raise PageFault(va)
return pa_page_base + (va & self._page_mask)
@@ -0,0 +1,93 @@
"""VirtualAllocator: device-wide VA space management with free-list.
Allocations are page-aligned. Freed ranges are coalesced with adjacent
free blocks to allow larger re-allocations.
"""
from __future__ import annotations
import bisect
class VaAllocationError(Exception):
pass
class VirtualAllocator:
"""Manages a contiguous VA address space with page-aligned alloc/free.
Args:
va_base: Start of the VA range.
va_size: Total size of the VA range in bytes.
page_size: Page granularity in bytes.
"""
def __init__(
self,
va_base: int,
va_size: int,
page_size: int = 2 * 1024 * 1024,
) -> None:
self._va_base = va_base
self._va_size = va_size
self._page_size = page_size
self._used = 0
# Free list: sorted list of (start, size) tuples
self._free: list[tuple[int, int]] = [(va_base, va_size)]
def _align_up(self, nbytes: int) -> int:
"""Round up to page boundary."""
return ((nbytes + self._page_size - 1) // self._page_size) * self._page_size
def alloc(self, nbytes: int) -> int:
"""Allocate a contiguous VA range. Returns the start VA."""
aligned = self._align_up(nbytes)
for i, (start, size) in enumerate(self._free):
if size >= aligned:
# Take from the beginning of this free block
if size == aligned:
self._free.pop(i)
else:
self._free[i] = (start + aligned, size - aligned)
self._used += aligned
return start
raise VaAllocationError(
f"Out of VA space: need {aligned}, largest free block "
f"{max((s for _, s in self._free), default=0)}"
)
def free(self, va: int, nbytes: int) -> None:
"""Free a VA range and coalesce with adjacent free blocks."""
aligned = self._align_up(nbytes)
self._used -= aligned
# Insert into sorted free list and coalesce
new_start = va
new_end = va + aligned
# Find insertion point
idx = bisect.bisect_left(self._free, (va,))
# Try coalesce with previous block
if idx > 0:
prev_start, prev_size = self._free[idx - 1]
if prev_start + prev_size == new_start:
new_start = prev_start
idx -= 1
self._free.pop(idx)
# Try coalesce with next block
if idx < len(self._free):
next_start, next_size = self._free[idx]
if new_end == next_start:
new_end = next_start + next_size
self._free.pop(idx)
self._free.insert(idx, (new_start, new_end - new_start))
@property
def used(self) -> int:
return self._used
@property
def total(self) -> int:
return self._va_size
+188 -22
View File
@@ -19,8 +19,18 @@ class RuntimeContext:
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
def __enter__(self):
return self
def __exit__(self, *exc):
self.cleanup()
return False
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
@@ -58,6 +68,92 @@ class RuntimeContext:
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── Tensor lifecycle ─────────────────────────────────────────────
def _free_tensor(self, tensor: Any) -> None:
"""Free a single tensor: unmap MMU, return VA and PA."""
handle = tensor._handle
if handle is None:
return
tensor._handle = None
if not handle.va_base:
return
from kernbench.runtime_api.kernel import MmuUnmapMsg
dp_policy = None
if tensor._dp_metadata is not None:
dp_policy = tensor._dp_metadata.dp_policy
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
# Send MmuUnmapMsg through fabric
from collections import defaultdict
if is_cube_replicate:
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in group_shards
)
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Return VA space
if self._va_allocator is not None:
self._va_allocator.free(handle.va_base, handle.nbytes)
# Return PA space
if self._allocators:
for shard in handle.shards:
flat_idx = (
shard.sip * self._num_cubes * self._pes_per_cube
+ shard.cube * self._pes_per_cube
+ shard.pe
)
alloc = self._allocators.get(flat_idx)
if alloc is not None:
from kernbench.policy.address.phyaddr import PhysAddr
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
def cleanup(self) -> None:
"""Free all tensors created by this context."""
for ref in self._tensors:
t = ref()
if t is not None and t._handle is not None:
self._free_tensor(t)
self._tensors.clear()
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
@@ -111,6 +207,26 @@ class RuntimeContext:
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
def _next_tensor_name(self) -> str:
@@ -122,63 +238,57 @@ class RuntimeContext:
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
placement: list | None,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
if not isinstance(dp, DPPolicy):
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy: DPPolicy | None = None
# Resolve placement: dp= takes priority over placement=
if dp is not None and isinstance(dp, DPPolicy):
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
elif placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
pe_indices = {s.pe_index for s in placement}
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
@@ -186,8 +296,64 @@ class RuntimeContext:
dtype=dtype,
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
t._ctx_ref = weakref.ref(self)
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
+31
View File
@@ -69,6 +69,7 @@ class TensorArgShard:
class TensorArg:
shards: tuple[TensorArgShard, ...]
arg_kind: Literal["tensor"] = "tensor"
va_base: int = 0 # VA base address for the entire tensor
@dataclass(frozen=True)
@@ -121,3 +122,33 @@ class PeDmaMsg:
nbytes: int
is_write: bool = False
msg_type: Literal["pe_dma"] = "pe_dma"
@dataclass(frozen=True)
class MmuMapMsg:
"""MMU mapping install: broadcast VA→PA entries to target PEs.
Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU.
target_sips controls which SIPs receive the message.
"""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_map"] = "mmu_map"
@dataclass(frozen=True)
class MmuUnmapMsg:
"""MMU mapping removal: broadcast VA ranges to unmap from all PEs."""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
+40 -3
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import weakref
from dataclasses import dataclass
from typing import Literal
@@ -26,6 +27,7 @@ class TensorHandle:
dtype: str
itemsize: int
shards: tuple[TensorShard, ...]
va_base: int = 0 # VA base address for the entire tensor
@property
def nbytes(self) -> int:
@@ -56,8 +58,19 @@ def deploy_tensor(
placement: list[ShardSpec],
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize
# Allocate VA range for the entire tensor (if VA allocator provided)
va_base = 0
if va_allocator is not None:
va_base = va_allocator.alloc(total_nbytes)
shards: list[TensorShard] = []
for spec in placement:
alloc = allocators[spec.pe_index]
@@ -65,20 +78,29 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=pa.encode(),
pa=encoded_pa,
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle(
name=name,
shape=shape,
dtype=dtype,
itemsize=isize,
shards=tuple(shards),
va_base=va_base,
)
@@ -101,8 +123,7 @@ class Tensor:
Usage::
a = ctx.zeros((M, K), dtype="f16")
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
"""
@@ -117,6 +138,14 @@ class Tensor:
self.name = name
self._dp_metadata: DPMetadata | None = None
self._handle: TensorHandle | None = None
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
def __del__(self) -> None:
if self._ctx_ref is None or self._handle is None:
return
ctx = self._ctx_ref()
if ctx is not None:
ctx._free_tensor(self)
@property
def itemsize(self) -> int:
@@ -133,6 +162,13 @@ class Tensor:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return self._handle.shards[0].pa
@property
def va(self) -> int:
"""VA base address for the entire tensor."""
if self._handle is None:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return self._handle.va_base
def to(
self,
placement: list[ShardSpec] | None = None,
@@ -163,4 +199,5 @@ class Tensor:
)
for s in self._handle.shards
),
va_base=self._handle.va_base,
)
+92
View File
@@ -98,6 +98,16 @@ class GraphEngine:
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
# Derive PE_MMU node ID from PE_DMA node ID
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
mmu_id = f"{pe_prefix}.pe_mmu"
mmu_comp = self._components.get(mmu_id)
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
self._components[node_id]._mmu = mmu_comp.mmu
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
@@ -119,6 +129,27 @@ class GraphEngine:
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
def mmu_map(self, va: int, pa: int, size: int) -> None:
"""Sideband: install VA→PA mapping in all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_map_pe(
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
) -> None:
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
comp = self._components.get(mmu_id)
if comp is not None and hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_unmap(self, va: int, size: int) -> None:
"""Sideband: remove VA mapping from all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.unmap(va=va, size=size)
# ── internal ────────────────────────────────────────────────────
def _wire(
@@ -166,6 +197,11 @@ class GraphEngine:
yield from self._process_memory_direct(key, request, done)
return
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
yield from self._process_mmu_msg(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
@@ -341,3 +377,59 @@ class GraphEngine:
return entries
raise ValueError(f"unsupported request type: {type(request)}")
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
"""
start_ns = self._env.now
target_sips = getattr(request, "target_sips", "all")
# Determine target SIPs
sip_set: set[int] = set()
if target_sips == "all":
for ep_id in self._resolver.find_all_pcie_eps():
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
sip_set.add(sip_id)
else:
sip_set = set(target_sips)
entries = []
for sip_id in sorted(sip_set):
entries.append((
self._resolver.find_pcie_ep(sip_id),
self._resolver.find_io_cpu(sip_id),
0, # MmuMapMsg has no data payload
))
if not entries:
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
done.succeed()
return
if len(entries) == 1:
pcie_ep_id, io_cpu_id, _ = entries[0]
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP fan-out
sub_dones = []
for pcie_ep_id, io_cpu_id, _ in entries:
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
sub_done = self._env.event()
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
yield self._host_queues[pcie_ep_id].put(sub_txn)
sub_dones.append(sub_done)
for sd in sub_dones:
yield sd
elapsed = self._env.now - start_ns
self._results[key] = (
Completion(ok=True),
{"total_ns": elapsed, "msg_type": request.msg_type},
)
done.succeed()
+11
View File
@@ -22,6 +22,7 @@ _PE_COMP_OFFSETS = {
"pe_dma": (0.0, -0.15),
"pe_gemm": (0.0, 0.0),
"pe_math": (0.0, 0.15),
"pe_mmu": (0.15, -0.15),
"pe_tcm": (0.3, 0.0),
}
@@ -495,6 +496,15 @@ def _instantiate_cube(
kind="pe_response",
))
# noc → PE_MMU (MMU mapping install)
pe_mmu_id = f"{pp}.pe_mmu"
if pe_mmu_id in nodes:
edges.append(Edge(
src=f"{cp}.noc", dst=pe_mmu_id,
distance_mm=clinks.get("noc_to_pe_mmu_mm", 0.0),
kind="command",
))
pe_idx += 1
# ── xbar_top/bot → HBM slices ──
@@ -1073,6 +1083,7 @@ def _build_pe_view(spec: dict) -> ViewGraph:
"pe_dma": (7.0, 1.5),
"pe_gemm": (7.0, 4.0),
"pe_math": (7.0, 6.5),
"pe_mmu": (4.0, 1.5),
"pe_tcm": (10.0, 4.0),
}
+16 -16
View File
@@ -86,11 +86,11 @@ class TLContext:
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, pa: int, shape: tuple[int, ...], dtype: str,
self, addr: int, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
pa=pa, shape=shape, dtype=dtype,
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
)
@@ -104,7 +104,7 @@ class TLContext:
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
@@ -113,9 +113,9 @@ class TLContext:
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle."""
self._emit_dispatch_overhead()
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
self._commands.append(DmaReadCmd(
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
))
return handle
@@ -123,7 +123,7 @@ class TLContext:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
self._commands.append(DmaWriteCmd(
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
))
# ── GEMM Engine (blocking) ────────────────────────────────────
@@ -141,7 +141,7 @@ class TLContext:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
@@ -149,7 +149,7 @@ class TLContext:
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
return out
@@ -182,7 +182,7 @@ class TLContext:
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
@@ -201,7 +201,7 @@ class TLContext:
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
return out
@@ -209,7 +209,7 @@ class TLContext:
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
@@ -227,17 +227,17 @@ class TLContext:
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
@@ -247,7 +247,7 @@ class TLContext:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, pa=x.pa, shape=new_shape,
id=x.id, addr=x.addr, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
@@ -278,7 +278,7 @@ class TLContext:
self._emit_dispatch_overhead()
self._commands.append(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion
+226
View File
@@ -0,0 +1,226 @@
"""Tests for PE_MMU component integration and MmuMapMsg fabric path.
Validates:
T15-a. PE_MMU component registered in ComponentRegistry
T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table
T15-c. PE_DMA translates VA→PA via mmu before routing
T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields
T17. PE_CPU passes VA (not PA) to kernel when VA is available
T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA
"""
import pytest
try:
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
except ImportError:
pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True)
# ── T16. MmuMapMsg / MmuUnmapMsg message types ──────────────────────
def test_mmu_map_msg_fields():
"""MmuMapMsg carries VA→PA mapping entries for broadcast."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_map"
assert len(msg.entries) == 2
assert msg.entries[0]["va"] == 0x1_0000_0000
assert msg.entries[0]["pa"] == 0xA000_0000
assert msg.entries[0]["size"] == 4096
def test_mmu_map_msg_immutable():
"""MmuMapMsg is frozen."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(),
target_cubes="all",
target_pe="all",
)
with pytest.raises(AttributeError):
msg.entries = () # type: ignore[misc]
def test_mmu_unmap_msg_fields():
"""MmuUnmapMsg carries VA ranges to unmap."""
msg = MmuUnmapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_unmap"
assert len(msg.entries) == 1
assert msg.entries[0]["va"] == 0x1_0000_0000
# ── T15-a. PE_MMU component registry ────────────────────────────────
def test_pe_mmu_registry():
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="pe_mmu_v1",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5},
)
comp = ComponentRegistry.create(node)
assert isinstance(comp, PeMmuComponent)
# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ─────────
def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works."""
import simpy
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node
env = simpy.Environment()
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="pe_mmu_v1",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5, "page_size": 4096},
)
comp = PeMmuComponent(node)
comp.in_ports["src"] = simpy.Store(env)
comp.start(env)
# Submit MmuMapMsg via inbox
map_msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
done = env.event()
txn = Transaction(
request=map_msg,
path=["sip0.cube0.pe0.pe_mmu"],
step=0, nbytes=0, done=done,
)
def inject():
yield comp._inbox.put(txn)
env.process(inject())
env.run(until=100)
# After processing, the MMU's translate should work
from kernbench.policy.address.pe_mmu import PeMMU
mmu = comp.mmu # the underlying PeMMU utility object
assert isinstance(mmu, PeMMU)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
# ── T15-c. PE_DMA uses MMU translate ────────────────────────────────
def test_pe_dma_translates_va():
"""PE_DMA.handle_command calls mmu.translate(va) → PA before routing.
This test validates the contract: after Phase 2, DmaReadCmd carries VA,
and PE_DMA must translate it to PA via the MMU before resolving the
HBM node path.
"""
# This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_dma",
kind="pe_dma",
impl="pe_dma_v1",
pos_mm=None,
attrs={"rd_engines": 1, "wr_engines": 1},
)
comp = PeDmaComponent(node)
# PE_DMA must have a way to access the MMU (via ctx or direct reference)
# The exact wiring mechanism is flexible, but the attribute must exist
assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or (
hasattr(comp, 'ctx') and comp.ctx is not None
), "PE_DMA must have access to PE_MMU for VA translation"
# ── T17. PE_CPU passes VA to kernel ──────────────────────────────────
def test_pe_cpu_uses_va_base_from_tensor_arg():
"""PE_CPU should use TensorArg.va_base for kernel pointer args.
After Phase 2, TensorArg carries va_base alongside shards.
PE_CPU extracts va_base and passes it to the kernel function
so kernels operate on VA (not PA).
"""
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000,
nbytes=4096, offset_bytes=0)
targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000)
# PE_CPU should use targ.va_base for kernel pointer arg
assert targ.va_base == 0x1_0000_0000
# PA still accessible via shard for direct-PA operations (IPCQ etc.)
assert shard.pa == 0x1000
# ── T18. MmuMapMsg broadcast pattern ─────────────────────────────────
def test_mmu_map_msg_broadcast_target():
"""MmuMapMsg with target_pe='all' is a broadcast to all PEs."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},),
target_cubes="all",
target_pe="all",
)
assert msg.target_pe == "all"
assert msg.target_cubes == "all"
def test_mmu_map_msg_same_entries_all_pes():
"""All PEs in a broadcast receive identical entries (not per-PE splits)."""
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192},
{"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192},
)
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=entries,
target_cubes="all",
target_pe="all",
)
# The message carries the full mapping — every PE receives exactly this
assert msg.entries == entries
+241
View File
@@ -0,0 +1,241 @@
"""Tests for MmuMapMsg fabric path and cross-cube mapping.
Validates:
F1. MmuMapMsg traverses fabric: latency > 0 (not sideband)
F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs
F3. After MmuMapMsg, PE_MMU has correct mappings
F4. Cross-cube sharded tensor: all PEs get global mappings
F5. Replicate tensor: each PE gets own cube's PA (local override)
F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM
F7. Overlap detection: replicate vs sharded identified correctly
F8. Existing regression: PA-only benchmarks still pass
"""
import pytest
from pathlib import Path
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec
from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _engine():
return GraphEngine(load_topology(TOPOLOGY_PATH))
# ── F1. MmuMapMsg fabric latency ─────────────────────────────────────
def test_mmu_map_via_fabric_has_latency():
"""MmuMapMsg submitted through engine.submit() completes with latency > 0."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_0",
entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},),
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
# Fabric traversal must have non-zero latency
assert trace is not None
assert trace.get("total_ns", 0) > 0
# ── F2. MmuMapMsg fan-out ────────────────────────────────────────────
def test_mmu_map_reaches_all_pes_in_cube():
"""MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_1",
entries=({"va": va, "pa": pa, "size": size},),
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
# Verify all 8 PE_MMUs in cube 0 have the mapping
for pe_id in range(8):
mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu"
mmu_comp = engine._components[mmu_id]
assert mmu_comp.mmu.translate(va) == pa
# ── F3. Multiple MmuMapMsg entries ───────────────────────────────────
def test_mmu_map_multiple_entries():
"""MmuMapMsg with multiple entries installs all of them."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
)
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_2",
entries=entries,
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000
# ── F4. Cross-cube sharded: global mapping ───────────────────────────
def test_cross_cube_sharded_all_pes_get_global_mapping():
"""For sharded tensor across cubes (unique offsets), all PEs get all mappings."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
# Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1
)
# Broadcast to both cubes
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_xc",
entries=entries,
target_cubes=(0, 1),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
# PE in cube0 can translate both cube0 and cube1 addresses
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local
assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote
# PE in cube1 can also translate both
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote
assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local
# ── F5. Replicate: local PA override ─────────────────────────────────
def test_replicate_local_pa_override():
"""For replicated tensor (same VA range), each cube's PEs see local PA."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
va, size = 0x1_0000_0000, 4096
# Cube 0 gets its own PA
msg0 = MmuMapMsg(
correlation_id="c0",
request_id="mmu_rep_c0",
entries=({"va": va, "pa": 0xA000_0000, "size": size},),
target_cubes=(0,),
target_pe="all",
)
h0 = engine.submit(msg0)
engine.wait(h0)
# Cube 1 gets a different PA for the same VA
msg1 = MmuMapMsg(
correlation_id="c0",
request_id="mmu_rep_c1",
entries=({"va": va, "pa": 0xB000_0000, "size": size},),
target_cubes=(1,),
target_pe="all",
)
h1 = engine.submit(msg1)
engine.wait(h1)
# Cube 0 PEs translate to cube 0's PA
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_c0.mmu.translate(va) == 0xA000_0000
# Cube 1 PEs translate to cube 1's PA
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
assert mmu_c1.mmu.translate(va) == 0xB000_0000
# ── F7. Overlap detection ────────────────────────────────────────────
def test_detect_overlapping_shards():
"""Utility: detect if shards have overlapping VA ranges (replicate indicator)."""
from kernbench.runtime_api.tensor import TensorShard
# Sharded: unique offsets
sharded = [
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096),
]
offsets = [(s.offset_bytes, s.nbytes) for s in sharded]
assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets"
# Replicated: same offset
replicated = [
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0),
]
offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated]
assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets"
# ── F8. Regression: existing benchmarks still pass ───────────────────
def test_qkv_gemm_still_passes():
"""QKV GEMM benchmark completes successfully with VA/MMU enabled."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import BenchResult, DeviceSelector
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_regression",
spec=graph.spec,
)
from benches.qkv_gemm import run as bench_run
bench_run(ctx)
ctx.wait_all()
# If we get here without exception, the benchmark succeeded
+6 -6
View File
@@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn():
gemm.in_ports["src"] = simpy.Store(env)
gemm.start(env)
a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn():
math_comp.in_ports["src"] = simpy.Store(env)
math_comp.start(env)
x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = MathCmd(op="exp", inputs=(x,), out=out)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -777,7 +777,7 @@ def test_tl_ref_no_dma():
tl = TLContext(pe_id=0, dispatch_cycles=0)
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
assert handle.pa == 0x1000
assert handle.addr == 0x1000
assert handle.shape == (4, 4)
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"
+203
View File
@@ -0,0 +1,203 @@
"""Tests for PeMMU: per-PE virtual-to-physical address translation.
Validates:
T1. Basic map + translate
T2. Page-aligned dict lookup (O(1), multi-page range)
T3. Multiple tensor mappings accumulate
T4. unmap removes entries, translate raises PageFault
T5. PageFault on unmapped VA
T6. Identical mapping broadcast across multiple PEs
"""
import pytest
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
_2MB = 2 * 1024 * 1024
# ── T1. Basic map + translate ────────────────────────────────────────
def test_map_and_translate_basic():
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
def test_translate_preserves_offset():
"""translate(va + offset) returns pa + offset within a page."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
def test_multi_page_mapping():
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
mmu = PeMMU(page_size=_2MB)
va_base = 0x1_0000_0000
pa_base = 0x2_0000_0000
size = 8 * 1024 * 1024 # 8 MB = 4 pages
mmu.map(va=va_base, pa=pa_base, size=size)
# First page
assert mmu.translate(va_base) == pa_base
# Second page start
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
# Third page with offset
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
# Last byte of last page
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
def test_page_table_entry_count():
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
mmu = PeMMU(page_size=_2MB)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
assert mmu.num_entries == 4
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
assert mmu.num_entries == 5
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
def test_multiple_mappings_accumulate():
"""Two non-overlapping tensors → both translate correctly."""
mmu = PeMMU(page_size=4096)
# Tensor A
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
# Tensor B (different VA range)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
def test_mappings_do_not_interfere():
"""Adjacent VA ranges map to completely independent PA ranges."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
assert mmu.translate(0x0000_1000) == 0x0000_0000
# ── T4. unmap removes entries ────────────────────────────────────────
def test_unmap_removes_mapping():
"""After unmap, translate raises PageFault."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
def test_unmap_partial_range():
"""Unmap only part of a multi-page mapping; rest stays valid."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
assert mmu.num_entries == 2
# Unmap first page only
mmu.unmap(va=0x1000_0000, size=4096)
assert mmu.num_entries == 1
with pytest.raises(PageFault):
mmu.translate(0x1000_0000)
# Second page still valid
assert mmu.translate(0x1000_1000) == 0x2000_1000
def test_unmap_does_not_affect_other_mappings():
"""Unmapping tensor A does not affect tensor B."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
def test_pagefault_on_unmapped_va():
"""translate() on never-mapped VA raises PageFault."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault):
mmu.translate(0xDEAD_BEEF)
def test_pagefault_contains_va():
"""PageFault exception carries the faulting VA."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault, match="0xdeadbeef"):
mmu.translate(0xDEAD_BEEF)
# ── T6. Identical mapping broadcast across PEs ───────────────────────
def test_broadcast_same_mapping_to_all_pes():
"""All PEs receive the same full mapping → identical translate results."""
entries = [
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
]
num_pes = 8
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
# Broadcast: every PE gets the same entries
for mmu in mmus:
for va, pa, size in entries:
mmu.map(va=va, pa=pa, size=size)
# All PEs translate identically
for mmu in mmus:
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
def test_cross_pe_access_via_broadcast():
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
mmu_pe0 = PeMMU(page_size=4096)
# Full mapping includes PE3's shard
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
# PE0 accesses PE3's region → valid translation
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
assert pa == 0xD000_0100
# ── TLB overhead attribute ───────────────────────────────────────────
def test_tlb_overhead_default():
"""Default tlb_overhead_ns is 0."""
mmu = PeMMU()
assert mmu.overhead_ns == 0.0
def test_tlb_overhead_configurable():
"""tlb_overhead_ns is configurable."""
mmu = PeMMU(overhead_ns=0.5)
assert mmu.overhead_ns == 0.5
+193
View File
@@ -0,0 +1,193 @@
"""Tests for tensor free: del-based + context manager cleanup.
Validates:
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
TF3. Context manager cleans up all tensors on exit
TF4. del after context exit is safe (no crash)
TF5. Alloc-del-alloc cycle reuses VA and PA
TF6. del already-freed tensor is safe (no crash)
"""
import gc
import pytest
from pathlib import Path
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PageFault
from kernbench.policy.placement.dp import DPPolicy
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_ctx():
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_free",
spec=graph.spec,
)
return ctx, engine
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
def test_allocator_free_hbm_reclaims_space():
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_hbm(4096)
used_after_alloc = a.hbm_used
a.free_hbm(pa1, 4096)
assert a.hbm_used == used_after_alloc - 4096
pa2 = a.alloc_hbm(4096)
assert pa2 is not None
def test_allocator_free_tcm_reclaims_space():
"""free_tcm returns TCM space."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_tcm(256)
used_after_alloc = a.tcm_used
a.free_tcm(pa1, 256)
assert a.tcm_used == used_after_alloc - 256
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
def test_del_tensor_unmaps_mmu():
"""del tensor removes MMU mappings."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
va_base = t._handle.va_base
# Verify mapping exists
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_comp.mmu.translate(va_base) is not None
# Delete tensor
del t
gc.collect()
# Mapping should be gone
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
def test_del_tensor_reclaims_va():
"""del tensor returns VA space for reuse."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
va1 = t1._handle.va_base
del t1
gc.collect()
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
va2 = t2._handle.va_base
assert va2 == va1
# ── TF3. Context manager cleanup ─────────────────────────────────────
def test_context_manager_cleans_all():
"""Exiting context manager cleans up all tensors."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="ctx_mgr",
spec=graph.spec,
) as ctx:
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
va_base = t._handle.va_base
# After context exit, MMU mappings should be cleared
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
# ── TF4. del after context exit is safe ───────────────────────────────
def test_del_after_context_exit_no_crash():
"""del tensor after context manager exit does not crash."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="safe_del",
spec=graph.spec,
)
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
# Simulate context going away
ctx.cleanup()
# del should not crash even though context already cleaned up
del t
gc.collect()
# No exception = pass
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
def test_alloc_del_cycle():
"""Multiple alloc-del cycles work correctly."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
for i in range(3):
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
assert t._handle is not None
assert t._handle.va_base > 0
del t
gc.collect()
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
def test_del_already_cleaned_tensor_no_crash():
"""del on a tensor whose handle is already None does not crash."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
ctx.cleanup() # clears all tensors
# t._handle is now None
del t # should not crash
gc.collect()
+11 -10
View File
@@ -17,31 +17,32 @@ def test_full_graph_node_count():
g = _graph()
# 1 switch
# + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn)
# + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
# + 16 cubes × (cube_comps + 8 PEs × 7 pe_comps))
# IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
# + 16 ucie_conn (4 ports × 4 connections)
# + 2 xbar_top/bot
# + 8 hbm_slices = 35
# = 1 + 2*(23 + 16*(35+48)) = 1 + 2*(23+1328) = 1 + 2702 = 2703
assert len(g.nodes) == 2703
# pe_comps: 7 (pe_cpu, pe_scheduler, pe_dma, pe_gemm, pe_math, pe_mmu, pe_tcm)
# = 1 + 2*(23 + 16*(35+56)) = 1 + 2*(23+1456) = 1 + 2958 = 2959
assert len(g.nodes) == 2959
def test_full_graph_edge_count():
g = _graph()
# Per cube: 184
# Per cube: 192
# PE-internal: 56
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8, noc→pe_mmu: 8
# xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8
# noc↔xbar_top: 2, noc↔xbar_bot: 2
# xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2
# xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2
# ucie: 64, m_cpu↔noc: 2, noc↔sram: 2
# Total: 56+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 184
# Total: 56+8+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 192
# IO edges per SIP: 77
# Per SIP: 16*184 + 48 inter-cube + 77 IO = 3069
# Total: 2 * 3069 = 6138
assert len(g.edges) == 6138
# Per SIP: 16*192 + 48 inter-cube + 77 IO = 3197
# Total: 2 * 3197 = 6394
assert len(g.edges) == 6394
# ── Full graph: specific nodes exist ─────────────────────────────────
@@ -267,7 +268,7 @@ def test_cube_view_pe_to_noc():
def test_pe_view_has_all_components():
v = _graph().pe_view
assert set(v.nodes.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
}
+1 -1
View File
@@ -23,7 +23,7 @@ def test_pe_template_components():
spec = _read_spec(TOPOLOGY_PATH)
comps = spec["cube"]["pe_template"]["components"]
assert set(comps.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
}
+3 -3
View File
@@ -34,7 +34,7 @@ def test_tl_load_generates_dma_read():
cmds = tl.commands
assert len(cmds) == 1
assert isinstance(cmds[0], DmaReadCmd)
assert cmds[0].src_pa == 0x1000
assert cmds[0].src_addr == 0x1000
assert cmds[0].nbytes == 32 * 64 * 2
@@ -47,7 +47,7 @@ def test_tl_store_generates_dma_write():
tl.store(0x2000, h)
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(cmds) == 1
assert cmds[0].dst_pa == 0x2000
assert cmds[0].dst_addr == 0x2000
assert cmds[0].nbytes == 16 * 16 * 4
@@ -148,7 +148,7 @@ def test_tl_composite_nonblocking():
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
assert len(comp_cmds) == 1
assert comp_cmds[0].op == "gemm"
assert comp_cmds[0].out_pa == 0x3000
assert comp_cmds[0].out_addr == 0x3000
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
+140
View File
@@ -0,0 +1,140 @@
"""Tests for VirtualAllocator: device-wide VA space management.
Validates:
T7. Basic VA allocation (contiguous, non-overlapping)
T8. VA free + reallocation (free-list reuse)
T9. VA space exhaustion raises AllocationError
"""
import pytest
from kernbench.policy.address.va_allocator import VirtualAllocator
_KB = 1024
_MB = 1024 * 1024
_GB = 1024 * 1024 * 1024
# ── T7. Basic VA allocation ──────────────────────────────────────────
def test_alloc_returns_aligned_va():
"""First allocation returns va_base."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
addr = va.alloc(4096)
assert addr == 0x1_0000_0000
def test_alloc_sequential_non_overlapping():
"""Two allocations return contiguous, non-overlapping VA ranges."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(8192)
assert a1 == 0x1_0000_0000
assert a2 == 0x1_0000_1000 # a1 + 4096
# No overlap
assert a2 >= a1 + 4096
def test_alloc_page_aligned():
"""Allocations are page-aligned even if requested size is not page-multiple."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(100) # < 1 page, but occupies 1 page
a2 = va.alloc(100)
assert a2 == 0x1_0000_1000 # aligned to next page
def test_alloc_large_contiguous():
"""Large allocation (multiple pages) is contiguous."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=2 * _MB)
addr = va.alloc(8 * _MB) # 4 pages
assert addr == 0x0
# Next alloc starts after 8 MB
addr2 = va.alloc(2 * _MB)
assert addr2 == 8 * _MB
# ── T8. VA free + reallocation ───────────────────────────────────────
def test_free_and_realloc():
"""Freed VA range can be reused by subsequent allocation."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
va.free(a1, 4096)
# New alloc should reuse a1's range
a3 = va.alloc(4096)
assert a3 == a1
def test_free_coalesce():
"""Freeing adjacent blocks allows larger reallocation."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
a3 = va.alloc(4096)
# Free first two (adjacent)
va.free(a1, 4096)
va.free(a2, 4096)
# Should be able to allocate 8192 contiguous from the freed region
a4 = va.alloc(8192)
assert a4 == a1 # reuses coalesced region
def test_free_out_of_order():
"""Free in non-sequential order still works."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
a3 = va.alloc(4096)
va.free(a2, 4096) # free middle
a4 = va.alloc(4096)
assert a4 == a2 # reuses middle slot
# ── T9. VA space exhaustion ──────────────────────────────────────────
def test_alloc_exhaustion():
"""Allocation beyond VA space raises AllocationError."""
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
va.alloc(4096)
va.alloc(4096)
with pytest.raises(Exception, match="[Aa]lloc|[Ee]xhaust|[Oo]ut of"):
va.alloc(4096)
def test_alloc_after_partial_free():
"""After freeing some, can allocate again within freed space."""
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
# Space is full
with pytest.raises(Exception):
va.alloc(4096)
# Free one, now can allocate again
va.free(a1, 4096)
a3 = va.alloc(4096)
assert a3 == a1
# ── Stats / inspection ───────────────────────────────────────────────
def test_used_and_total():
"""used and total properties reflect allocation state."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _MB, page_size=4096)
assert va.used == 0
assert va.total == 1 * _MB
va.alloc(4096)
assert va.used == 4096
va.alloc(8192)
assert va.used == 4096 + 8192 # page-aligned: 4096 + 8192 = 12288
+230
View File
@@ -0,0 +1,230 @@
"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
Validates:
T10. TensorHandle has va_base; TensorShard does NOT have va field
T11. deploy_tensor allocates VA + creates mapping entries
T12. Tensor.va returns the tensor's VA base
T13. tl.load/tl.store generate DMA commands with VA (not PA)
T14. Kernel VA-based offset calculation flows through DMA commands
"""
import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, ShardSpec
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
)
from kernbench.runtime_api.kernel import TensorArgShard
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
def _make_va_allocator() -> VirtualAllocator:
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
# ── T10. TensorHandle has va_base ────────────────────────────────────
def test_tensor_handle_has_va_base():
"""TensorHandle must have a 'va_base' field."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
assert th.va_base == 0x1_0000_0000
def test_tensor_handle_va_base_immutable():
"""TensorHandle.va_base is immutable (frozen dataclass)."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
with pytest.raises(AttributeError):
th.va_base = 0x2_0000_0000 # type: ignore[misc]
def test_tensor_shard_no_va_field():
"""TensorShard should NOT have a va field — va is derived from
TensorHandle.va_base + shard.offset_bytes."""
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert th.va_base is not None
assert th.va_base > 0
def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Each shard's VA is derivable: va_base + offset_bytes
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_registers_mmu_mappings():
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Every MMU should have entries (broadcast)
for mmu in mmus.values():
assert mmu.num_entries > 0
# Each shard's derived VA should translate to its PA in every MMU
for mmu in mmus.values():
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert mmu.translate(shard_va) == s.pa
# ── T12. Tensor.va property ──────────────────────────────────────────
def test_tensor_va_property():
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
from kernbench.runtime_api.tensor import Tensor
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
mmus = _make_mmus(1)
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
t._handle = deploy_tensor(
name="test",
shape=(2048,),
dtype="f16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert t.va > 0
assert t.va == t._handle.va_base
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
def test_tl_load_uses_va_in_dma_cmd():
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
tl = TLContext(dispatch_cycles=0)
va_ptr = 0x1_0000_0000
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
# The DMA command should carry the VA
assert dma_cmds[0].src_addr == va_ptr
def test_tl_store_uses_va_in_dma_cmd():
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
tl = TLContext(dispatch_cycles=0)
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
va_out = 0x2_0000_0000
tl.store(va_out, h)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(dma_cmds) == 1
assert dma_cmds[0].dst_addr == va_out
# ── T14. Kernel VA offset calculation ─────────────────────────────────
def test_kernel_va_offset_in_dma():
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
pid = tl.program_id(0)
elem_bytes = 2 # f16
offset = pid * BLOCK_SIZE * elem_bytes
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
va_base = 0x1_0000_0000
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
run_kernel(tiled_kernel, tl, a_ptr=va_base)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
assert dma_cmds[0].src_addr == expected_va
+1
View File
@@ -65,6 +65,7 @@ cube:
pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } }
pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } }
pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }
pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } }
pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs:
{ size_mb: 16 } }
links: