"""Tests for MmuMapMsg fabric path and cross-cube mapping. Validates: F1. MmuMapMsg traverses fabric: latency > 0 (not sideband) F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs F3. After MmuMapMsg, PE_MMU has correct mappings F4. Cross-cube sharded tensor: all PEs get global mappings F5. Replicate tensor: each PE gets own cube's PA (local override) F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM F7. Overlap detection: replicate vs sharded identified correctly F8. Existing regression: PA-only benchmarks still pass """ import pytest from pathlib import Path from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import load_topology TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" _MB = 1 << 20 _GB = 1 << 30 _CFG = AddressConfig( sip_count=2, cubes_per_sip=16, pes_per_cube=8, hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=8, tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, sram_bytes_per_cube=32 * _MB, ) def _engine(): return GraphEngine(load_topology(TOPOLOGY_PATH)) # ── F1. MmuMapMsg fabric latency ───────────────────────────────────── def test_mmu_map_via_fabric_has_latency(): """MmuMapMsg submitted through engine.submit() completes with latency > 0.""" from kernbench.runtime_api.kernel import MmuMapMsg engine = _engine() msg = MmuMapMsg( correlation_id="c0", request_id="mmu_map_0", entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},), target_cubes=(0,), target_pe="all", ) h = engine.submit(msg) engine.wait(h) comp, trace = engine.get_completion(h) assert comp.ok is True # Fabric traversal must have non-zero latency assert trace is not None assert trace.get("total_ns", 0) > 0 # ── F2. MmuMapMsg fan-out ──────────────────────────────────────────── def test_mmu_map_reaches_all_pes_in_cube(): """MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube.""" from kernbench.runtime_api.kernel import MmuMapMsg engine = _engine() va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096 msg = MmuMapMsg( correlation_id="c0", request_id="mmu_map_1", entries=({"va": va, "pa": pa, "size": size},), target_cubes=(0,), target_pe="all", ) h = engine.submit(msg) engine.wait(h) # Verify all 8 PE_MMUs in cube 0 have the mapping for pe_id in range(8): mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu" mmu_comp = engine._components[mmu_id] assert mmu_comp.mmu.translate(va) == pa # ── F3. Multiple MmuMapMsg entries ─────────────────────────────────── def test_mmu_map_multiple_entries(): """MmuMapMsg with multiple entries installs all of them.""" from kernbench.runtime_api.kernel import MmuMapMsg engine = _engine() entries = ( {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, {"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, ) msg = MmuMapMsg( correlation_id="c0", request_id="mmu_map_2", entries=entries, target_cubes=(0,), target_pe="all", ) h = engine.submit(msg) engine.wait(h) mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"] assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000 assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000 # ── F4. Cross-cube sharded: global mapping ─────────────────────────── def test_cross_cube_sharded_all_pes_get_global_mapping(): """For sharded tensor across cubes (unique offsets), all PEs get all mappings.""" from kernbench.runtime_api.kernel import MmuMapMsg engine = _engine() # Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096 entries = ( {"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0 {"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1 ) # Broadcast to both cubes msg = MmuMapMsg( correlation_id="c0", request_id="mmu_map_xc", entries=entries, target_cubes=(0, 1), target_pe="all", ) h = engine.submit(msg) engine.wait(h) # PE in cube0 can translate both cube0 and cube1 addresses mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"] assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote # PE in cube1 can also translate both mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"] assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local # ── F5. Replicate: local PA override ───────────────────────────────── def test_replicate_local_pa_override(): """For replicated tensor (same VA range), each cube's PEs see local PA.""" from kernbench.runtime_api.kernel import MmuMapMsg engine = _engine() va, size = 0x1_0000_0000, 4096 # Cube 0 gets its own PA msg0 = MmuMapMsg( correlation_id="c0", request_id="mmu_rep_c0", entries=({"va": va, "pa": 0xA000_0000, "size": size},), target_cubes=(0,), target_pe="all", ) h0 = engine.submit(msg0) engine.wait(h0) # Cube 1 gets a different PA for the same VA msg1 = MmuMapMsg( correlation_id="c0", request_id="mmu_rep_c1", entries=({"va": va, "pa": 0xB000_0000, "size": size},), target_cubes=(1,), target_pe="all", ) h1 = engine.submit(msg1) engine.wait(h1) # Cube 0 PEs translate to cube 0's PA mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"] assert mmu_c0.mmu.translate(va) == 0xA000_0000 # Cube 1 PEs translate to cube 1's PA mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"] assert mmu_c1.mmu.translate(va) == 0xB000_0000 # ── F7. Overlap detection ──────────────────────────────────────────── def test_detect_overlapping_shards(): """Utility: detect if shards have overlapping VA ranges (replicate indicator).""" from kernbench.runtime_api.tensor import TensorShard # Sharded: unique offsets sharded = [ TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0), TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096), ] offsets = [(s.offset_bytes, s.nbytes) for s in sharded] assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets" # Replicated: same offset replicated = [ TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0), TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0), ] offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated] assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets" # ── F8. Regression: existing benchmarks still pass ─────────────────── def test_qkv_gemm_still_passes(): """QKV GEMM benchmark completes successfully with VA/MMU enabled.""" from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import BenchResult, DeviceSelector graph = load_topology(TOPOLOGY_PATH) engine = GraphEngine(graph) ctx = RuntimeContext( engine=engine, target_device=DeviceSelector("sip:0"), correlation_id="test_regression", spec=graph.spec, ) from kernbench.benches.qkv_gemm import run as bench_run bench_run(ctx) ctx.wait_all() # If we get here without exception, the benchmark succeeded