diff --git a/SPEC.md b/SPEC.md index a850c60..1aeb0ea 100644 --- a/SPEC.md +++ b/SPEC.md @@ -207,12 +207,15 @@ benchmark instances by default. ## R10. Memory Addressing (Phase 0) -In Phase 0, the simulator uses a **PA-first memory model**: +The simulator uses a **VA/PA memory model** (ADR-0011): -- All memory operations use device physical addresses (PA) only. -- Virtual addressing, MMU/IOMMU, and address translation latency are out of scope. +- Tensors are assigned a contiguous virtual address (VA) range at deployment. +- PE_MMU translates VA→PA per access; TLB overhead is configurable. +- Mapping installation (MmuMapMsg) traverses the fabric with measured latency. +- Replicate tensors use per-cube local PA mapping; sharded tensors broadcast. +- PA-only fallback is retained for backward compatibility. - Tensor placement is represented as a list of PA shards, each explicitly tagged - with `(sip, cube, pe)`. + with `(sip, cube, pe)`, plus a tensor-wide `va_base`. All memory access latency MUST be modeled explicitly via graph traversal. No implicit translation or hidden latency is allowed. diff --git a/benches/ipcq_allreduce.py b/benches/ipcq_allreduce.py index 99e5217..798173e 100644 --- a/benches/ipcq_allreduce.py +++ b/benches/ipcq_allreduce.py @@ -1,2 +1,2 @@ -def run(ctx): +def run(torch): print("IPCQ all reduce kernel bench") diff --git a/benches/loader.py b/benches/loader.py index e78e1a2..abc5ac7 100644 --- a/benches/loader.py +++ b/benches/loader.py @@ -15,7 +15,7 @@ def resolve_bench(bench_id: str) -> BenchFn: Expected layout (repo root): benches/.py - def run(ctx: RuntimeContext) -> Any + def run(torch: RuntimeContext) -> Any """ bench_id = bench_id.strip() if not bench_id: @@ -30,7 +30,7 @@ def resolve_bench(bench_id: str) -> BenchFn: run_fn = getattr(mod, "run", None) if run_fn is None: - raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.") + raise ValueError(f"Bench module {module_path} must define a 'run(torch)' function.") if not callable(run_fn): raise ValueError(f"'run' in {module_path} is not callable.") diff --git a/benches/qkv_gemm.py b/benches/qkv_gemm.py index 7c92569..492c632 100644 --- a/benches/qkv_gemm.py +++ b/benches/qkv_gemm.py @@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): tl.wait(handle) -def run(ctx): +def run(torch): """Run the QKV GEMM benchmark.""" # DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE) - a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a") - b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b") - out = ctx.empty( + a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a") + b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b") + out = torch.empty( (M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out", ) # Launch GEMM kernel - ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N) + torch.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N) diff --git a/benches/qkv_gemm_multi_pe.py b/benches/qkv_gemm_multi_pe.py index 2b7bd87..e8e8649 100644 --- a/benches/qkv_gemm_multi_pe.py +++ b/benches/qkv_gemm_multi_pe.py @@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): tl.wait(handle) -def run(ctx): +def run(torch): """Run the multi-PE QKV GEMM benchmark.""" # DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split) - a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a") - b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b") - out = ctx.empty( + a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a") + b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b") + out = torch.empty( (M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out", ) # Launch GEMM kernel on all PEs - ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N) + torch.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N) diff --git a/docs/adr/ADR-0011-memory-addressing-simplification.md b/docs/adr/ADR-0011-memory-addressing-simplification.md index 3fa7003..96a4c97 100644 --- a/docs/adr/ADR-0011-memory-addressing-simplification.md +++ b/docs/adr/ADR-0011-memory-addressing-simplification.md @@ -1,8 +1,8 @@ -# ADR-0011: Memory Addressing Simplification (PA-first) +# ADR-0011: Memory Addressing — PA-first with VA/MMU Extension ## Status -Accepted +Accepted (Phase 1 VA/MMU implemented) ## Context @@ -11,49 +11,82 @@ translation path for DMA: host allocates physical memory at PE level, maps it into a virtual address space, installs mappings, and DMA requests use virtual addresses that are translated to physical addresses. -For early development, we want a minimal, deterministic model that enables: - -- correct routing and latency accounting through the graph, -- stable tensor deployment and kernel execution semantics, -- future extension toward VA/MMU without rewriting workflows. +The PA-only model (Phase 0) was insufficient for running standard Triton kernels +that use `base_addr + offset` patterns on sharded tensors — each PE's shard has +a different PA, but the kernel needs a single contiguous address space. --- ## Decision -### D1. Phase 0 model is PA-only - -The simulator uses a PA-first model: +### D1. Phase 0 model is PA-only (original, retained as fallback) - All device memory accesses (MemoryRead/MemoryWrite) operate on device physical addresses (PA) plus size. -- Tensor handles store PA-based shard mappings after deployment. -- KernelLaunch passes tensor arguments as PA-based mappings (or references to them). -- MMU/IOMMU concepts (virtual address spaces, page tables, translation latency) - are NOT modeled in Phase 0. +- PA-only mode remains functional via PageFault fallback in PE_DMA. ### D2. Allocation produces PA mappings Device allocation selects PE-local memory regions and returns PA mappings sufficient to execute kernels and issue DMA requests. -### D3. Extension path (non-breaking) +### D3. Phase 1: VA/MMU layer (implemented) -A future ADR MAY introduce an optional VA/MMU layer by: +#### D3.1 Virtual Address Model -- introducing virtual addresses in tensor handles, -- adding a mapping-install step, -- modeling translation latency and page granularity. +- Each tensor gets a single contiguous VA range (`TensorHandle.va_base`). +- `TensorShard` does NOT carry a `va` field — shard VA is derived as + `va_base + offset_bytes`. +- Kernels receive `va_base` as their pointer argument (via `TensorArg.va_base`). +- `DmaReadCmd.src_addr` and `DmaWriteCmd.dst_addr` carry VA (not PA). -The Phase 0 PA model remains a valid fast-path configuration. +#### D3.2 PE_MMU Component + +- Hybrid design: SimPy component (inbox for MmuMapMsg) + utility (synchronous + `translate()` called by PE_DMA). +- Page-aligned dict lookup for O(1) VA→PA translation. +- `tlb_overhead_ns` configurable per-access latency. +- PageFault fallback: if VA has no mapping, PE_DMA treats it as PA directly + (backward compatibility with PA-only tests). + +#### D3.3 Mapping Installation + +- `MmuMapMsg` traverses the fabric: Host → PCIE_EP → IO_CPU (cube fan-out) → + M_CPU (PE fan-out) → NOC → PE_MMU. Latency is measured end-to-end. +- `MmuMapMsg.target_sips` controls SIP-level routing to prevent cross-SIP + mapping contamination for replicated tensors. +- Mapping strategy based on `DPPolicy.cube`: + - **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only. + Each cube's PEs see only their local PA. No cross-cube mapping installed. + - **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all + target cubes. Enables cross-PE and cross-cube DMA. + +#### D3.4 Tensor Lifecycle + +- `del tensor` triggers automatic cleanup via `Tensor.__del__` + `weakref` to + RuntimeContext. Sends `MmuUnmapMsg` through fabric, returns VA and PA space. +- `with RuntimeContext(...) as ctx:` provides scope-based bulk cleanup. +- `RuntimeContext._tensors` uses `weakref.ref` to avoid preventing GC. +- `PEMemAllocator` uses free-list with coalescing (not bump allocator). +- `VirtualAllocator` uses free-list with coalescing for VA space. + +#### D3.5 Allocators + +- `VirtualAllocator`: device-wide VA space, page-aligned alloc/free with + coalescing. +- `PEMemAllocator`: per-PE HBM/TCM, free-list based alloc/free with coalescing. +- Page size configurable via `topology.yaml` pe_mmu attrs (default 4096). --- ## Consequences -- Early implementation stays simple and testable. -- All latency remains explicit via graph traversal, not hidden translation. -- Future VA/MMU modeling can be added without breaking existing benchmarks. +- Triton kernels use `base_addr + offset` patterns naturally on sharded tensors. +- All latency remains explicit via graph traversal, including MMU mapping + installation and per-access TLB overhead. +- PA-only mode retained as fallback (PageFault → treat as PA). +- Benchmark parameter renamed `ctx` → `torch` for PyTorch code compatibility. +- IPCQ and other fixed-address resources bypass MMU (use PA directly). --- @@ -62,4 +95,6 @@ The Phase 0 PA model remains a valid fast-path configuration. - ADR-0007 (runtime_api vs sim_engine boundaries) - ADR-0008 (tensor deployment) - ADR-0009 (kernel execution) +- ADR-0014 (PE-internal execution model) +- ADR-0015 (component port/wire model) - SPEC R2 (latency by traversal) diff --git a/src/kernbench/common/pe_commands.py b/src/kernbench/common/pe_commands.py index d1d2c39..c6bf991 100644 --- a/src/kernbench/common/pe_commands.py +++ b/src/kernbench/common/pe_commands.py @@ -28,7 +28,7 @@ class TensorHandle: """ id: str - pa: int # physical address in HBM/TCM + addr: int # address (VA when MMU enabled, PA otherwise) shape: tuple[int, ...] dtype: str nbytes: int # total byte size @@ -50,19 +50,19 @@ class CompletionHandle: @dataclass(frozen=True) class DmaReadCmd: - """DMA READ: HBM → PE_TCM.""" + """DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA).""" handle: TensorHandle - src_pa: int + src_addr: int nbytes: int @dataclass(frozen=True) class DmaWriteCmd: - """DMA WRITE: PE_TCM → HBM.""" + """DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA).""" handle: TensorHandle - dst_pa: int + dst_addr: int nbytes: int @@ -108,7 +108,7 @@ class CompositeCmd: op: Literal["gemm", "math"] a: TensorHandle b: TensorHandle | None - out_pa: int + out_addr: int out_nbytes: int math_op: str | None = None # for op="math": which math operation diff --git a/src/kernbench/components/impls/__init__.py b/src/kernbench/components/impls/__init__.py index 38e68e4..cd170ef 100644 --- a/src/kernbench/components/impls/__init__.py +++ b/src/kernbench/components/impls/__init__.py @@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent from kernbench.components.impls.pe_gemm import PeGemmComponent from kernbench.components.impls.pe_math import PeMathComponent from kernbench.components.impls.pe_scheduler import PeSchedulerComponent +from kernbench.components.impls.pe_mmu import PeMmuComponent from kernbench.components.impls.pe_tcm import PeTcmComponent from kernbench.components.impls.sram import SramComponent from kernbench.components.impls.xbar import PositionAwareXbarComponent @@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent) ComponentRegistry.register("pe_dma_v1", PeDmaComponent) ComponentRegistry.register("pe_gemm_v1", PeGemmComponent) ComponentRegistry.register("pe_math_v1", PeMathComponent) +ComponentRegistry.register("pe_mmu_v1", PeMmuComponent) ComponentRegistry.register("pe_tcm_v1", PeTcmComponent) __all__ = [ @@ -47,6 +49,7 @@ __all__ = [ "PeDmaComponent", "PeGemmComponent", "PeMathComponent", + "PeMmuComponent", "PeSchedulerComponent", "PeTcmComponent", "TransitComponent", diff --git a/src/kernbench/components/impls/io_cpu.py b/src/kernbench/components/impls/io_cpu.py index ad123a6..83f2b8a 100644 --- a/src/kernbench/components/impls/io_cpu.py +++ b/src/kernbench/components/impls/io_cpu.py @@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase): def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]: """Return list of (sip, cube) pairs to fan out to.""" - from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg + from kernbench.runtime_api.kernel import ( + KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg, + ) target_cubes = getattr(request, "target_cubes", "all") @@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase): targets.append(key) return targets + if isinstance(request, (MmuMapMsg, MmuUnmapMsg)): + my_sip = self._my_sip() + if target_cubes == "all": + n_cubes = 16 + if self.ctx and self.ctx.spec: + sips = self.ctx.spec.get("system", {}).get("sips", {}) + n_cubes = sips.get("cubes_per_sip", 16) + return [(my_sip, c) for c in range(n_cubes)] + return [(my_sip, c) for c in target_cubes] + return [] def _cube_from_pa(self, pa_val: int, fallback: int) -> int: diff --git a/src/kernbench/components/impls/m_cpu.py b/src/kernbench/components/impls/m_cpu.py index c818a16..40c9ae5 100644 --- a/src/kernbench/components/impls/m_cpu.py +++ b/src/kernbench/components/impls/m_cpu.py @@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase): def _worker(self, env: simpy.Environment) -> Generator: """Dispatch forward txns, collect response txns.""" - from kernbench.runtime_api.kernel import KernelLaunchMsg + from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg while True: txn: Any = yield self._inbox.get() @@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase): elif self.ctx is not None and txn.request is not None: if isinstance(txn.request, KernelLaunchMsg): env.process(self._kernel_launch_fanout(env, txn)) + elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)): + env.process(self._mmu_msg_fanout(env, txn)) else: env.process(self._dma_fanout(env, txn)) else: @@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase): n_slices = mm.get("hbm_slices_per_cube", 8) return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)] + def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator: + """Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC. + + Routes through find_node_path (M_CPU → NOC → PE_MMU command edges). + PE_MMU is a terminal node — completes the transaction directly. + """ + request = txn.request + target_pe = getattr(request, "target_pe", "all") + cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" + pe_ids = self._resolve_pe_ids(target_pe) + + if not pe_ids: + txn.done.succeed() + return + + # Fan out to each PE_MMU + sub_dones: list[simpy.Event] = [] + for pe_id in pe_ids: + pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu" + try: + path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id) + except Exception: + continue + if len(path) < 2: + continue + sub_done = env.event() + sub_txn = Transaction( + request=request, path=path, step=0, + nbytes=0, done=sub_done, + ) + yield self.out_ports[path[1]].put(sub_txn.advance()) + sub_dones.append(sub_done) + + # Wait for all PE_MMUs to complete + for sd in sub_dones: + yield sd + + # Send aggregate response on reverse path + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + resp_msg = ResponseMsg( + correlation_id=request.correlation_id, + request_id=request.request_id, + src_cube=cube_id, src_pe=-1, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() + def _resolve_pe_ids(self, target_pe: int | str) -> list[int]: """Return list of PE IDs to fan out to (used by kernel launch fan-out).""" if isinstance(target_pe, int): diff --git a/src/kernbench/components/impls/pe_cpu.py b/src/kernbench/components/impls/pe_cpu.py index 6274c6e..34fcf8e 100644 --- a/src/kernbench/components/impls/pe_cpu.py +++ b/src/kernbench/components/impls/pe_cpu.py @@ -84,12 +84,15 @@ class PeCpuComponent(ComponentBase): tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0) # Unpack KernelLaunchMsg.args into positional args for kernel function - # TensorArg → PA (pointer), ScalarArg → value + # TensorArg → VA base (or PA fallback), ScalarArg → value kernel_args: list = [] for arg in request.args: if arg.arg_kind == "tensor": - shard = self._find_shard(arg.shards) - kernel_args.append(shard.pa) + if arg.va_base: + kernel_args.append(arg.va_base) + else: + shard = self._find_shard(arg.shards) + kernel_args.append(shard.pa) elif arg.arg_kind == "scalar": kernel_args.append(arg.value) diff --git a/src/kernbench/components/impls/pe_dma.py b/src/kernbench/components/impls/pe_dma.py index 40830bf..857456c 100644 --- a/src/kernbench/components/impls/pe_dma.py +++ b/src/kernbench/components/impls/pe_dma.py @@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase): super().__init__(node, ctx) self._dma_read: simpy.Resource | None = None self._dma_write: simpy.Resource | None = None + self._mmu = None # PeMMU instance, set by engine wiring def init_resources(self, env: simpy.Environment) -> None: self._dma_read = simpy.Resource(env, capacity=1) @@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase): cmd = pe_txn.command assert self._dma_read is not None and self._dma_write is not None - # Determine direction and target PA + # Determine direction and target address (VA → PA via MMU) if isinstance(cmd, DmaReadCmd): dma_res = self._dma_read - target_pa = cmd.src_pa + raw_addr = cmd.src_addr is_write = False elif isinstance(cmd, DmaWriteCmd): dma_res = self._dma_write - target_pa = cmd.dst_pa + raw_addr = cmd.dst_addr is_write = True else: pe_txn.done.succeed() return - # Resolve PA → HBM node and compute path + # Translate VA → PA via MMU (if available), then resolve HBM node + # If MMU has no mapping for this address (PageFault), treat as PA directly + # (backward-compatible with PA-only mode) + if self._mmu is not None: + from kernbench.policy.address.pe_mmu import PageFault + try: + target_pa = self._mmu.translate(raw_addr) + if self._mmu.overhead_ns > 0: + yield env.timeout(self._mmu.overhead_ns) + except PageFault: + target_pa = raw_addr + else: + target_pa = raw_addr # fallback: treat as PA directly pa = PhysAddr.decode(target_pa) dst_node = self.ctx.resolver.resolve(pa) path = self.ctx.router.find_path(self._pe_prefix, dst_node) diff --git a/src/kernbench/components/impls/pe_mmu.py b/src/kernbench/components/impls/pe_mmu.py new file mode 100644 index 0000000..3481cc4 --- /dev/null +++ b/src/kernbench/components/impls/pe_mmu.py @@ -0,0 +1,66 @@ +"""PE_MMU component: address translation unit. + +Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU). +Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead). +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase, ComponentRegistry +from kernbench.policy.address.pe_mmu import PeMMU + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeMmuComponent(ComponentBase): + """PE_MMU: per-PE virtual-to-physical address translation. + + Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal + page table. PE_DMA and PE_GEMM access the underlying PeMMU object + via the ``mmu`` property for synchronous VA→PA translation. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024)) + overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0)) + self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns) + + @property + def mmu(self) -> PeMMU: + """The underlying PeMMU utility object for direct translate() calls.""" + return self._mmu + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + yield env.timeout(0) + + def _worker(self, env: simpy.Environment) -> Generator: + """Process MmuMapMsg/MmuUnmapMsg from inbox.""" + from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg + + while True: + txn: Any = yield self._inbox.get() + + if hasattr(txn, "request"): + request = txn.request + if isinstance(request, MmuMapMsg): + for entry in request.entries: + self._mmu.map( + va=entry["va"], pa=entry["pa"], size=entry["size"], + ) + txn.done.succeed() + elif isinstance(request, MmuUnmapMsg): + for entry in request.entries: + self._mmu.unmap(va=entry["va"], size=entry["size"]) + txn.done.succeed() + else: + # Forward non-MMU transactions normally + yield from self._forward_txn(env, txn) + else: + yield from self._forward_txn(env, txn) diff --git a/src/kernbench/components/impls/pe_scheduler.py b/src/kernbench/components/impls/pe_scheduler.py index d196759..daa7c3a 100644 --- a/src/kernbench/components/impls/pe_scheduler.py +++ b/src/kernbench/components/impls/pe_scheduler.py @@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase): # --- Stage 1: DMA_READ b_tile from HBM --- read_done = env.event() - b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes + b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes b_tile_handle = TensorHandle( - id=f"b_tile_{tile_idx}", pa=b_tile_pa, + id=f"b_tile_{tile_idx}", addr=b_tile_addr, shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes, ) - read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes) + read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes) read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp) t0 = env.now yield self.out_ports[f"{pp}.pe_dma"].put(read_txn) @@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase): # --- Stage 2: COMPUTE (GEMM) --- compute_done = env.event() out_handle = TensorHandle( - id=f"out_tile_{tile_idx}", pa=0, + id=f"out_tile_{tile_idx}", addr=0, shape=(M, tile_n), dtype=dtype, nbytes=M * tile_n * dtype_bytes, ) @@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase): # --- Stage 3: DMA_WRITE out_tile to HBM --- write_done = env.event() - out_tile_pa = cmd.out_pa + n_start * dtype_bytes + out_tile_pa = cmd.out_addr + n_start * dtype_bytes write_nbytes = M * tile_n * dtype_bytes - write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes) + write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes) write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) t0 = env.now yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) @@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase): # Step 2: DMA_WRITE result to HBM write_done = env.event() - write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes) + write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes) write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) yield write_done diff --git a/src/kernbench/policy/address/allocator.py b/src/kernbench/policy/address/allocator.py index 5d10bb4..068bc2d 100644 --- a/src/kernbench/policy/address/allocator.py +++ b/src/kernbench/policy/address/allocator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import bisect from dataclasses import dataclass from kernbench.policy.address.phyaddr import PhysAddr @@ -29,6 +30,63 @@ class AddressConfig: return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes +class _FreeList: + """Offset-based free-list allocator with coalescing.""" + + def __init__(self, capacity: int) -> None: + self._capacity = capacity + self._used = 0 + self._free: list[tuple[int, int]] = [(0, capacity)] # (offset, size) + + @property + def used(self) -> int: + return self._used + + @property + def total(self) -> int: + return self._capacity + + def alloc(self, nbytes: int) -> int: + """Allocate nbytes, return offset. Raises AllocationError if full.""" + for i, (start, size) in enumerate(self._free): + if size >= nbytes: + if size == nbytes: + self._free.pop(i) + else: + self._free[i] = (start + nbytes, size - nbytes) + self._used += nbytes + return start + raise AllocationError( + f"overflow: need {nbytes}, " + f"largest free block {max((s for _, s in self._free), default=0)}" + ) + + def free(self, offset: int, nbytes: int) -> None: + """Return a range to the free-list with coalescing.""" + self._used -= nbytes + new_start = offset + new_end = offset + nbytes + + idx = bisect.bisect_left(self._free, (offset,)) + + # Coalesce with previous block + if idx > 0: + prev_start, prev_size = self._free[idx - 1] + if prev_start + prev_size == new_start: + new_start = prev_start + idx -= 1 + self._free.pop(idx) + + # Coalesce with next block + if idx < len(self._free): + next_start, next_size = self._free[idx] + if new_end == next_start: + new_end = next_start + next_size + self._free.pop(idx) + + self._free.insert(idx, (new_start, new_end - new_start)) + + class PEMemAllocator: def __init__( self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig, @@ -38,39 +96,48 @@ class PEMemAllocator: self._cube_id = cube_id self._pe_id = pe_id self._cfg = cfg - self._hbm_cursor = 0 - self._tcm_cursor = 0 + self._hbm = _FreeList(cfg.hbm_slice_bytes) + self._tcm = _FreeList(cfg.tcm_allocatable_bytes) def alloc_hbm(self, nbytes: int) -> PhysAddr: - if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes: + try: + offset = self._hbm.alloc(nbytes) + except AllocationError: raise AllocationError( f"HBM overflow: need {nbytes}, " - f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}" + f"available {self._cfg.hbm_slice_bytes - self._hbm.used}" ) - pa = PhysAddr.pe_hbm_addr( + return PhysAddr.pe_hbm_addr( rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id, - pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor, + pe_id=self._pe_id, pe_local_hbm_offset=offset, slice_size_bytes=self._cfg.hbm_slice_bytes, ) - self._hbm_cursor += nbytes - return pa + + def free_hbm(self, pa: PhysAddr, nbytes: int) -> None: + # Extract PE-local offset from the PA's hbm_offset + pe_slice_start = self._pe_id * self._cfg.hbm_slice_bytes + offset = pa.hbm_offset - pe_slice_start + self._hbm.free(offset, nbytes) def alloc_tcm(self, nbytes: int) -> PhysAddr: - if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes: + try: + offset = self._tcm.alloc(nbytes) + except AllocationError: raise AllocationError( f"TCM overflow: need {nbytes}, " - f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}" + f"available {self._cfg.tcm_allocatable_bytes - self._tcm.used}" ) - pa = PhysAddr.pe_tcm_addr( + return PhysAddr.pe_tcm_addr( rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id, - pe_id=self._pe_id, tcm_offset=self._tcm_cursor, + pe_id=self._pe_id, tcm_offset=offset, ) - self._tcm_cursor += nbytes - return pa + + def free_tcm(self, pa: PhysAddr, nbytes: int) -> None: + self._tcm.free(pa.sub_offset, nbytes) @property def hbm_used(self) -> int: - return self._hbm_cursor + return self._hbm.used @property def hbm_total(self) -> int: @@ -78,7 +145,7 @@ class PEMemAllocator: @property def tcm_used(self) -> int: - return self._tcm_cursor + return self._tcm.used @property def tcm_total(self) -> int: diff --git a/src/kernbench/policy/address/pe_mmu.py b/src/kernbench/policy/address/pe_mmu.py new file mode 100644 index 0000000..6080b24 --- /dev/null +++ b/src/kernbench/policy/address/pe_mmu.py @@ -0,0 +1,66 @@ +"""PeMMU: per-PE virtual-to-physical address translation. + +Page-aligned dict lookup for O(1) VA→PA translation. +Used as a utility class by PE_DMA, PE_GEMM (direct call), +and as a component inbox target for MmuMapMsg/MmuUnmapMsg. +""" +from __future__ import annotations + + +class PageFault(Exception): + """Raised when VA has no mapping in the page table.""" + + def __init__(self, va: int | str) -> None: + if isinstance(va, str): + super().__init__(va) + else: + self.va = va + super().__init__(f"PageFault at VA 0x{va:x}") + + +class PeMMU: + """Per-PE MMU with page-aligned VA→PA translation table. + + Args: + page_size: Page size in bytes (default 2 MB). + overhead_ns: Per-access TLB lookup latency in nanoseconds. + """ + + def __init__( + self, + page_size: int = 2 * 1024 * 1024, + overhead_ns: float = 0.0, + ) -> None: + self._page_size = page_size + self._page_shift = (page_size - 1).bit_length() + self._page_mask = page_size - 1 + self._table: dict[int, int] = {} # va_page_number → pa_page_base + self._overhead_ns = overhead_ns + + @property + def overhead_ns(self) -> float: + return self._overhead_ns + + @property + def num_entries(self) -> int: + return len(self._table) + + def map(self, va: int, pa: int, size: int) -> None: + """Register VA→PA mapping for a contiguous range.""" + for off in range(0, size, self._page_size): + vpn = (va + off) >> self._page_shift + self._table[vpn] = pa + off + + def unmap(self, va: int, size: int) -> None: + """Remove VA mapping for a contiguous range.""" + for off in range(0, size, self._page_size): + vpn = (va + off) >> self._page_shift + self._table.pop(vpn, None) + + def translate(self, va: int) -> int: + """Translate VA to PA. Raises PageFault if unmapped.""" + vpn = va >> self._page_shift + pa_page_base = self._table.get(vpn) + if pa_page_base is None: + raise PageFault(va) + return pa_page_base + (va & self._page_mask) diff --git a/src/kernbench/policy/address/va_allocator.py b/src/kernbench/policy/address/va_allocator.py new file mode 100644 index 0000000..6f5e811 --- /dev/null +++ b/src/kernbench/policy/address/va_allocator.py @@ -0,0 +1,93 @@ +"""VirtualAllocator: device-wide VA space management with free-list. + +Allocations are page-aligned. Freed ranges are coalesced with adjacent +free blocks to allow larger re-allocations. +""" +from __future__ import annotations + +import bisect + + +class VaAllocationError(Exception): + pass + + +class VirtualAllocator: + """Manages a contiguous VA address space with page-aligned alloc/free. + + Args: + va_base: Start of the VA range. + va_size: Total size of the VA range in bytes. + page_size: Page granularity in bytes. + """ + + def __init__( + self, + va_base: int, + va_size: int, + page_size: int = 2 * 1024 * 1024, + ) -> None: + self._va_base = va_base + self._va_size = va_size + self._page_size = page_size + self._used = 0 + # Free list: sorted list of (start, size) tuples + self._free: list[tuple[int, int]] = [(va_base, va_size)] + + def _align_up(self, nbytes: int) -> int: + """Round up to page boundary.""" + return ((nbytes + self._page_size - 1) // self._page_size) * self._page_size + + def alloc(self, nbytes: int) -> int: + """Allocate a contiguous VA range. Returns the start VA.""" + aligned = self._align_up(nbytes) + for i, (start, size) in enumerate(self._free): + if size >= aligned: + # Take from the beginning of this free block + if size == aligned: + self._free.pop(i) + else: + self._free[i] = (start + aligned, size - aligned) + self._used += aligned + return start + raise VaAllocationError( + f"Out of VA space: need {aligned}, largest free block " + f"{max((s for _, s in self._free), default=0)}" + ) + + def free(self, va: int, nbytes: int) -> None: + """Free a VA range and coalesce with adjacent free blocks.""" + aligned = self._align_up(nbytes) + self._used -= aligned + + # Insert into sorted free list and coalesce + new_start = va + new_end = va + aligned + + # Find insertion point + idx = bisect.bisect_left(self._free, (va,)) + + # Try coalesce with previous block + if idx > 0: + prev_start, prev_size = self._free[idx - 1] + if prev_start + prev_size == new_start: + new_start = prev_start + idx -= 1 + self._free.pop(idx) + + # Try coalesce with next block + if idx < len(self._free): + next_start, next_size = self._free[idx] + if new_end == next_start: + new_end = next_start + next_size + self._free.pop(idx) + + self._free.insert(idx, (new_start, new_end - new_start)) + + @property + def used(self) -> int: + return self._used + + @property + def total(self) -> int: + return self._va_size diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index e9cf270..021babe 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -19,8 +19,18 @@ class RuntimeContext: _handles: list[RequestHandle] = field(default_factory=list, init=False) _completed: set[RequestHandle] = field(default_factory=set, init=False) _allocators: dict[int, Any] = field(default_factory=dict, init=False) + _va_allocator: Any = field(default=None, init=False) + _mmus: dict[int, Any] = field(default_factory=dict, init=False) _tensor_counter: int = field(default=0, init=False) _traces: list[dict] = field(default_factory=list, init=False) + _tensors: list[Any] = field(default_factory=list, init=False) + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.cleanup() + return False def submit(self, request: Any) -> RequestHandle: submit_fn = getattr(self.engine, "submit", None) @@ -58,6 +68,92 @@ class RuntimeContext: def handles(self) -> list[RequestHandle]: return list(self._handles) + # ── Tensor lifecycle ───────────────────────────────────────────── + + def _free_tensor(self, tensor: Any) -> None: + """Free a single tensor: unmap MMU, return VA and PA.""" + handle = tensor._handle + if handle is None: + return + tensor._handle = None + + if not handle.va_base: + return + + from kernbench.runtime_api.kernel import MmuUnmapMsg + + dp_policy = None + if tensor._dp_metadata is not None: + dp_policy = tensor._dp_metadata.dp_policy + + is_cube_replicate = ( + dp_policy is not None and dp_policy.cube == "replicate" + ) + + # Send MmuUnmapMsg through fabric + from collections import defaultdict + if is_cube_replicate: + cube_groups: dict[tuple[int, int], list] = defaultdict(list) + for shard in handle.shards: + cube_groups[(shard.sip, shard.cube)].append(shard) + for (sip, cube), group_shards in cube_groups.items(): + entries = tuple( + {"va": handle.va_base + s.offset_bytes, "size": s.nbytes} + for s in group_shards + ) + msg = MmuUnmapMsg( + correlation_id=self.correlation_id, + request_id=f"unmap_{tensor.name}_s{sip}c{cube}", + entries=entries, + target_sips=(sip,), + target_cubes=(cube,), + target_pe="all", + ) + h = self.submit(msg) + self.wait(h) + else: + entries = tuple( + {"va": handle.va_base + s.offset_bytes, "size": s.nbytes} + for s in handle.shards + ) + sip_set = sorted({s.sip for s in handle.shards}) + cube_set = sorted({s.cube for s in handle.shards}) + msg = MmuUnmapMsg( + correlation_id=self.correlation_id, + request_id=f"unmap_{tensor.name}", + entries=entries, + target_sips=tuple(sip_set), + target_cubes=tuple(cube_set), + target_pe="all", + ) + h = self.submit(msg) + self.wait(h) + + # Return VA space + if self._va_allocator is not None: + self._va_allocator.free(handle.va_base, handle.nbytes) + + # Return PA space + if self._allocators: + for shard in handle.shards: + flat_idx = ( + shard.sip * self._num_cubes * self._pes_per_cube + + shard.cube * self._pes_per_cube + + shard.pe + ) + alloc = self._allocators.get(flat_idx) + if alloc is not None: + from kernbench.policy.address.phyaddr import PhysAddr + alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes) + + def cleanup(self) -> None: + """Free all tensors created by this context.""" + for ref in self._tensors: + t = ref() + if t is not None and t._handle is not None: + self._free_tensor(t) + self._tensors.clear() + # ── PyTorch-like tensor API ────────────────────────────────────── def _ensure_allocators(self) -> dict: @@ -111,6 +207,26 @@ class RuntimeContext: self._allocators[flat_idx] = PEMemAllocator( rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg, ) + + # Initialize VA allocator and per-PE MMUs + from kernbench.policy.address.pe_mmu import PeMMU + from kernbench.policy.address.va_allocator import VirtualAllocator + + pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {}) + page_size = int(pe_mmu_attrs.get("page_size", 4096)) + tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0)) + + self._va_allocator = VirtualAllocator( + va_base=0x1_0000_0000, + va_size=64 * (1 << 30), # 64 GB VA space + page_size=page_size, + ) + total_pes = sip_count * cubes_per_sip * pes_per_cube + for flat_idx in range(total_pes): + self._mmus[flat_idx] = PeMMU( + page_size=page_size, overhead_ns=tlb_overhead_ns, + ) + return self._allocators def _next_tensor_name(self) -> str: @@ -122,63 +238,57 @@ class RuntimeContext: shape: tuple[int, ...], dtype: str = "f16", *, - placement: list | None = None, dp: Any = None, name: str | None = None, ): """Create a tensor and deploy to HBM with zero-fill (like torch.zeros).""" - return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp) + return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp) def empty( self, shape: tuple[int, ...], dtype: str = "f16", *, - placement: list | None = None, dp: Any = None, name: str | None = None, ): """Allocate a tensor in HBM without initialization (like torch.empty).""" - return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp) + return self._create_tensor(shape, dtype, name, pattern=None, dp=dp) def _create_tensor( self, shape: tuple[int, ...], dtype: str, - placement: list | None, name: str | None, pattern: str | None, dp: Any = None, ): - from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy + from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy from kernbench.runtime_api.kernel import MemoryWriteMsg from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize + if not isinstance(dp, DPPolicy): + raise ValueError("dp=DPPolicy(...) is required for tensor creation") + tensor_name = name or self._next_tensor_name() t = Tensor(shape=shape, dtype=dtype, name=tensor_name) - dp_policy: DPPolicy | None = None - - # Resolve placement: dp= takes priority over placement= - if dp is not None and isinstance(dp, DPPolicy): - dp_policy = dp - allocators = self._ensure_allocators() - itemsize = dtype_itemsize(dtype) - shape_2d = (shape[0], shape[1]) # type: tuple[int, int] - total_cubes = self._num_sips * self._num_cubes - placement = resolve_dp_policy( - dp, shape=shape_2d, itemsize=itemsize, - num_pe=self._pes_per_cube, num_cubes=total_cubes, - ) - elif placement is None: - placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)] + dp_policy = dp + allocators = self._ensure_allocators() + itemsize = dtype_itemsize(dtype) + shape_2d = (shape[0], shape[1]) # type: tuple[int, int] + total_cubes = self._num_sips * self._num_cubes + placement = resolve_dp_policy( + dp, shape=shape_2d, itemsize=itemsize, + num_pe=self._pes_per_cube, num_cubes=total_cubes, + ) # Infer target_pe from placement: multi-PE → "all", single PE → pe_index pe_indices = {s.pe_index for s in placement} target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices)) t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy) - # Allocate PAs via PEMemAllocator + # Allocate PAs via PEMemAllocator + VA via VirtualAllocator allocators = self._ensure_allocators() handle = deploy_tensor( name=tensor_name, @@ -186,8 +296,64 @@ class RuntimeContext: dtype=dtype, placement=placement, allocators=allocators, + va_allocator=self._va_allocator, + mmus=self._mmus, ) t._handle = handle + import weakref + t._ctx_ref = weakref.ref(self) + self._tensors.append(weakref.ref(t)) + + # Install VA→PA mappings via fabric MmuMapMsg + if handle.va_base: + from collections import defaultdict + from kernbench.runtime_api.kernel import MmuMapMsg + + is_cube_replicate = ( + dp_policy is not None and dp_policy.cube == "replicate" + ) + + if is_cube_replicate: + # Replicate: each (sip, cube) gets only its own local PA mappings + cube_groups: dict[tuple[int, int], list] = defaultdict(list) + for shard in handle.shards: + cube_groups[(shard.sip, shard.cube)].append(shard) + + for (sip, cube), group_shards in cube_groups.items(): + entries = tuple( + {"va": handle.va_base + s.offset_bytes, + "pa": s.pa, "size": s.nbytes} + for s in group_shards + ) + msg = MmuMapMsg( + correlation_id=self.correlation_id, + request_id=f"mmu_{tensor_name}_s{sip}c{cube}", + entries=entries, + target_sips=(sip,), + target_cubes=(cube,), + target_pe="all", + ) + h = self.submit(msg) + self.wait(h) + else: + # Sharded: broadcast all mappings to all target (sip, cube)s + entries = tuple( + {"va": handle.va_base + s.offset_bytes, + "pa": s.pa, "size": s.nbytes} + for s in handle.shards + ) + sip_set = sorted({s.sip for s in handle.shards}) + cube_set = sorted({s.cube for s in handle.shards}) + msg = MmuMapMsg( + correlation_id=self.correlation_id, + request_id=f"mmu_{tensor_name}", + entries=entries, + target_sips=tuple(sip_set), + target_cubes=tuple(cube_set), + target_pe="all", + ) + h = self.submit(msg) + self.wait(h) # Submit MemoryWriteMsg per shard (deploy data to device) if pattern is not None: diff --git a/src/kernbench/runtime_api/kernel.py b/src/kernbench/runtime_api/kernel.py index 433d976..3fc8624 100644 --- a/src/kernbench/runtime_api/kernel.py +++ b/src/kernbench/runtime_api/kernel.py @@ -69,6 +69,7 @@ class TensorArgShard: class TensorArg: shards: tuple[TensorArgShard, ...] arg_kind: Literal["tensor"] = "tensor" + va_base: int = 0 # VA base address for the entire tensor @dataclass(frozen=True) @@ -121,3 +122,33 @@ class PeDmaMsg: nbytes: int is_write: bool = False msg_type: Literal["pe_dma"] = "pe_dma" + + +@dataclass(frozen=True) +class MmuMapMsg: + """MMU mapping install: broadcast VA→PA entries to target PEs. + + Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU. + target_sips controls which SIPs receive the message. + """ + + correlation_id: str + request_id: str + entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...) + target_sips: tuple[int, ...] | Literal["all"] = "all" + target_cubes: tuple[int, ...] | Literal["all"] = "all" + target_pe: int | Literal["all"] = "all" + msg_type: Literal["mmu_map"] = "mmu_map" + + +@dataclass(frozen=True) +class MmuUnmapMsg: + """MMU mapping removal: broadcast VA ranges to unmap from all PEs.""" + + correlation_id: str + request_id: str + entries: tuple[dict, ...] # ({"va": int, "size": int}, ...) + target_sips: tuple[int, ...] | Literal["all"] = "all" + target_cubes: tuple[int, ...] | Literal["all"] = "all" + target_pe: int | Literal["all"] = "all" + msg_type: Literal["mmu_unmap"] = "mmu_unmap" diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 26d4749..4dde44f 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import weakref from dataclasses import dataclass from typing import Literal @@ -26,6 +27,7 @@ class TensorHandle: dtype: str itemsize: int shards: tuple[TensorShard, ...] + va_base: int = 0 # VA base address for the entire tensor @property def nbytes(self) -> int: @@ -56,8 +58,19 @@ def deploy_tensor( placement: list[ShardSpec], allocators: dict[int, PEMemAllocator], mem_kind: Literal["hbm", "tcm"] = "hbm", + va_allocator=None, + mmus: dict | None = None, ) -> TensorHandle: + from kernbench.policy.address.pe_mmu import PeMMU + isize = dtype_itemsize(dtype) + total_nbytes = math.prod(shape) * isize + + # Allocate VA range for the entire tensor (if VA allocator provided) + va_base = 0 + if va_allocator is not None: + va_base = va_allocator.alloc(total_nbytes) + shards: list[TensorShard] = [] for spec in placement: alloc = allocators[spec.pe_index] @@ -65,20 +78,29 @@ def deploy_tensor( pa = alloc.alloc_hbm(spec.nbytes) else: pa = alloc.alloc_tcm(spec.nbytes) + encoded_pa = pa.encode() shards.append(TensorShard( sip=alloc._sip_id, cube=alloc._cube_id, pe=alloc._pe_id, - pa=pa.encode(), + pa=encoded_pa, nbytes=spec.nbytes, offset_bytes=spec.offset_bytes, )) + + # Register VA→PA mapping in all MMUs (broadcast) + if va_base and mmus is not None: + shard_va = va_base + spec.offset_bytes + for mmu in mmus.values(): + mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes) + return TensorHandle( name=name, shape=shape, dtype=dtype, itemsize=isize, shards=tuple(shards), + va_base=va_base, ) @@ -101,8 +123,7 @@ class Tensor: Usage:: - a = ctx.zeros((M, K), dtype="f16") - a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8)) + a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate")) ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K) """ @@ -117,6 +138,14 @@ class Tensor: self.name = name self._dp_metadata: DPMetadata | None = None self._handle: TensorHandle | None = None + self._ctx_ref: weakref.ref | None = None # set by RuntimeContext + + def __del__(self) -> None: + if self._ctx_ref is None or self._handle is None: + return + ctx = self._ctx_ref() + if ctx is not None: + ctx._free_tensor(self) @property def itemsize(self) -> int: @@ -133,6 +162,13 @@ class Tensor: raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") return self._handle.shards[0].pa + @property + def va(self) -> int: + """VA base address for the entire tensor.""" + if self._handle is None: + raise RuntimeError(f"Tensor '{self.name}' is not deployed yet") + return self._handle.va_base + def to( self, placement: list[ShardSpec] | None = None, @@ -163,4 +199,5 @@ class Tensor: ) for s in self._handle.shards ), + va_base=self._handle.va_base, ) diff --git a/src/kernbench/sim_engine/engine.py b/src/kernbench/sim_engine/engine.py index 3388334..6c25813 100644 --- a/src/kernbench/sim_engine/engine.py +++ b/src/kernbench/sim_engine/engine.py @@ -98,6 +98,16 @@ class GraphEngine: self._components[node_id].in_ports["host"] = host_q self._pe_dma_queues[node_id] = host_q + # Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object + for node_id, node in graph.nodes.items(): + if node.kind == "pe_dma": + # Derive PE_MMU node ID from PE_DMA node ID + pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0" + mmu_id = f"{pe_prefix}.pe_mmu" + mmu_comp = self._components.get(mmu_id) + if mmu_comp is not None and hasattr(mmu_comp, "mmu"): + self._components[node_id]._mmu = mmu_comp.mmu + # Start components after all ports are wired (ADR-0015 D3) for comp in self._components.values(): comp.start(self._env) @@ -119,6 +129,27 @@ class GraphEngine: def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: return self._results[str(handle)] + def mmu_map(self, va: int, pa: int, size: int) -> None: + """Sideband: install VA→PA mapping in all PE_MMU components.""" + for node_id, comp in self._components.items(): + if hasattr(comp, "mmu"): + comp.mmu.map(va=va, pa=pa, size=size) + + def mmu_map_pe( + self, sip: int, cube: int, pe: int, va: int, pa: int, size: int, + ) -> None: + """Sideband: install VA→PA mapping in a specific PE's MMU only.""" + mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu" + comp = self._components.get(mmu_id) + if comp is not None and hasattr(comp, "mmu"): + comp.mmu.map(va=va, pa=pa, size=size) + + def mmu_unmap(self, va: int, size: int) -> None: + """Sideband: remove VA mapping from all PE_MMU components.""" + for node_id, comp in self._components.items(): + if hasattr(comp, "mmu"): + comp.mmu.unmap(va=va, size=size) + # ── internal ──────────────────────────────────────────────────── def _wire( @@ -166,6 +197,11 @@ class GraphEngine: yield from self._process_memory_direct(key, request, done) return + from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg + if isinstance(request, (MmuMapMsg, MmuUnmapMsg)): + yield from self._process_mmu_msg(key, request, done) + return + entries = self._entry_points(request) if not entries: self._results[key] = ( @@ -341,3 +377,59 @@ class GraphEngine: return entries raise ValueError(f"unsupported request type: {type(request)}") + + def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event): + """Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg. + + Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU + """ + start_ns = self._env.now + target_sips = getattr(request, "target_sips", "all") + + # Determine target SIPs + sip_set: set[int] = set() + if target_sips == "all": + for ep_id in self._resolver.find_all_pcie_eps(): + sip_id = int(ep_id.split(".")[0].replace("sip", "")) + sip_set.add(sip_id) + else: + sip_set = set(target_sips) + + entries = [] + for sip_id in sorted(sip_set): + entries.append(( + self._resolver.find_pcie_ep(sip_id), + self._resolver.find_io_cpu(sip_id), + 0, # MmuMapMsg has no data payload + )) + + if not entries: + self._results[key] = (Completion(ok=True), {"total_ns": 0.0}) + done.succeed() + return + + if len(entries) == 1: + pcie_ep_id, io_cpu_id, _ = entries[0] + path = self._router.find_node_path(pcie_ep_id, io_cpu_id) + txn_done = self._env.event() + txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done) + yield self._host_queues[pcie_ep_id].put(txn) + yield txn_done + else: + # Multi-SIP fan-out + sub_dones = [] + for pcie_ep_id, io_cpu_id, _ in entries: + path = self._router.find_node_path(pcie_ep_id, io_cpu_id) + sub_done = self._env.event() + sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done) + yield self._host_queues[pcie_ep_id].put(sub_txn) + sub_dones.append(sub_done) + for sd in sub_dones: + yield sd + + elapsed = self._env.now - start_ns + self._results[key] = ( + Completion(ok=True), + {"total_ns": elapsed, "msg_type": request.msg_type}, + ) + done.succeed() diff --git a/src/kernbench/topology/builder.py b/src/kernbench/topology/builder.py index 246f01f..d9c267b 100644 --- a/src/kernbench/topology/builder.py +++ b/src/kernbench/topology/builder.py @@ -22,6 +22,7 @@ _PE_COMP_OFFSETS = { "pe_dma": (0.0, -0.15), "pe_gemm": (0.0, 0.0), "pe_math": (0.0, 0.15), + "pe_mmu": (0.15, -0.15), "pe_tcm": (0.3, 0.0), } @@ -495,6 +496,15 @@ def _instantiate_cube( kind="pe_response", )) + # noc → PE_MMU (MMU mapping install) + pe_mmu_id = f"{pp}.pe_mmu" + if pe_mmu_id in nodes: + edges.append(Edge( + src=f"{cp}.noc", dst=pe_mmu_id, + distance_mm=clinks.get("noc_to_pe_mmu_mm", 0.0), + kind="command", + )) + pe_idx += 1 # ── xbar_top/bot → HBM slices ── @@ -1073,6 +1083,7 @@ def _build_pe_view(spec: dict) -> ViewGraph: "pe_dma": (7.0, 1.5), "pe_gemm": (7.0, 4.0), "pe_math": (7.0, 6.5), + "pe_mmu": (4.0, 1.5), "pe_tcm": (10.0, 4.0), } diff --git a/src/kernbench/triton_emu/tl_context.py b/src/kernbench/triton_emu/tl_context.py index 4d5296c..63c867b 100644 --- a/src/kernbench/triton_emu/tl_context.py +++ b/src/kernbench/triton_emu/tl_context.py @@ -86,11 +86,11 @@ class TLContext: self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles)) def _make_handle( - self, pa: int, shape: tuple[int, ...], dtype: str, + self, addr: int, shape: tuple[int, ...], dtype: str, ) -> TensorHandle: return TensorHandle( id=self._next_handle_id(), - pa=pa, shape=shape, dtype=dtype, + addr=addr, shape=shape, dtype=dtype, nbytes=self._nbytes(shape, dtype), ) @@ -104,7 +104,7 @@ class TLContext: Used when the scheduler will stream data per-tile (e.g., tensor b in a composite GEMM). No command is generated. """ - return self._make_handle(pa=ptr, shape=shape, dtype=dtype) + return self._make_handle(addr=ptr, shape=shape, dtype=dtype) # ── Data Movement (blocking, DMA engine) ────────────────────── @@ -113,9 +113,9 @@ class TLContext: ) -> TensorHandle: """Load tensor from HBM to TCM. Returns TensorHandle.""" self._emit_dispatch_overhead() - handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype) + handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype) self._commands.append(DmaReadCmd( - handle=handle, src_pa=ptr, nbytes=handle.nbytes, + handle=handle, src_addr=ptr, nbytes=handle.nbytes, )) return handle @@ -123,7 +123,7 @@ class TLContext: """Store tensor from TCM to HBM.""" self._emit_dispatch_overhead() self._commands.append(DmaWriteCmd( - handle=handle, dst_pa=ptr, nbytes=handle.nbytes, + handle=handle, dst_addr=ptr, nbytes=handle.nbytes, )) # ── GEMM Engine (blocking) ──────────────────────────────────── @@ -141,7 +141,7 @@ class TLContext: raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}") out_shape = (*a.shape[:-2], m, n) out_dtype = a.dtype - out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype) + out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype) self._emit_dispatch_overhead() self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n)) return out @@ -149,7 +149,7 @@ class TLContext: # ── MATH Engine: unary (blocking) ───────────────────────────── def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle: - out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype) + out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype) self._emit_dispatch_overhead() self._commands.append(MathCmd(op=op, inputs=(x,), out=out)) return out @@ -182,7 +182,7 @@ class TLContext: ) -> TensorHandle: out_shape = list(x.shape) out_shape[axis] = 1 - out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype) + out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype) self._emit_dispatch_overhead() self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis)) return out @@ -201,7 +201,7 @@ class TLContext: def _binary_math( self, op: str, a: TensorHandle, b: TensorHandle, ) -> TensorHandle: - out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype) + out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() self._commands.append(MathCmd(op=op, inputs=(a, b), out=out)) return out @@ -209,7 +209,7 @@ class TLContext: def where( self, cond: TensorHandle, a: TensorHandle, b: TensorHandle, ) -> TensorHandle: - out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype) + out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out)) return out @@ -227,17 +227,17 @@ class TLContext: def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle: """Create index range tensor in TCM.""" n = end - start - return self._make_handle(pa=0, shape=(n,), dtype=dtype) + return self._make_handle(addr=0, shape=(n,), dtype=dtype) def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle: """Create zero-filled tensor in TCM.""" - return self._make_handle(pa=0, shape=shape, dtype=dtype) + return self._make_handle(addr=0, shape=shape, dtype=dtype) def full( self, shape: tuple[int, ...], value: float | int, dtype: str = "f16", ) -> TensorHandle: """Create constant-filled tensor in TCM.""" - return self._make_handle(pa=0, shape=shape, dtype=dtype) + return self._make_handle(addr=0, shape=shape, dtype=dtype) # ── Metadata (no compute, no DMA) ───────────────────────────── @@ -247,7 +247,7 @@ class TLContext: raise ValueError("trans requires at least 2D tensor") new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2]) return TensorHandle( - id=x.id, pa=x.pa, shape=new_shape, + id=x.id, addr=x.addr, shape=new_shape, dtype=x.dtype, nbytes=x.nbytes, data=x.data, ) @@ -278,7 +278,7 @@ class TLContext: self._emit_dispatch_overhead() self._commands.append(CompositeCmd( completion=completion, op=op, - a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes, + a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes, math_op=math_op, )) return completion diff --git a/tests/test_mmu_component.py b/tests/test_mmu_component.py new file mode 100644 index 0000000..08ba17f --- /dev/null +++ b/tests/test_mmu_component.py @@ -0,0 +1,226 @@ +"""Tests for PE_MMU component integration and MmuMapMsg fabric path. + +Validates: + T15-a. PE_MMU component registered in ComponentRegistry + T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table + T15-c. PE_DMA translates VA→PA via mmu before routing + T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields + T17. PE_CPU passes VA (not PA) to kernel when VA is available + T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA +""" +import pytest + +try: + from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg +except ImportError: + pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True) + + +# ── T16. MmuMapMsg / MmuUnmapMsg message types ────────────────────── + + +def test_mmu_map_msg_fields(): + """MmuMapMsg carries VA→PA mapping entries for broadcast.""" + msg = MmuMapMsg( + correlation_id="c0", + request_id="r0", + entries=( + {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, + {"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, + ), + target_cubes="all", + target_pe="all", + ) + assert msg.msg_type == "mmu_map" + assert len(msg.entries) == 2 + assert msg.entries[0]["va"] == 0x1_0000_0000 + assert msg.entries[0]["pa"] == 0xA000_0000 + assert msg.entries[0]["size"] == 4096 + + +def test_mmu_map_msg_immutable(): + """MmuMapMsg is frozen.""" + msg = MmuMapMsg( + correlation_id="c0", + request_id="r0", + entries=(), + target_cubes="all", + target_pe="all", + ) + with pytest.raises(AttributeError): + msg.entries = () # type: ignore[misc] + + +def test_mmu_unmap_msg_fields(): + """MmuUnmapMsg carries VA ranges to unmap.""" + msg = MmuUnmapMsg( + correlation_id="c0", + request_id="r0", + entries=( + {"va": 0x1_0000_0000, "size": 4096}, + ), + target_cubes="all", + target_pe="all", + ) + assert msg.msg_type == "mmu_unmap" + assert len(msg.entries) == 1 + assert msg.entries[0]["va"] == 0x1_0000_0000 + + +# ── T15-a. PE_MMU component registry ──────────────────────────────── + + +def test_pe_mmu_registry(): + """pe_mmu_v1 impl resolves in ComponentRegistry.""" + from kernbench.components.base import ComponentRegistry + from kernbench.components.impls.pe_mmu import PeMmuComponent + from kernbench.topology.types import Node + + node = Node( + id="sip0.cube0.pe0.pe_mmu", + kind="pe_mmu", + impl="pe_mmu_v1", + pos_mm=None, + attrs={"tlb_overhead_ns": 0.5}, + ) + comp = ComponentRegistry.create(node) + assert isinstance(comp, PeMmuComponent) + + +# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ───────── + + +def test_pe_mmu_processes_map_msg(): + """PE_MMU component receives MmuMapMsg → translate works.""" + import simpy + from kernbench.components.impls.pe_mmu import PeMmuComponent + from kernbench.sim_engine.transaction import Transaction + from kernbench.topology.types import Node + + env = simpy.Environment() + node = Node( + id="sip0.cube0.pe0.pe_mmu", + kind="pe_mmu", + impl="pe_mmu_v1", + pos_mm=None, + attrs={"tlb_overhead_ns": 0.5, "page_size": 4096}, + ) + comp = PeMmuComponent(node) + comp.in_ports["src"] = simpy.Store(env) + comp.start(env) + + # Submit MmuMapMsg via inbox + map_msg = MmuMapMsg( + correlation_id="c0", + request_id="r0", + entries=( + {"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096}, + ), + target_cubes="all", + target_pe="all", + ) + done = env.event() + txn = Transaction( + request=map_msg, + path=["sip0.cube0.pe0.pe_mmu"], + step=0, nbytes=0, done=done, + ) + + def inject(): + yield comp._inbox.put(txn) + + env.process(inject()) + env.run(until=100) + + # After processing, the MMU's translate should work + from kernbench.policy.address.pe_mmu import PeMMU + mmu = comp.mmu # the underlying PeMMU utility object + assert isinstance(mmu, PeMMU) + assert mmu.translate(0x1_0000_0000) == 0xABCD_0000 + + +# ── T15-c. PE_DMA uses MMU translate ──────────────────────────────── + + +def test_pe_dma_translates_va(): + """PE_DMA.handle_command calls mmu.translate(va) → PA before routing. + + This test validates the contract: after Phase 2, DmaReadCmd carries VA, + and PE_DMA must translate it to PA via the MMU before resolving the + HBM node path. + """ + # This test validates the interface contract. Full integration test + # requires the engine wiring which is validated in test_engine. + # Here we check that PE_DMA has an mmu attribute it can call. + from kernbench.components.impls.pe_dma import PeDmaComponent + from kernbench.topology.types import Node + + node = Node( + id="sip0.cube0.pe0.pe_dma", + kind="pe_dma", + impl="pe_dma_v1", + pos_mm=None, + attrs={"rd_engines": 1, "wr_engines": 1}, + ) + comp = PeDmaComponent(node) + + # PE_DMA must have a way to access the MMU (via ctx or direct reference) + # The exact wiring mechanism is flexible, but the attribute must exist + assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or ( + hasattr(comp, 'ctx') and comp.ctx is not None + ), "PE_DMA must have access to PE_MMU for VA translation" + + +# ── T17. PE_CPU passes VA to kernel ────────────────────────────────── + + +def test_pe_cpu_uses_va_base_from_tensor_arg(): + """PE_CPU should use TensorArg.va_base for kernel pointer args. + + After Phase 2, TensorArg carries va_base alongside shards. + PE_CPU extracts va_base and passes it to the kernel function + so kernels operate on VA (not PA). + """ + from kernbench.runtime_api.kernel import TensorArg, TensorArgShard + + shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000, + nbytes=4096, offset_bytes=0) + targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000) + + # PE_CPU should use targ.va_base for kernel pointer arg + assert targ.va_base == 0x1_0000_0000 + # PA still accessible via shard for direct-PA operations (IPCQ etc.) + assert shard.pa == 0x1000 + + +# ── T18. MmuMapMsg broadcast pattern ───────────────────────────────── + + +def test_mmu_map_msg_broadcast_target(): + """MmuMapMsg with target_pe='all' is a broadcast to all PEs.""" + msg = MmuMapMsg( + correlation_id="c0", + request_id="r0", + entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},), + target_cubes="all", + target_pe="all", + ) + assert msg.target_pe == "all" + assert msg.target_cubes == "all" + + +def test_mmu_map_msg_same_entries_all_pes(): + """All PEs in a broadcast receive identical entries (not per-PE splits).""" + entries = ( + {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192}, + {"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192}, + ) + msg = MmuMapMsg( + correlation_id="c0", + request_id="r0", + entries=entries, + target_cubes="all", + target_pe="all", + ) + # The message carries the full mapping — every PE receives exactly this + assert msg.entries == entries diff --git a/tests/test_mmu_fabric.py b/tests/test_mmu_fabric.py new file mode 100644 index 0000000..62a2ad3 --- /dev/null +++ b/tests/test_mmu_fabric.py @@ -0,0 +1,241 @@ +"""Tests for MmuMapMsg fabric path and cross-cube mapping. + +Validates: + F1. MmuMapMsg traverses fabric: latency > 0 (not sideband) + F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs + F3. After MmuMapMsg, PE_MMU has correct mappings + F4. Cross-cube sharded tensor: all PEs get global mappings + F5. Replicate tensor: each PE gets own cube's PA (local override) + F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM + F7. Overlap detection: replicate vs sharded identified correctly + F8. Existing regression: PA-only benchmarks still pass +""" +import pytest +from pathlib import Path + +from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator +from kernbench.policy.address.pe_mmu import PeMMU +from kernbench.policy.address.va_allocator import VirtualAllocator +from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec +from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle +from kernbench.sim_engine.engine import GraphEngine +from kernbench.topology.builder import load_topology + +TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" + +_MB = 1 << 20 +_GB = 1 << 30 + +_CFG = AddressConfig( + sip_count=2, + cubes_per_sip=16, + pes_per_cube=8, + hbm_bytes_per_cube=48 * _GB, + hbm_slices_per_cube=8, + tcm_bytes_per_pe=16 * _MB, + tcm_scheduler_reserved_bytes=4 * _MB, + sram_bytes_per_cube=32 * _MB, +) + + +def _engine(): + return GraphEngine(load_topology(TOPOLOGY_PATH)) + + +# ── F1. MmuMapMsg fabric latency ───────────────────────────────────── + + +def test_mmu_map_via_fabric_has_latency(): + """MmuMapMsg submitted through engine.submit() completes with latency > 0.""" + from kernbench.runtime_api.kernel import MmuMapMsg + + engine = _engine() + msg = MmuMapMsg( + correlation_id="c0", + request_id="mmu_map_0", + entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},), + target_cubes=(0,), + target_pe="all", + ) + h = engine.submit(msg) + engine.wait(h) + comp, trace = engine.get_completion(h) + assert comp.ok is True + # Fabric traversal must have non-zero latency + assert trace is not None + assert trace.get("total_ns", 0) > 0 + + +# ── F2. MmuMapMsg fan-out ──────────────────────────────────────────── + + +def test_mmu_map_reaches_all_pes_in_cube(): + """MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube.""" + from kernbench.runtime_api.kernel import MmuMapMsg + + engine = _engine() + va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096 + msg = MmuMapMsg( + correlation_id="c0", + request_id="mmu_map_1", + entries=({"va": va, "pa": pa, "size": size},), + target_cubes=(0,), + target_pe="all", + ) + h = engine.submit(msg) + engine.wait(h) + + # Verify all 8 PE_MMUs in cube 0 have the mapping + for pe_id in range(8): + mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu" + mmu_comp = engine._components[mmu_id] + assert mmu_comp.mmu.translate(va) == pa + + +# ── F3. Multiple MmuMapMsg entries ─────────────────────────────────── + + +def test_mmu_map_multiple_entries(): + """MmuMapMsg with multiple entries installs all of them.""" + from kernbench.runtime_api.kernel import MmuMapMsg + + engine = _engine() + entries = ( + {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, + {"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, + ) + msg = MmuMapMsg( + correlation_id="c0", + request_id="mmu_map_2", + entries=entries, + target_cubes=(0,), + target_pe="all", + ) + h = engine.submit(msg) + engine.wait(h) + + mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"] + assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000 + assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000 + + +# ── F4. Cross-cube sharded: global mapping ─────────────────────────── + + +def test_cross_cube_sharded_all_pes_get_global_mapping(): + """For sharded tensor across cubes (unique offsets), all PEs get all mappings.""" + from kernbench.runtime_api.kernel import MmuMapMsg + + engine = _engine() + # Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096 + entries = ( + {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0 + {"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1 + ) + # Broadcast to both cubes + msg = MmuMapMsg( + correlation_id="c0", + request_id="mmu_map_xc", + entries=entries, + target_cubes=(0, 1), + target_pe="all", + ) + h = engine.submit(msg) + engine.wait(h) + + # PE in cube0 can translate both cube0 and cube1 addresses + mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"] + assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local + assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote + + # PE in cube1 can also translate both + mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"] + assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote + assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local + + +# ── F5. Replicate: local PA override ───────────────────────────────── + + +def test_replicate_local_pa_override(): + """For replicated tensor (same VA range), each cube's PEs see local PA.""" + from kernbench.runtime_api.kernel import MmuMapMsg + + engine = _engine() + va, size = 0x1_0000_0000, 4096 + + # Cube 0 gets its own PA + msg0 = MmuMapMsg( + correlation_id="c0", + request_id="mmu_rep_c0", + entries=({"va": va, "pa": 0xA000_0000, "size": size},), + target_cubes=(0,), + target_pe="all", + ) + h0 = engine.submit(msg0) + engine.wait(h0) + + # Cube 1 gets a different PA for the same VA + msg1 = MmuMapMsg( + correlation_id="c0", + request_id="mmu_rep_c1", + entries=({"va": va, "pa": 0xB000_0000, "size": size},), + target_cubes=(1,), + target_pe="all", + ) + h1 = engine.submit(msg1) + engine.wait(h1) + + # Cube 0 PEs translate to cube 0's PA + mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"] + assert mmu_c0.mmu.translate(va) == 0xA000_0000 + + # Cube 1 PEs translate to cube 1's PA + mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"] + assert mmu_c1.mmu.translate(va) == 0xB000_0000 + + +# ── F7. Overlap detection ──────────────────────────────────────────── + + +def test_detect_overlapping_shards(): + """Utility: detect if shards have overlapping VA ranges (replicate indicator).""" + from kernbench.runtime_api.tensor import TensorShard + + # Sharded: unique offsets + sharded = [ + TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0), + TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096), + ] + offsets = [(s.offset_bytes, s.nbytes) for s in sharded] + assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets" + + # Replicated: same offset + replicated = [ + TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0), + TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0), + ] + offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated] + assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets" + + +# ── F8. Regression: existing benchmarks still pass ─────────────────── + + +def test_qkv_gemm_still_passes(): + """QKV GEMM benchmark completes successfully with VA/MMU enabled.""" + from kernbench.runtime_api.context import RuntimeContext + from kernbench.runtime_api.types import BenchResult, DeviceSelector + + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + ctx = RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:0"), + correlation_id="test_regression", + spec=graph.spec, + ) + from benches.qkv_gemm import run as bench_run + bench_run(ctx) + ctx.wait_all() + # If we get here without exception, the benchmark succeeded diff --git a/tests/test_pe_components.py b/tests/test_pe_components.py index 35c4efb..3149edc 100644 --- a/tests/test_pe_components.py +++ b/tests/test_pe_components.py @@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn(): gemm.in_ports["src"] = simpy.Store(env) gemm.start(env) - a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64) - b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64) - out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32) + a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64) + b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64) + out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32) cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4) done = env.event() pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix) @@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn(): math_comp.in_ports["src"] = simpy.Store(env) math_comp.start(env) - x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32) - out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32) + x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32) + out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32) cmd = MathCmd(op="exp", inputs=(x,), out=out) done = env.event() pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix) @@ -777,7 +777,7 @@ def test_tl_ref_no_dma(): tl = TLContext(pe_id=0, dispatch_cycles=0) handle = tl.ref(0x1000, shape=(4, 4), dtype="f16") - assert handle.pa == 0x1000 + assert handle.addr == 0x1000 assert handle.shape == (4, 4) assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}" diff --git a/tests/test_pe_mmu.py b/tests/test_pe_mmu.py new file mode 100644 index 0000000..35a5190 --- /dev/null +++ b/tests/test_pe_mmu.py @@ -0,0 +1,203 @@ +"""Tests for PeMMU: per-PE virtual-to-physical address translation. + +Validates: + T1. Basic map + translate + T2. Page-aligned dict lookup (O(1), multi-page range) + T3. Multiple tensor mappings accumulate + T4. unmap removes entries, translate raises PageFault + T5. PageFault on unmapped VA + T6. Identical mapping broadcast across multiple PEs +""" +import pytest + +from kernbench.policy.address.pe_mmu import PageFault, PeMMU + +_2MB = 2 * 1024 * 1024 + + +# ── T1. Basic map + translate ──────────────────────────────────────── + + +def test_map_and_translate_basic(): + """map(va, pa, size) → translate(va) returns pa; offset preserved.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) + assert mmu.translate(0x1_0000_0000) == 0xABCD_0000 + + +def test_translate_preserves_offset(): + """translate(va + offset) returns pa + offset within a page.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) + assert mmu.translate(0x1_0000_0100) == 0xABCD_0100 + assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF + + +# ── T2. Page-aligned dict lookup (multi-page) ─────────────────────── + + +def test_multi_page_mapping(): + """8 MB mapping with 2 MB pages → 4 page entries, all translate correctly.""" + mmu = PeMMU(page_size=_2MB) + va_base = 0x1_0000_0000 + pa_base = 0x2_0000_0000 + size = 8 * 1024 * 1024 # 8 MB = 4 pages + + mmu.map(va=va_base, pa=pa_base, size=size) + + # First page + assert mmu.translate(va_base) == pa_base + # Second page start + assert mmu.translate(va_base + _2MB) == pa_base + _2MB + # Third page with offset + assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100 + # Last byte of last page + assert mmu.translate(va_base + size - 1) == pa_base + size - 1 + + +def test_page_table_entry_count(): + """Mapping N bytes with page_size P creates ceil(N/P) entries.""" + mmu = PeMMU(page_size=_2MB) + mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024) + assert mmu.num_entries == 4 + + mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB) + assert mmu.num_entries == 5 + + +# ── T3. Multiple tensor mappings accumulate ────────────────────────── + + +def test_multiple_mappings_accumulate(): + """Two non-overlapping tensors → both translate correctly.""" + mmu = PeMMU(page_size=4096) + # Tensor A + mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) + # Tensor B (different VA range) + mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096) + + assert mmu.translate(0x1_0000_0000) == 0xA000_0000 + assert mmu.translate(0x1_0001_0000) == 0xB000_0000 + + +def test_mappings_do_not_interfere(): + """Adjacent VA ranges map to completely independent PA ranges.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096) + mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096) + + assert mmu.translate(0x0000_0000) == 0xFFFF_0000 + assert mmu.translate(0x0000_1000) == 0x0000_0000 + + +# ── T4. unmap removes entries ──────────────────────────────────────── + + +def test_unmap_removes_mapping(): + """After unmap, translate raises PageFault.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) + assert mmu.translate(0x1_0000_0000) == 0xABCD_0000 + + mmu.unmap(va=0x1_0000_0000, size=4096) + with pytest.raises(PageFault): + mmu.translate(0x1_0000_0000) + + +def test_unmap_partial_range(): + """Unmap only part of a multi-page mapping; rest stays valid.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages + assert mmu.num_entries == 2 + + # Unmap first page only + mmu.unmap(va=0x1000_0000, size=4096) + assert mmu.num_entries == 1 + + with pytest.raises(PageFault): + mmu.translate(0x1000_0000) + # Second page still valid + assert mmu.translate(0x1000_1000) == 0x2000_1000 + + +def test_unmap_does_not_affect_other_mappings(): + """Unmapping tensor A does not affect tensor B.""" + mmu = PeMMU(page_size=4096) + mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) + mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096) + + mmu.unmap(va=0x1_0000_0000, size=4096) + with pytest.raises(PageFault): + mmu.translate(0x1_0000_0000) + assert mmu.translate(0x1_0001_0000) == 0xB000_0000 + + +# ── T5. PageFault on unmapped VA ───────────────────────────────────── + + +def test_pagefault_on_unmapped_va(): + """translate() on never-mapped VA raises PageFault.""" + mmu = PeMMU(page_size=4096) + with pytest.raises(PageFault): + mmu.translate(0xDEAD_BEEF) + + +def test_pagefault_contains_va(): + """PageFault exception carries the faulting VA.""" + mmu = PeMMU(page_size=4096) + with pytest.raises(PageFault, match="0xdeadbeef"): + mmu.translate(0xDEAD_BEEF) + + +# ── T6. Identical mapping broadcast across PEs ─────────────────────── + + +def test_broadcast_same_mapping_to_all_pes(): + """All PEs receive the same full mapping → identical translate results.""" + entries = [ + (0x1_0000_0000, 0xA000_0000, 4096), # shard 0 + (0x1_0000_1000, 0xB000_0000, 4096), # shard 1 + (0x1_0000_2000, 0xC000_0000, 4096), # shard 2 + (0x1_0000_3000, 0xD000_0000, 4096), # shard 3 + ] + num_pes = 8 + mmus = [PeMMU(page_size=4096) for _ in range(num_pes)] + + # Broadcast: every PE gets the same entries + for mmu in mmus: + for va, pa, size in entries: + mmu.map(va=va, pa=pa, size=size) + + # All PEs translate identically + for mmu in mmus: + assert mmu.translate(0x1_0000_0000) == 0xA000_0000 + assert mmu.translate(0x1_0000_1000) == 0xB000_0000 + assert mmu.translate(0x1_0000_2000) == 0xC000_0000 + assert mmu.translate(0x1_0000_3000) == 0xD000_0000 + + +def test_cross_pe_access_via_broadcast(): + """PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario).""" + mmu_pe0 = PeMMU(page_size=4096) + # Full mapping includes PE3's shard + mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard + mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard + + # PE0 accesses PE3's region → valid translation + pa = mmu_pe0.translate(0x1_0000_1000 + 0x100) + assert pa == 0xD000_0100 + + +# ── TLB overhead attribute ─────────────────────────────────────────── + + +def test_tlb_overhead_default(): + """Default tlb_overhead_ns is 0.""" + mmu = PeMMU() + assert mmu.overhead_ns == 0.0 + + +def test_tlb_overhead_configurable(): + """tlb_overhead_ns is configurable.""" + mmu = PeMMU(overhead_ns=0.5) + assert mmu.overhead_ns == 0.5 diff --git a/tests/test_tensor_free.py b/tests/test_tensor_free.py new file mode 100644 index 0000000..20d9913 --- /dev/null +++ b/tests/test_tensor_free.py @@ -0,0 +1,193 @@ +"""Tests for tensor free: del-based + context manager cleanup. + +Validates: + TF1. PEMemAllocator.free_hbm/free_tcm reclaims space + TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped) + TF3. Context manager cleans up all tensors on exit + TF4. del after context exit is safe (no crash) + TF5. Alloc-del-alloc cycle reuses VA and PA + TF6. del already-freed tensor is safe (no crash) +""" +import gc + +import pytest +from pathlib import Path + +from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator +from kernbench.policy.address.pe_mmu import PageFault +from kernbench.policy.placement.dp import DPPolicy +from kernbench.sim_engine.engine import GraphEngine +from kernbench.runtime_api.context import RuntimeContext +from kernbench.runtime_api.types import DeviceSelector +from kernbench.topology.builder import load_topology + +TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" + +_MB = 1 << 20 +_GB = 1 << 30 + +_CFG = AddressConfig( + sip_count=2, + cubes_per_sip=16, + pes_per_cube=8, + hbm_bytes_per_cube=48 * _GB, + hbm_slices_per_cube=8, + tcm_bytes_per_pe=16 * _MB, + tcm_scheduler_reserved_bytes=4 * _MB, + sram_bytes_per_cube=32 * _MB, +) + + +def _make_ctx(): + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + ctx = RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:0"), + correlation_id="test_free", + spec=graph.spec, + ) + return ctx, engine + + +# ── TF1. PEMemAllocator free ───────────────────────────────────────── + + +def test_allocator_free_hbm_reclaims_space(): + """free_hbm returns HBM space; subsequent alloc can reuse it.""" + a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG) + pa1 = a.alloc_hbm(4096) + used_after_alloc = a.hbm_used + a.free_hbm(pa1, 4096) + assert a.hbm_used == used_after_alloc - 4096 + pa2 = a.alloc_hbm(4096) + assert pa2 is not None + + +def test_allocator_free_tcm_reclaims_space(): + """free_tcm returns TCM space.""" + a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG) + pa1 = a.alloc_tcm(256) + used_after_alloc = a.tcm_used + a.free_tcm(pa1, 256) + assert a.tcm_used == used_after_alloc - 256 + + +# ── TF2. del tensor triggers cleanup ───────────────────────────────── + + +def test_del_tensor_unmaps_mmu(): + """del tensor removes MMU mappings.""" + ctx, engine = _make_ctx() + dp = DPPolicy(cube="replicate", pe="replicate") + t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test") + va_base = t._handle.va_base + + # Verify mapping exists + mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"] + assert mmu_comp.mmu.translate(va_base) is not None + + # Delete tensor + del t + gc.collect() + + # Mapping should be gone + with pytest.raises(PageFault): + mmu_comp.mmu.translate(va_base) + + +def test_del_tensor_reclaims_va(): + """del tensor returns VA space for reuse.""" + ctx, engine = _make_ctx() + dp = DPPolicy(cube="replicate", pe="replicate") + + t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1") + va1 = t1._handle.va_base + + del t1 + gc.collect() + + t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2") + va2 = t2._handle.va_base + assert va2 == va1 + + +# ── TF3. Context manager cleanup ───────────────────────────────────── + + +def test_context_manager_cleans_all(): + """Exiting context manager cleans up all tensors.""" + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + + with RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:0"), + correlation_id="ctx_mgr", + spec=graph.spec, + ) as ctx: + dp = DPPolicy(cube="replicate", pe="replicate") + t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test") + va_base = t._handle.va_base + + # After context exit, MMU mappings should be cleared + mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"] + with pytest.raises(PageFault): + mmu_comp.mmu.translate(va_base) + + +# ── TF4. del after context exit is safe ─────────────────────────────── + + +def test_del_after_context_exit_no_crash(): + """del tensor after context manager exit does not crash.""" + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + + ctx = RuntimeContext( + engine=engine, + target_device=DeviceSelector("sip:0"), + correlation_id="safe_del", + spec=graph.spec, + ) + dp = DPPolicy(cube="replicate", pe="replicate") + t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test") + + # Simulate context going away + ctx.cleanup() + + # del should not crash even though context already cleaned up + del t + gc.collect() + # No exception = pass + + +# ── TF5. Alloc-del-alloc cycle ──────────────────────────────────────── + + +def test_alloc_del_cycle(): + """Multiple alloc-del cycles work correctly.""" + ctx, engine = _make_ctx() + dp = DPPolicy(cube="replicate", pe="replicate") + + for i in range(3): + t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}") + assert t._handle is not None + assert t._handle.va_base > 0 + del t + gc.collect() + + +# ── TF6. del already-cleaned tensor is safe ────────────────────────── + + +def test_del_already_cleaned_tensor_no_crash(): + """del on a tensor whose handle is already None does not crash.""" + ctx, engine = _make_ctx() + dp = DPPolicy(cube="replicate", pe="replicate") + t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del") + + ctx.cleanup() # clears all tensors + # t._handle is now None + del t # should not crash + gc.collect() diff --git a/tests/test_topology_compile.py b/tests/test_topology_compile.py index 14f6bbd..e3d0223 100644 --- a/tests/test_topology_compile.py +++ b/tests/test_topology_compile.py @@ -17,31 +17,32 @@ def test_full_graph_node_count(): g = _graph() # 1 switch # + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn) - # + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps)) + # + 16 cubes × (cube_comps + 8 PEs × 7 pe_comps)) # IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23 # cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie) # + 16 ucie_conn (4 ports × 4 connections) # + 2 xbar_top/bot # + 8 hbm_slices = 35 - # = 1 + 2*(23 + 16*(35+48)) = 1 + 2*(23+1328) = 1 + 2702 = 2703 - assert len(g.nodes) == 2703 + # pe_comps: 7 (pe_cpu, pe_scheduler, pe_dma, pe_gemm, pe_math, pe_mmu, pe_tcm) + # = 1 + 2*(23 + 16*(35+56)) = 1 + 2*(23+1456) = 1 + 2958 = 2959 + assert len(g.nodes) == 2959 def test_full_graph_edge_count(): g = _graph() - # Per cube: 184 + # Per cube: 192 # PE-internal: 56 - # PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8 + # PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8, noc→pe_mmu: 8 # xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8 # noc↔xbar_top: 2, noc↔xbar_bot: 2 # xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2 # xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2 # ucie: 64, m_cpu↔noc: 2, noc↔sram: 2 - # Total: 56+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 184 + # Total: 56+8+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 192 # IO edges per SIP: 77 - # Per SIP: 16*184 + 48 inter-cube + 77 IO = 3069 - # Total: 2 * 3069 = 6138 - assert len(g.edges) == 6138 + # Per SIP: 16*192 + 48 inter-cube + 77 IO = 3197 + # Total: 2 * 3197 = 6394 + assert len(g.edges) == 6394 # ── Full graph: specific nodes exist ───────────────────────────────── @@ -267,7 +268,7 @@ def test_cube_view_pe_to_noc(): def test_pe_view_has_all_components(): v = _graph().pe_view assert set(v.nodes.keys()) == { - "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm" + "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm" } diff --git a/tests/test_topology_load.py b/tests/test_topology_load.py index a30edf5..e16db62 100644 --- a/tests/test_topology_load.py +++ b/tests/test_topology_load.py @@ -23,7 +23,7 @@ def test_pe_template_components(): spec = _read_spec(TOPOLOGY_PATH) comps = spec["cube"]["pe_template"]["components"] assert set(comps.keys()) == { - "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm" + "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm" } diff --git a/tests/test_triton_emu.py b/tests/test_triton_emu.py index 036fc06..e144c80 100644 --- a/tests/test_triton_emu.py +++ b/tests/test_triton_emu.py @@ -34,7 +34,7 @@ def test_tl_load_generates_dma_read(): cmds = tl.commands assert len(cmds) == 1 assert isinstance(cmds[0], DmaReadCmd) - assert cmds[0].src_pa == 0x1000 + assert cmds[0].src_addr == 0x1000 assert cmds[0].nbytes == 32 * 64 * 2 @@ -47,7 +47,7 @@ def test_tl_store_generates_dma_write(): tl.store(0x2000, h) cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] assert len(cmds) == 1 - assert cmds[0].dst_pa == 0x2000 + assert cmds[0].dst_addr == 0x2000 assert cmds[0].nbytes == 16 * 16 * 4 @@ -148,7 +148,7 @@ def test_tl_composite_nonblocking(): comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)] assert len(comp_cmds) == 1 assert comp_cmds[0].op == "gemm" - assert comp_cmds[0].out_pa == 0x3000 + assert comp_cmds[0].out_addr == 0x3000 assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes diff --git a/tests/test_va_allocator.py b/tests/test_va_allocator.py new file mode 100644 index 0000000..4461364 --- /dev/null +++ b/tests/test_va_allocator.py @@ -0,0 +1,140 @@ +"""Tests for VirtualAllocator: device-wide VA space management. + +Validates: + T7. Basic VA allocation (contiguous, non-overlapping) + T8. VA free + reallocation (free-list reuse) + T9. VA space exhaustion raises AllocationError +""" +import pytest + +from kernbench.policy.address.va_allocator import VirtualAllocator + +_KB = 1024 +_MB = 1024 * 1024 +_GB = 1024 * 1024 * 1024 + + +# ── T7. Basic VA allocation ────────────────────────────────────────── + + +def test_alloc_returns_aligned_va(): + """First allocation returns va_base.""" + va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096) + addr = va.alloc(4096) + assert addr == 0x1_0000_0000 + + +def test_alloc_sequential_non_overlapping(): + """Two allocations return contiguous, non-overlapping VA ranges.""" + va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096) + a1 = va.alloc(4096) + a2 = va.alloc(8192) + assert a1 == 0x1_0000_0000 + assert a2 == 0x1_0000_1000 # a1 + 4096 + # No overlap + assert a2 >= a1 + 4096 + + +def test_alloc_page_aligned(): + """Allocations are page-aligned even if requested size is not page-multiple.""" + va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096) + a1 = va.alloc(100) # < 1 page, but occupies 1 page + a2 = va.alloc(100) + assert a2 == 0x1_0000_1000 # aligned to next page + + +def test_alloc_large_contiguous(): + """Large allocation (multiple pages) is contiguous.""" + va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=2 * _MB) + addr = va.alloc(8 * _MB) # 4 pages + assert addr == 0x0 + # Next alloc starts after 8 MB + addr2 = va.alloc(2 * _MB) + assert addr2 == 8 * _MB + + +# ── T8. VA free + reallocation ─────────────────────────────────────── + + +def test_free_and_realloc(): + """Freed VA range can be reused by subsequent allocation.""" + va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096) + a1 = va.alloc(4096) + a2 = va.alloc(4096) + va.free(a1, 4096) + + # New alloc should reuse a1's range + a3 = va.alloc(4096) + assert a3 == a1 + + +def test_free_coalesce(): + """Freeing adjacent blocks allows larger reallocation.""" + va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096) + a1 = va.alloc(4096) + a2 = va.alloc(4096) + a3 = va.alloc(4096) + + # Free first two (adjacent) + va.free(a1, 4096) + va.free(a2, 4096) + + # Should be able to allocate 8192 contiguous from the freed region + a4 = va.alloc(8192) + assert a4 == a1 # reuses coalesced region + + +def test_free_out_of_order(): + """Free in non-sequential order still works.""" + va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096) + a1 = va.alloc(4096) + a2 = va.alloc(4096) + a3 = va.alloc(4096) + + va.free(a2, 4096) # free middle + a4 = va.alloc(4096) + assert a4 == a2 # reuses middle slot + + +# ── T9. VA space exhaustion ────────────────────────────────────────── + + +def test_alloc_exhaustion(): + """Allocation beyond VA space raises AllocationError.""" + va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096) + va.alloc(4096) + va.alloc(4096) + with pytest.raises(Exception, match="[Aa]lloc|[Ee]xhaust|[Oo]ut of"): + va.alloc(4096) + + +def test_alloc_after_partial_free(): + """After freeing some, can allocate again within freed space.""" + va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096) + a1 = va.alloc(4096) + a2 = va.alloc(4096) + + # Space is full + with pytest.raises(Exception): + va.alloc(4096) + + # Free one, now can allocate again + va.free(a1, 4096) + a3 = va.alloc(4096) + assert a3 == a1 + + +# ── Stats / inspection ─────────────────────────────────────────────── + + +def test_used_and_total(): + """used and total properties reflect allocation state.""" + va = VirtualAllocator(va_base=0x0, va_size=1 * _MB, page_size=4096) + assert va.used == 0 + assert va.total == 1 * _MB + + va.alloc(4096) + assert va.used == 4096 + + va.alloc(8192) + assert va.used == 4096 + 8192 # page-aligned: 4096 + 8192 = 12288 diff --git a/tests/test_va_integration.py b/tests/test_va_integration.py new file mode 100644 index 0000000..a173cbb --- /dev/null +++ b/tests/test_va_integration.py @@ -0,0 +1,230 @@ +"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA. + +Validates: + T10. TensorHandle has va_base; TensorShard does NOT have va field + T11. deploy_tensor allocates VA + creates mapping entries + T12. Tensor.va returns the tensor's VA base + T13. tl.load/tl.store generate DMA commands with VA (not PA) + T14. Kernel VA-based offset calculation flows through DMA commands +""" +import pytest + +from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator +from kernbench.policy.address.pe_mmu import PeMMU +from kernbench.policy.address.va_allocator import VirtualAllocator +from kernbench.policy.placement.dp import column_wise, ShardSpec +from kernbench.runtime_api.tensor import ( + TensorHandle, + TensorShard, + deploy_tensor, +) +from kernbench.runtime_api.kernel import TensorArgShard +from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd +from kernbench.triton_emu.tl_context import TLContext, run_kernel + +_MB = 1 << 20 +_GB = 1 << 30 + +_CFG = AddressConfig( + sip_count=2, + cubes_per_sip=16, + pes_per_cube=8, + hbm_bytes_per_cube=48 * _GB, + hbm_slices_per_cube=8, + tcm_bytes_per_pe=16 * _MB, + tcm_scheduler_reserved_bytes=4 * _MB, + sram_bytes_per_cube=32 * _MB, +) + + +def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]: + return { + i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) + for i in range(num_pe) + } + + +def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]: + return {i: PeMMU(page_size=page_size) for i in range(num_pe)} + + +def _make_va_allocator() -> VirtualAllocator: + return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB) + + +# ── T10. TensorHandle has va_base ──────────────────────────────────── + + +def test_tensor_handle_has_va_base(): + """TensorHandle must have a 'va_base' field.""" + th = TensorHandle( + name="A", shape=(1024, 512), dtype="fp16", itemsize=2, + shards=(), va_base=0x1_0000_0000, + ) + assert th.va_base == 0x1_0000_0000 + + +def test_tensor_handle_va_base_immutable(): + """TensorHandle.va_base is immutable (frozen dataclass).""" + th = TensorHandle( + name="A", shape=(1024, 512), dtype="fp16", itemsize=2, + shards=(), va_base=0x1_0000_0000, + ) + with pytest.raises(AttributeError): + th.va_base = 0x2_0000_0000 # type: ignore[misc] + + +def test_tensor_shard_no_va_field(): + """TensorShard should NOT have a va field — va is derived from + TensorHandle.va_base + shard.offset_bytes.""" + ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0) + assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field" + + +# ── T11. deploy_tensor allocates VA + creates mappings ─────────────── + + +def test_deploy_tensor_assigns_va_base(): + """deploy_tensor with VA allocator assigns va_base to TensorHandle.""" + allocs = _make_allocators() + va_alloc = _make_va_allocator() + mmus = _make_mmus() + placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + + th = deploy_tensor( + name="W", + shape=(1024, 512), + dtype="fp16", + placement=placement, + allocators=allocs, + va_allocator=va_alloc, + mmus=mmus, + ) + + assert th.va_base is not None + assert th.va_base > 0 + + +def test_deploy_tensor_va_covers_all_shards(): + """VA allocation covers the entire tensor; each shard is at va_base + offset.""" + allocs = _make_allocators() + va_alloc = _make_va_allocator() + mmus = _make_mmus() + placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + + th = deploy_tensor( + name="W", + shape=(1024, 512), + dtype="fp16", + placement=placement, + allocators=allocs, + va_allocator=va_alloc, + mmus=mmus, + ) + + # Each shard's VA is derivable: va_base + offset_bytes + for s in th.shards: + shard_va = th.va_base + s.offset_bytes + assert shard_va > 0 + + +def test_deploy_tensor_registers_mmu_mappings(): + """deploy_tensor registers VA→PA mappings in all PE MMUs.""" + allocs = _make_allocators() + va_alloc = _make_va_allocator() + mmus = _make_mmus() + placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) + + th = deploy_tensor( + name="W", + shape=(1024, 512), + dtype="fp16", + placement=placement, + allocators=allocs, + va_allocator=va_alloc, + mmus=mmus, + ) + + # Every MMU should have entries (broadcast) + for mmu in mmus.values(): + assert mmu.num_entries > 0 + + # Each shard's derived VA should translate to its PA in every MMU + for mmu in mmus.values(): + for s in th.shards: + shard_va = th.va_base + s.offset_bytes + assert mmu.translate(shard_va) == s.pa + + +# ── T12. Tensor.va property ────────────────────────────────────────── + + +def test_tensor_va_property(): + """Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base).""" + from kernbench.runtime_api.tensor import Tensor + + allocs = _make_allocators(1) + va_alloc = _make_va_allocator() + mmus = _make_mmus(1) + placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)] + + t = Tensor(shape=(2048,), dtype="f16", name="test") + t._handle = deploy_tensor( + name="test", + shape=(2048,), + dtype="f16", + placement=placement, + allocators=allocs, + va_allocator=va_alloc, + mmus=mmus, + ) + assert t.va > 0 + assert t.va == t._handle.va_base + + +# ── T13. tl.load/tl.store use VA ───────────────────────────────────── + + +def test_tl_load_uses_va_in_dma_cmd(): + """tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa).""" + tl = TLContext(dispatch_cycles=0) + va_ptr = 0x1_0000_0000 + h = tl.load(va_ptr, shape=(32, 64), dtype="f16") + + dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)] + assert len(dma_cmds) == 1 + # The DMA command should carry the VA + assert dma_cmds[0].src_addr == va_ptr + + +def test_tl_store_uses_va_in_dma_cmd(): + """tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va.""" + tl = TLContext(dispatch_cycles=0) + h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32") + va_out = 0x2_0000_0000 + tl.store(va_out, h) + + dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] + assert len(dma_cmds) == 1 + assert dma_cmds[0].dst_addr == va_out + + +# ── T14. Kernel VA offset calculation ───────────────────────────────── + + +def test_kernel_va_offset_in_dma(): + """Kernel using base_va + pid * stride generates correct VA in DmaReadCmd.""" + def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"): + pid = tl.program_id(0) + elem_bytes = 2 # f16 + offset = pid * BLOCK_SIZE * elem_bytes + a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE) + + va_base = 0x1_0000_0000 + tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0) + run_kernel(tiled_kernel, tl, a_ptr=va_base) + + dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)] + assert len(dma_cmds) == 1 + expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes + assert dma_cmds[0].src_addr == expected_va diff --git a/topology.yaml b/topology.yaml index 4d5e7b5..9fce8f9 100644 --- a/topology.yaml +++ b/topology.yaml @@ -65,6 +65,7 @@ cube: pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } } pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } } pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } } + pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } } pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs: { size_mb: 16 } } links: