Files
kernbench2/tests/test_mmu_component.py
T
ywkang 81ce55571d Rename impl names: add builtin. prefix for clear provenance
- components.yaml: all builtin impls use builtin.xxx naming
- topology.yaml: all impl references updated to builtin.xxx
- builder.py: hardcoded ucie impl → builtin.ucie
- Tests: all impl string references updated

Convention: builtin.<name> for built-in, custom.<name> for user-defined.
382 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 00:16:24 -07:00

227 lines
7.4 KiB
Python

"""Tests for PE_MMU component integration and MmuMapMsg fabric path.
Validates:
T15-a. PE_MMU component registered in ComponentRegistry
T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table
T15-c. PE_DMA translates VA→PA via mmu before routing
T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields
T17. PE_CPU passes VA (not PA) to kernel when VA is available
T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA
"""
import pytest
try:
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
except ImportError:
pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True)
# ── T16. MmuMapMsg / MmuUnmapMsg message types ──────────────────────
def test_mmu_map_msg_fields():
"""MmuMapMsg carries VA→PA mapping entries for broadcast."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_map"
assert len(msg.entries) == 2
assert msg.entries[0]["va"] == 0x1_0000_0000
assert msg.entries[0]["pa"] == 0xA000_0000
assert msg.entries[0]["size"] == 4096
def test_mmu_map_msg_immutable():
"""MmuMapMsg is frozen."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(),
target_cubes="all",
target_pe="all",
)
with pytest.raises(AttributeError):
msg.entries = () # type: ignore[misc]
def test_mmu_unmap_msg_fields():
"""MmuUnmapMsg carries VA ranges to unmap."""
msg = MmuUnmapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_unmap"
assert len(msg.entries) == 1
assert msg.entries[0]["va"] == 0x1_0000_0000
# ── T15-a. PE_MMU component registry ────────────────────────────────
def test_pe_mmu_registry():
"""pe_mmu impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="builtin.pe_mmu",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5},
)
comp = ComponentRegistry.create(node)
assert isinstance(comp, PeMmuComponent)
# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ─────────
def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works."""
import simpy
from kernbench.components.builtin.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node
env = simpy.Environment()
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="builtin.pe_mmu",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5, "page_size": 4096},
)
comp = PeMmuComponent(node)
comp.in_ports["src"] = simpy.Store(env)
comp.start(env)
# Submit MmuMapMsg via inbox
map_msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
done = env.event()
txn = Transaction(
request=map_msg,
path=["sip0.cube0.pe0.pe_mmu"],
step=0, nbytes=0, done=done,
)
def inject():
yield comp._inbox.put(txn)
env.process(inject())
env.run(until=100)
# After processing, the MMU's translate should work
from kernbench.policy.address.pe_mmu import PeMMU
mmu = comp.mmu # the underlying PeMMU utility object
assert isinstance(mmu, PeMMU)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
# ── T15-c. PE_DMA uses MMU translate ────────────────────────────────
def test_pe_dma_translates_va():
"""PE_DMA.handle_command calls mmu.translate(va) → PA before routing.
This test validates the contract: after Phase 2, DmaReadCmd carries VA,
and PE_DMA must translate it to PA via the MMU before resolving the
HBM node path.
"""
# This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.builtin.pe_dma import PeDmaComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_dma",
kind="pe_dma",
impl="builtin.pe_dma",
pos_mm=None,
attrs={"rd_engines": 1, "wr_engines": 1},
)
comp = PeDmaComponent(node)
# PE_DMA must have a way to access the MMU (via ctx or direct reference)
# The exact wiring mechanism is flexible, but the attribute must exist
assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or (
hasattr(comp, 'ctx') and comp.ctx is not None
), "PE_DMA must have access to PE_MMU for VA translation"
# ── T17. PE_CPU passes VA to kernel ──────────────────────────────────
def test_pe_cpu_uses_va_base_from_tensor_arg():
"""PE_CPU should use TensorArg.va_base for kernel pointer args.
After Phase 2, TensorArg carries va_base alongside shards.
PE_CPU extracts va_base and passes it to the kernel function
so kernels operate on VA (not PA).
"""
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000,
nbytes=4096, offset_bytes=0)
targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000)
# PE_CPU should use targ.va_base for kernel pointer arg
assert targ.va_base == 0x1_0000_0000
# PA still accessible via shard for direct-PA operations (IPCQ etc.)
assert shard.pa == 0x1000
# ── T18. MmuMapMsg broadcast pattern ─────────────────────────────────
def test_mmu_map_msg_broadcast_target():
"""MmuMapMsg with target_pe='all' is a broadcast to all PEs."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},),
target_cubes="all",
target_pe="all",
)
assert msg.target_pe == "all"
assert msg.target_cubes == "all"
def test_mmu_map_msg_same_entries_all_pes():
"""All PEs in a broadcast receive identical entries (not per-PE splits)."""
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192},
{"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192},
)
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=entries,
target_cubes="all",
target_pe="all",
)
# The message carries the full mapping — every PE receives exactly this
assert msg.entries == entries