Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,241 @@
|
||||
"""Tests for MmuMapMsg fabric path and cross-cube mapping.
|
||||
|
||||
Validates:
|
||||
F1. MmuMapMsg traverses fabric: latency > 0 (not sideband)
|
||||
F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs
|
||||
F3. After MmuMapMsg, PE_MMU has correct mappings
|
||||
F4. Cross-cube sharded tensor: all PEs get global mappings
|
||||
F5. Replicate tensor: each PE gets own cube's PA (local override)
|
||||
F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM
|
||||
F7. Overlap detection: replicate vs sharded identified correctly
|
||||
F8. Existing regression: PA-only benchmarks still pass
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec
|
||||
from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _engine():
|
||||
return GraphEngine(load_topology(TOPOLOGY_PATH))
|
||||
|
||||
|
||||
# ── F1. MmuMapMsg fabric latency ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_via_fabric_has_latency():
|
||||
"""MmuMapMsg submitted through engine.submit() completes with latency > 0."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_0",
|
||||
entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
# Fabric traversal must have non-zero latency
|
||||
assert trace is not None
|
||||
assert trace.get("total_ns", 0) > 0
|
||||
|
||||
|
||||
# ── F2. MmuMapMsg fan-out ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_reaches_all_pes_in_cube():
|
||||
"""MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_1",
|
||||
entries=({"va": va, "pa": pa, "size": size},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
# Verify all 8 PE_MMUs in cube 0 have the mapping
|
||||
for pe_id in range(8):
|
||||
mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu"
|
||||
mmu_comp = engine._components[mmu_id]
|
||||
assert mmu_comp.mmu.translate(va) == pa
|
||||
|
||||
|
||||
# ── F3. Multiple MmuMapMsg entries ───────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_multiple_entries():
|
||||
"""MmuMapMsg with multiple entries installs all of them."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
entries = (
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
|
||||
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
|
||||
)
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_2",
|
||||
entries=entries,
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||
assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000
|
||||
|
||||
|
||||
# ── F4. Cross-cube sharded: global mapping ───────────────────────────
|
||||
|
||||
|
||||
def test_cross_cube_sharded_all_pes_get_global_mapping():
|
||||
"""For sharded tensor across cubes (unique offsets), all PEs get all mappings."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
# Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096
|
||||
entries = (
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0
|
||||
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1
|
||||
)
|
||||
# Broadcast to both cubes
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_xc",
|
||||
entries=entries,
|
||||
target_cubes=(0, 1),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
# PE in cube0 can translate both cube0 and cube1 addresses
|
||||
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local
|
||||
assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote
|
||||
|
||||
# PE in cube1 can also translate both
|
||||
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||
assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote
|
||||
assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local
|
||||
|
||||
|
||||
# ── F5. Replicate: local PA override ─────────────────────────────────
|
||||
|
||||
|
||||
def test_replicate_local_pa_override():
|
||||
"""For replicated tensor (same VA range), each cube's PEs see local PA."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
va, size = 0x1_0000_0000, 4096
|
||||
|
||||
# Cube 0 gets its own PA
|
||||
msg0 = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_rep_c0",
|
||||
entries=({"va": va, "pa": 0xA000_0000, "size": size},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h0 = engine.submit(msg0)
|
||||
engine.wait(h0)
|
||||
|
||||
# Cube 1 gets a different PA for the same VA
|
||||
msg1 = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_rep_c1",
|
||||
entries=({"va": va, "pa": 0xB000_0000, "size": size},),
|
||||
target_cubes=(1,),
|
||||
target_pe="all",
|
||||
)
|
||||
h1 = engine.submit(msg1)
|
||||
engine.wait(h1)
|
||||
|
||||
# Cube 0 PEs translate to cube 0's PA
|
||||
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_c0.mmu.translate(va) == 0xA000_0000
|
||||
|
||||
# Cube 1 PEs translate to cube 1's PA
|
||||
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||
assert mmu_c1.mmu.translate(va) == 0xB000_0000
|
||||
|
||||
|
||||
# ── F7. Overlap detection ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_overlapping_shards():
|
||||
"""Utility: detect if shards have overlapping VA ranges (replicate indicator)."""
|
||||
from kernbench.runtime_api.tensor import TensorShard
|
||||
|
||||
# Sharded: unique offsets
|
||||
sharded = [
|
||||
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||
TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096),
|
||||
]
|
||||
offsets = [(s.offset_bytes, s.nbytes) for s in sharded]
|
||||
assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets"
|
||||
|
||||
# Replicated: same offset
|
||||
replicated = [
|
||||
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||
TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0),
|
||||
]
|
||||
offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated]
|
||||
assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets"
|
||||
|
||||
|
||||
# ── F8. Regression: existing benchmarks still pass ───────────────────
|
||||
|
||||
|
||||
def test_qkv_gemm_still_passes():
|
||||
"""QKV GEMM benchmark completes successfully with VA/MMU enabled."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import BenchResult, DeviceSelector
|
||||
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="test_regression",
|
||||
spec=graph.spec,
|
||||
)
|
||||
from benches.qkv_gemm import run as bench_run
|
||||
bench_run(ctx)
|
||||
ctx.wait_all()
|
||||
# If we get here without exception, the benchmark succeeded
|
||||
Reference in New Issue
Block a user