Files
kernbench2/tests/test_tensor_free.py
T
ywkang 08812eda58 Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 00:01:47 -07:00

194 lines
6.0 KiB
Python

"""Tests for tensor free: del-based + context manager cleanup.
Validates:
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
TF3. Context manager cleans up all tensors on exit
TF4. del after context exit is safe (no crash)
TF5. Alloc-del-alloc cycle reuses VA and PA
TF6. del already-freed tensor is safe (no crash)
"""
import gc
import pytest
from pathlib import Path
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PageFault
from kernbench.policy.placement.dp import DPPolicy
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_ctx():
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_free",
spec=graph.spec,
)
return ctx, engine
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
def test_allocator_free_hbm_reclaims_space():
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_hbm(4096)
used_after_alloc = a.hbm_used
a.free_hbm(pa1, 4096)
assert a.hbm_used == used_after_alloc - 4096
pa2 = a.alloc_hbm(4096)
assert pa2 is not None
def test_allocator_free_tcm_reclaims_space():
"""free_tcm returns TCM space."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_tcm(256)
used_after_alloc = a.tcm_used
a.free_tcm(pa1, 256)
assert a.tcm_used == used_after_alloc - 256
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
def test_del_tensor_unmaps_mmu():
"""del tensor removes MMU mappings."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
va_base = t._handle.va_base
# Verify mapping exists
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_comp.mmu.translate(va_base) is not None
# Delete tensor
del t
gc.collect()
# Mapping should be gone
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
def test_del_tensor_reclaims_va():
"""del tensor returns VA space for reuse."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
va1 = t1._handle.va_base
del t1
gc.collect()
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
va2 = t2._handle.va_base
assert va2 == va1
# ── TF3. Context manager cleanup ─────────────────────────────────────
def test_context_manager_cleans_all():
"""Exiting context manager cleans up all tensors."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="ctx_mgr",
spec=graph.spec,
) as ctx:
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
va_base = t._handle.va_base
# After context exit, MMU mappings should be cleared
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
# ── TF4. del after context exit is safe ───────────────────────────────
def test_del_after_context_exit_no_crash():
"""del tensor after context manager exit does not crash."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="safe_del",
spec=graph.spec,
)
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
# Simulate context going away
ctx.cleanup()
# del should not crash even though context already cleaned up
del t
gc.collect()
# No exception = pass
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
def test_alloc_del_cycle():
"""Multiple alloc-del cycles work correctly."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
for i in range(3):
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
assert t._handle is not None
assert t._handle.va_base > 0
del t
gc.collect()
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
def test_del_already_cleaned_tensor_no_crash():
"""del on a tensor whose handle is already None does not crash."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
ctx.cleanup() # clears all tensors
# t._handle is now None
del t # should not crash
gc.collect()