Files
ywkang 08812eda58 Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 00:01:47 -07:00

204 lines
6.9 KiB
Python

"""Tests for PeMMU: per-PE virtual-to-physical address translation.
Validates:
T1. Basic map + translate
T2. Page-aligned dict lookup (O(1), multi-page range)
T3. Multiple tensor mappings accumulate
T4. unmap removes entries, translate raises PageFault
T5. PageFault on unmapped VA
T6. Identical mapping broadcast across multiple PEs
"""
import pytest
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
_2MB = 2 * 1024 * 1024
# ── T1. Basic map + translate ────────────────────────────────────────
def test_map_and_translate_basic():
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
def test_translate_preserves_offset():
"""translate(va + offset) returns pa + offset within a page."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
def test_multi_page_mapping():
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
mmu = PeMMU(page_size=_2MB)
va_base = 0x1_0000_0000
pa_base = 0x2_0000_0000
size = 8 * 1024 * 1024 # 8 MB = 4 pages
mmu.map(va=va_base, pa=pa_base, size=size)
# First page
assert mmu.translate(va_base) == pa_base
# Second page start
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
# Third page with offset
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
# Last byte of last page
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
def test_page_table_entry_count():
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
mmu = PeMMU(page_size=_2MB)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
assert mmu.num_entries == 4
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
assert mmu.num_entries == 5
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
def test_multiple_mappings_accumulate():
"""Two non-overlapping tensors → both translate correctly."""
mmu = PeMMU(page_size=4096)
# Tensor A
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
# Tensor B (different VA range)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
def test_mappings_do_not_interfere():
"""Adjacent VA ranges map to completely independent PA ranges."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
assert mmu.translate(0x0000_1000) == 0x0000_0000
# ── T4. unmap removes entries ────────────────────────────────────────
def test_unmap_removes_mapping():
"""After unmap, translate raises PageFault."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
def test_unmap_partial_range():
"""Unmap only part of a multi-page mapping; rest stays valid."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
assert mmu.num_entries == 2
# Unmap first page only
mmu.unmap(va=0x1000_0000, size=4096)
assert mmu.num_entries == 1
with pytest.raises(PageFault):
mmu.translate(0x1000_0000)
# Second page still valid
assert mmu.translate(0x1000_1000) == 0x2000_1000
def test_unmap_does_not_affect_other_mappings():
"""Unmapping tensor A does not affect tensor B."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
def test_pagefault_on_unmapped_va():
"""translate() on never-mapped VA raises PageFault."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault):
mmu.translate(0xDEAD_BEEF)
def test_pagefault_contains_va():
"""PageFault exception carries the faulting VA."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault, match="0xdeadbeef"):
mmu.translate(0xDEAD_BEEF)
# ── T6. Identical mapping broadcast across PEs ───────────────────────
def test_broadcast_same_mapping_to_all_pes():
"""All PEs receive the same full mapping → identical translate results."""
entries = [
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
]
num_pes = 8
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
# Broadcast: every PE gets the same entries
for mmu in mmus:
for va, pa, size in entries:
mmu.map(va=va, pa=pa, size=size)
# All PEs translate identically
for mmu in mmus:
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
def test_cross_pe_access_via_broadcast():
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
mmu_pe0 = PeMMU(page_size=4096)
# Full mapping includes PE3's shard
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
# PE0 accesses PE3's region → valid translation
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
assert pa == 0xD000_0100
# ── TLB overhead attribute ───────────────────────────────────────────
def test_tlb_overhead_default():
"""Default tlb_overhead_ns is 0."""
mmu = PeMMU()
assert mmu.overhead_ns == 0.0
def test_tlb_overhead_configurable():
"""tlb_overhead_ns is configurable."""
mmu = PeMMU(overhead_ns=0.5)
assert mmu.overhead_ns == 0.5