08812eda58
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
194 lines
6.0 KiB
Python
194 lines
6.0 KiB
Python
"""Tests for tensor free: del-based + context manager cleanup.
|
|
|
|
Validates:
|
|
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
|
|
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
|
|
TF3. Context manager cleans up all tensors on exit
|
|
TF4. del after context exit is safe (no crash)
|
|
TF5. Alloc-del-alloc cycle reuses VA and PA
|
|
TF6. del already-freed tensor is safe (no crash)
|
|
"""
|
|
import gc
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
|
from kernbench.policy.address.pe_mmu import PageFault
|
|
from kernbench.policy.placement.dp import DPPolicy
|
|
from kernbench.sim_engine.engine import GraphEngine
|
|
from kernbench.runtime_api.context import RuntimeContext
|
|
from kernbench.runtime_api.types import DeviceSelector
|
|
from kernbench.topology.builder import load_topology
|
|
|
|
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
|
|
|
_MB = 1 << 20
|
|
_GB = 1 << 30
|
|
|
|
_CFG = AddressConfig(
|
|
sip_count=2,
|
|
cubes_per_sip=16,
|
|
pes_per_cube=8,
|
|
hbm_bytes_per_cube=48 * _GB,
|
|
hbm_slices_per_cube=8,
|
|
tcm_bytes_per_pe=16 * _MB,
|
|
tcm_scheduler_reserved_bytes=4 * _MB,
|
|
sram_bytes_per_cube=32 * _MB,
|
|
)
|
|
|
|
|
|
def _make_ctx():
|
|
graph = load_topology(TOPOLOGY_PATH)
|
|
engine = GraphEngine(graph)
|
|
ctx = RuntimeContext(
|
|
engine=engine,
|
|
target_device=DeviceSelector("sip:0"),
|
|
correlation_id="test_free",
|
|
spec=graph.spec,
|
|
)
|
|
return ctx, engine
|
|
|
|
|
|
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
|
|
|
|
|
|
def test_allocator_free_hbm_reclaims_space():
|
|
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
|
|
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
|
pa1 = a.alloc_hbm(4096)
|
|
used_after_alloc = a.hbm_used
|
|
a.free_hbm(pa1, 4096)
|
|
assert a.hbm_used == used_after_alloc - 4096
|
|
pa2 = a.alloc_hbm(4096)
|
|
assert pa2 is not None
|
|
|
|
|
|
def test_allocator_free_tcm_reclaims_space():
|
|
"""free_tcm returns TCM space."""
|
|
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
|
pa1 = a.alloc_tcm(256)
|
|
used_after_alloc = a.tcm_used
|
|
a.free_tcm(pa1, 256)
|
|
assert a.tcm_used == used_after_alloc - 256
|
|
|
|
|
|
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
|
|
|
|
|
|
def test_del_tensor_unmaps_mmu():
|
|
"""del tensor removes MMU mappings."""
|
|
ctx, engine = _make_ctx()
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
|
|
va_base = t._handle.va_base
|
|
|
|
# Verify mapping exists
|
|
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
|
assert mmu_comp.mmu.translate(va_base) is not None
|
|
|
|
# Delete tensor
|
|
del t
|
|
gc.collect()
|
|
|
|
# Mapping should be gone
|
|
with pytest.raises(PageFault):
|
|
mmu_comp.mmu.translate(va_base)
|
|
|
|
|
|
def test_del_tensor_reclaims_va():
|
|
"""del tensor returns VA space for reuse."""
|
|
ctx, engine = _make_ctx()
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
|
|
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
|
|
va1 = t1._handle.va_base
|
|
|
|
del t1
|
|
gc.collect()
|
|
|
|
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
|
|
va2 = t2._handle.va_base
|
|
assert va2 == va1
|
|
|
|
|
|
# ── TF3. Context manager cleanup ─────────────────────────────────────
|
|
|
|
|
|
def test_context_manager_cleans_all():
|
|
"""Exiting context manager cleans up all tensors."""
|
|
graph = load_topology(TOPOLOGY_PATH)
|
|
engine = GraphEngine(graph)
|
|
|
|
with RuntimeContext(
|
|
engine=engine,
|
|
target_device=DeviceSelector("sip:0"),
|
|
correlation_id="ctx_mgr",
|
|
spec=graph.spec,
|
|
) as ctx:
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
|
|
va_base = t._handle.va_base
|
|
|
|
# After context exit, MMU mappings should be cleared
|
|
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
|
with pytest.raises(PageFault):
|
|
mmu_comp.mmu.translate(va_base)
|
|
|
|
|
|
# ── TF4. del after context exit is safe ───────────────────────────────
|
|
|
|
|
|
def test_del_after_context_exit_no_crash():
|
|
"""del tensor after context manager exit does not crash."""
|
|
graph = load_topology(TOPOLOGY_PATH)
|
|
engine = GraphEngine(graph)
|
|
|
|
ctx = RuntimeContext(
|
|
engine=engine,
|
|
target_device=DeviceSelector("sip:0"),
|
|
correlation_id="safe_del",
|
|
spec=graph.spec,
|
|
)
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
|
|
|
|
# Simulate context going away
|
|
ctx.cleanup()
|
|
|
|
# del should not crash even though context already cleaned up
|
|
del t
|
|
gc.collect()
|
|
# No exception = pass
|
|
|
|
|
|
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
|
|
|
|
|
|
def test_alloc_del_cycle():
|
|
"""Multiple alloc-del cycles work correctly."""
|
|
ctx, engine = _make_ctx()
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
|
|
for i in range(3):
|
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
|
|
assert t._handle is not None
|
|
assert t._handle.va_base > 0
|
|
del t
|
|
gc.collect()
|
|
|
|
|
|
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
|
|
|
|
|
|
def test_del_already_cleaned_tensor_no_crash():
|
|
"""del on a tensor whose handle is already None does not crash."""
|
|
ctx, engine = _make_ctx()
|
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
|
|
|
|
ctx.cleanup() # clears all tensors
|
|
# t._handle is now None
|
|
del t # should not crash
|
|
gc.collect()
|