08812eda58
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
204 lines
6.9 KiB
Python
204 lines
6.9 KiB
Python
"""Tests for PeMMU: per-PE virtual-to-physical address translation.
|
|
|
|
Validates:
|
|
T1. Basic map + translate
|
|
T2. Page-aligned dict lookup (O(1), multi-page range)
|
|
T3. Multiple tensor mappings accumulate
|
|
T4. unmap removes entries, translate raises PageFault
|
|
T5. PageFault on unmapped VA
|
|
T6. Identical mapping broadcast across multiple PEs
|
|
"""
|
|
import pytest
|
|
|
|
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
|
|
|
|
_2MB = 2 * 1024 * 1024
|
|
|
|
|
|
# ── T1. Basic map + translate ────────────────────────────────────────
|
|
|
|
|
|
def test_map_and_translate_basic():
|
|
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
|
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
|
|
|
|
|
def test_translate_preserves_offset():
|
|
"""translate(va + offset) returns pa + offset within a page."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
|
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
|
|
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
|
|
|
|
|
|
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
|
|
|
|
|
|
def test_multi_page_mapping():
|
|
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
|
|
mmu = PeMMU(page_size=_2MB)
|
|
va_base = 0x1_0000_0000
|
|
pa_base = 0x2_0000_0000
|
|
size = 8 * 1024 * 1024 # 8 MB = 4 pages
|
|
|
|
mmu.map(va=va_base, pa=pa_base, size=size)
|
|
|
|
# First page
|
|
assert mmu.translate(va_base) == pa_base
|
|
# Second page start
|
|
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
|
|
# Third page with offset
|
|
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
|
|
# Last byte of last page
|
|
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
|
|
|
|
|
|
def test_page_table_entry_count():
|
|
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
|
|
mmu = PeMMU(page_size=_2MB)
|
|
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
|
|
assert mmu.num_entries == 4
|
|
|
|
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
|
|
assert mmu.num_entries == 5
|
|
|
|
|
|
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
|
|
|
|
|
|
def test_multiple_mappings_accumulate():
|
|
"""Two non-overlapping tensors → both translate correctly."""
|
|
mmu = PeMMU(page_size=4096)
|
|
# Tensor A
|
|
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
|
# Tensor B (different VA range)
|
|
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
|
|
|
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
|
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
|
|
|
|
|
def test_mappings_do_not_interfere():
|
|
"""Adjacent VA ranges map to completely independent PA ranges."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
|
|
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
|
|
|
|
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
|
|
assert mmu.translate(0x0000_1000) == 0x0000_0000
|
|
|
|
|
|
# ── T4. unmap removes entries ────────────────────────────────────────
|
|
|
|
|
|
def test_unmap_removes_mapping():
|
|
"""After unmap, translate raises PageFault."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
|
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
|
|
|
mmu.unmap(va=0x1_0000_0000, size=4096)
|
|
with pytest.raises(PageFault):
|
|
mmu.translate(0x1_0000_0000)
|
|
|
|
|
|
def test_unmap_partial_range():
|
|
"""Unmap only part of a multi-page mapping; rest stays valid."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
|
|
assert mmu.num_entries == 2
|
|
|
|
# Unmap first page only
|
|
mmu.unmap(va=0x1000_0000, size=4096)
|
|
assert mmu.num_entries == 1
|
|
|
|
with pytest.raises(PageFault):
|
|
mmu.translate(0x1000_0000)
|
|
# Second page still valid
|
|
assert mmu.translate(0x1000_1000) == 0x2000_1000
|
|
|
|
|
|
def test_unmap_does_not_affect_other_mappings():
|
|
"""Unmapping tensor A does not affect tensor B."""
|
|
mmu = PeMMU(page_size=4096)
|
|
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
|
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
|
|
|
mmu.unmap(va=0x1_0000_0000, size=4096)
|
|
with pytest.raises(PageFault):
|
|
mmu.translate(0x1_0000_0000)
|
|
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
|
|
|
|
|
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
|
|
|
|
|
|
def test_pagefault_on_unmapped_va():
|
|
"""translate() on never-mapped VA raises PageFault."""
|
|
mmu = PeMMU(page_size=4096)
|
|
with pytest.raises(PageFault):
|
|
mmu.translate(0xDEAD_BEEF)
|
|
|
|
|
|
def test_pagefault_contains_va():
|
|
"""PageFault exception carries the faulting VA."""
|
|
mmu = PeMMU(page_size=4096)
|
|
with pytest.raises(PageFault, match="0xdeadbeef"):
|
|
mmu.translate(0xDEAD_BEEF)
|
|
|
|
|
|
# ── T6. Identical mapping broadcast across PEs ───────────────────────
|
|
|
|
|
|
def test_broadcast_same_mapping_to_all_pes():
|
|
"""All PEs receive the same full mapping → identical translate results."""
|
|
entries = [
|
|
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
|
|
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
|
|
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
|
|
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
|
|
]
|
|
num_pes = 8
|
|
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
|
|
|
|
# Broadcast: every PE gets the same entries
|
|
for mmu in mmus:
|
|
for va, pa, size in entries:
|
|
mmu.map(va=va, pa=pa, size=size)
|
|
|
|
# All PEs translate identically
|
|
for mmu in mmus:
|
|
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
|
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
|
|
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
|
|
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
|
|
|
|
|
|
def test_cross_pe_access_via_broadcast():
|
|
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
|
|
mmu_pe0 = PeMMU(page_size=4096)
|
|
# Full mapping includes PE3's shard
|
|
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
|
|
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
|
|
|
|
# PE0 accesses PE3's region → valid translation
|
|
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
|
|
assert pa == 0xD000_0100
|
|
|
|
|
|
# ── TLB overhead attribute ───────────────────────────────────────────
|
|
|
|
|
|
def test_tlb_overhead_default():
|
|
"""Default tlb_overhead_ns is 0."""
|
|
mmu = PeMMU()
|
|
assert mmu.overhead_ns == 0.0
|
|
|
|
|
|
def test_tlb_overhead_configurable():
|
|
"""tlb_overhead_ns is configurable."""
|
|
mmu = PeMMU(overhead_ns=0.5)
|
|
assert mmu.overhead_ns == 0.5
|