Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,193 @@
|
||||
"""Tests for tensor free: del-based + context manager cleanup.
|
||||
|
||||
Validates:
|
||||
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
|
||||
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
|
||||
TF3. Context manager cleans up all tensors on exit
|
||||
TF4. del after context exit is safe (no crash)
|
||||
TF5. Alloc-del-alloc cycle reuses VA and PA
|
||||
TF6. del already-freed tensor is safe (no crash)
|
||||
"""
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PageFault
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_ctx():
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="test_free",
|
||||
spec=graph.spec,
|
||||
)
|
||||
return ctx, engine
|
||||
|
||||
|
||||
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_allocator_free_hbm_reclaims_space():
|
||||
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
|
||||
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||
pa1 = a.alloc_hbm(4096)
|
||||
used_after_alloc = a.hbm_used
|
||||
a.free_hbm(pa1, 4096)
|
||||
assert a.hbm_used == used_after_alloc - 4096
|
||||
pa2 = a.alloc_hbm(4096)
|
||||
assert pa2 is not None
|
||||
|
||||
|
||||
def test_allocator_free_tcm_reclaims_space():
|
||||
"""free_tcm returns TCM space."""
|
||||
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||
pa1 = a.alloc_tcm(256)
|
||||
used_after_alloc = a.tcm_used
|
||||
a.free_tcm(pa1, 256)
|
||||
assert a.tcm_used == used_after_alloc - 256
|
||||
|
||||
|
||||
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
|
||||
|
||||
|
||||
def test_del_tensor_unmaps_mmu():
|
||||
"""del tensor removes MMU mappings."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
|
||||
va_base = t._handle.va_base
|
||||
|
||||
# Verify mapping exists
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_comp.mmu.translate(va_base) is not None
|
||||
|
||||
# Delete tensor
|
||||
del t
|
||||
gc.collect()
|
||||
|
||||
# Mapping should be gone
|
||||
with pytest.raises(PageFault):
|
||||
mmu_comp.mmu.translate(va_base)
|
||||
|
||||
|
||||
def test_del_tensor_reclaims_va():
|
||||
"""del tensor returns VA space for reuse."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
|
||||
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
|
||||
va1 = t1._handle.va_base
|
||||
|
||||
del t1
|
||||
gc.collect()
|
||||
|
||||
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
|
||||
va2 = t2._handle.va_base
|
||||
assert va2 == va1
|
||||
|
||||
|
||||
# ── TF3. Context manager cleanup ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_context_manager_cleans_all():
|
||||
"""Exiting context manager cleans up all tensors."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="ctx_mgr",
|
||||
spec=graph.spec,
|
||||
) as ctx:
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
|
||||
va_base = t._handle.va_base
|
||||
|
||||
# After context exit, MMU mappings should be cleared
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
with pytest.raises(PageFault):
|
||||
mmu_comp.mmu.translate(va_base)
|
||||
|
||||
|
||||
# ── TF4. del after context exit is safe ───────────────────────────────
|
||||
|
||||
|
||||
def test_del_after_context_exit_no_crash():
|
||||
"""del tensor after context manager exit does not crash."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="safe_del",
|
||||
spec=graph.spec,
|
||||
)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
|
||||
|
||||
# Simulate context going away
|
||||
ctx.cleanup()
|
||||
|
||||
# del should not crash even though context already cleaned up
|
||||
del t
|
||||
gc.collect()
|
||||
# No exception = pass
|
||||
|
||||
|
||||
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_alloc_del_cycle():
|
||||
"""Multiple alloc-del cycles work correctly."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
|
||||
for i in range(3):
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
|
||||
assert t._handle is not None
|
||||
assert t._handle.va_base > 0
|
||||
del t
|
||||
gc.collect()
|
||||
|
||||
|
||||
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
|
||||
|
||||
|
||||
def test_del_already_cleaned_tensor_no_crash():
|
||||
"""del on a tensor whose handle is already None does not crash."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
|
||||
|
||||
ctx.cleanup() # clears all tensors
|
||||
# t._handle is now None
|
||||
del t # should not crash
|
||||
gc.collect()
|
||||
Reference in New Issue
Block a user