"""Tests for PeMMU: per-PE virtual-to-physical address translation. Validates: T1. Basic map + translate T2. Page-aligned dict lookup (O(1), multi-page range) T3. Multiple tensor mappings accumulate T4. unmap removes entries, translate raises PageFault T5. PageFault on unmapped VA T6. Identical mapping broadcast across multiple PEs """ import pytest from kernbench.policy.address.pe_mmu import PageFault, PeMMU _2MB = 2 * 1024 * 1024 # ── T1. Basic map + translate ──────────────────────────────────────── def test_map_and_translate_basic(): """map(va, pa, size) → translate(va) returns pa; offset preserved.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) assert mmu.translate(0x1_0000_0000) == 0xABCD_0000 def test_translate_preserves_offset(): """translate(va + offset) returns pa + offset within a page.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) assert mmu.translate(0x1_0000_0100) == 0xABCD_0100 assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF # ── T2. Page-aligned dict lookup (multi-page) ─────────────────────── def test_multi_page_mapping(): """8 MB mapping with 2 MB pages → 4 page entries, all translate correctly.""" mmu = PeMMU(page_size=_2MB) va_base = 0x1_0000_0000 pa_base = 0x2_0000_0000 size = 8 * 1024 * 1024 # 8 MB = 4 pages mmu.map(va=va_base, pa=pa_base, size=size) # First page assert mmu.translate(va_base) == pa_base # Second page start assert mmu.translate(va_base + _2MB) == pa_base + _2MB # Third page with offset assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100 # Last byte of last page assert mmu.translate(va_base + size - 1) == pa_base + size - 1 def test_page_table_entry_count(): """Mapping N bytes with page_size P creates ceil(N/P) entries.""" mmu = PeMMU(page_size=_2MB) mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024) assert mmu.num_entries == 4 mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB) assert mmu.num_entries == 5 # ── T3. Multiple tensor mappings accumulate ────────────────────────── def test_multiple_mappings_accumulate(): """Two non-overlapping tensors → both translate correctly.""" mmu = PeMMU(page_size=4096) # Tensor A mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # Tensor B (different VA range) mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096) assert mmu.translate(0x1_0000_0000) == 0xA000_0000 assert mmu.translate(0x1_0001_0000) == 0xB000_0000 def test_mappings_do_not_interfere(): """Adjacent VA ranges map to completely independent PA ranges.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096) mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096) assert mmu.translate(0x0000_0000) == 0xFFFF_0000 assert mmu.translate(0x0000_1000) == 0x0000_0000 # ── T4. unmap removes entries ──────────────────────────────────────── def test_unmap_removes_mapping(): """After unmap, translate raises PageFault.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096) assert mmu.translate(0x1_0000_0000) == 0xABCD_0000 mmu.unmap(va=0x1_0000_0000, size=4096) with pytest.raises(PageFault): mmu.translate(0x1_0000_0000) def test_unmap_partial_range(): """Unmap only part of a multi-page mapping; rest stays valid.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages assert mmu.num_entries == 2 # Unmap first page only mmu.unmap(va=0x1000_0000, size=4096) assert mmu.num_entries == 1 with pytest.raises(PageFault): mmu.translate(0x1000_0000) # Second page still valid assert mmu.translate(0x1000_1000) == 0x2000_1000 def test_unmap_does_not_affect_other_mappings(): """Unmapping tensor A does not affect tensor B.""" mmu = PeMMU(page_size=4096) mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096) mmu.unmap(va=0x1_0000_0000, size=4096) with pytest.raises(PageFault): mmu.translate(0x1_0000_0000) assert mmu.translate(0x1_0001_0000) == 0xB000_0000 # ── T5. PageFault on unmapped VA ───────────────────────────────────── def test_pagefault_on_unmapped_va(): """translate() on never-mapped VA raises PageFault.""" mmu = PeMMU(page_size=4096) with pytest.raises(PageFault): mmu.translate(0xDEAD_BEEF) def test_pagefault_contains_va(): """PageFault exception carries the faulting VA.""" mmu = PeMMU(page_size=4096) with pytest.raises(PageFault, match="0xdeadbeef"): mmu.translate(0xDEAD_BEEF) # ── T6. Identical mapping broadcast across PEs ─────────────────────── def test_broadcast_same_mapping_to_all_pes(): """All PEs receive the same full mapping → identical translate results.""" entries = [ (0x1_0000_0000, 0xA000_0000, 4096), # shard 0 (0x1_0000_1000, 0xB000_0000, 4096), # shard 1 (0x1_0000_2000, 0xC000_0000, 4096), # shard 2 (0x1_0000_3000, 0xD000_0000, 4096), # shard 3 ] num_pes = 8 mmus = [PeMMU(page_size=4096) for _ in range(num_pes)] # Broadcast: every PE gets the same entries for mmu in mmus: for va, pa, size in entries: mmu.map(va=va, pa=pa, size=size) # All PEs translate identically for mmu in mmus: assert mmu.translate(0x1_0000_0000) == 0xA000_0000 assert mmu.translate(0x1_0000_1000) == 0xB000_0000 assert mmu.translate(0x1_0000_2000) == 0xC000_0000 assert mmu.translate(0x1_0000_3000) == 0xD000_0000 def test_cross_pe_access_via_broadcast(): """PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario).""" mmu_pe0 = PeMMU(page_size=4096) # Full mapping includes PE3's shard mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard # PE0 accesses PE3's region → valid translation pa = mmu_pe0.translate(0x1_0000_1000 + 0x100) assert pa == 0xD000_0100 # ── TLB overhead attribute ─────────────────────────────────────────── def test_tlb_overhead_default(): """Default tlb_overhead_ns is 0.""" mmu = PeMMU() assert mmu.overhead_ns == 0.0 def test_tlb_overhead_configurable(): """tlb_overhead_ns is configurable.""" mmu = PeMMU(overhead_ns=0.5) assert mmu.overhead_ns == 0.5