"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA. Validates: T10. TensorHandle has va_base; TensorShard does NOT have va field T11. deploy_tensor allocates VA + creates mapping entries T12. Tensor.va returns the tensor's VA base T13. tl.load/tl.store generate DMA commands with VA (not PA) T14. Kernel VA-based offset calculation flows through DMA commands """ import pytest from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy from kernbench.runtime_api.tensor import ( TensorHandle, TensorShard, deploy_tensor, ) from kernbench.runtime_api.kernel import TensorArgShard from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd from kernbench.triton_emu.tl_context import TLContext, run_kernel _MB = 1 << 20 _GB = 1 << 30 _CFG = AddressConfig( sip_count=2, cubes_per_sip=16, pes_per_cube=8, hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=8, tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, sram_bytes_per_cube=32 * _MB, ) def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]: return { (0, 0, i): PEMemAllocator(sip_id=0, die_id=0, pe_id=i, cfg=_CFG) for i in range(num_pe) } def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]: return {i: PeMMU(page_size=page_size) for i in range(num_pe)} def _make_va_allocator() -> VirtualAllocator: return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB) # ── T10. TensorHandle has va_base ──────────────────────────────────── def test_tensor_handle_has_va_base(): """TensorHandle must have a 'va_base' field.""" th = TensorHandle( name="A", shape=(1024, 512), dtype="fp16", itemsize=2, shards=(), va_base=0x1_0000_0000, ) assert th.va_base == 0x1_0000_0000 def test_tensor_handle_va_base_immutable(): """TensorHandle.va_base is immutable (frozen dataclass).""" th = TensorHandle( name="A", shape=(1024, 512), dtype="fp16", itemsize=2, shards=(), va_base=0x1_0000_0000, ) with pytest.raises(AttributeError): th.va_base = 0x2_0000_0000 # type: ignore[misc] def test_tensor_shard_no_va_field(): """TensorShard should NOT have a va field — va is derived from TensorHandle.va_base + shard.offset_bytes.""" ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0) assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field" # ── T11. deploy_tensor allocates VA + creates mappings ─────────────── def test_deploy_tensor_assigns_va_base(): """deploy_tensor with VA allocator assigns va_base to TensorHandle.""" allocs = _make_allocators() va_alloc = _make_va_allocator() placement = resolve_dp_policy( DPPolicy(cube="replicate", pe="column_wise"), shape=(1024, 512), itemsize=2, num_pe=8, num_cubes=1, target_sip=0, ) th = deploy_tensor( name="W", shape=(1024, 512), dtype="fp16", placement=placement, allocators=allocs, va_allocator=va_alloc, ) assert th.va_base is not None assert th.va_base > 0 def test_deploy_tensor_va_covers_all_shards(): """VA allocation covers the entire tensor; each shard is at va_base + offset.""" allocs = _make_allocators() va_alloc = _make_va_allocator() placement = resolve_dp_policy( DPPolicy(cube="replicate", pe="column_wise"), shape=(1024, 512), itemsize=2, num_pe=8, num_cubes=1, target_sip=0, ) th = deploy_tensor( name="W", shape=(1024, 512), dtype="fp16", placement=placement, allocators=allocs, va_allocator=va_alloc, ) for s in th.shards: shard_va = th.va_base + s.offset_bytes assert shard_va > 0 def test_deploy_tensor_does_not_install_mmu_mappings(): """deploy_tensor does NOT install MMU mappings — that's context's job.""" allocs = _make_allocators() va_alloc = _make_va_allocator() mmus = _make_mmus() placement = resolve_dp_policy( DPPolicy(cube="replicate", pe="column_wise"), shape=(1024, 512), itemsize=2, num_pe=8, num_cubes=1, target_sip=0, ) deploy_tensor( name="W", shape=(1024, 512), dtype="fp16", placement=placement, allocators=allocs, va_allocator=va_alloc, ) # No MMU should have any entries (mappings come from fabric MmuMapMsg) for mmu in mmus.values(): assert mmu.num_entries == 0 # ── T12. Tensor.va property ────────────────────────────────────────── def test_tensor_va_property(): """Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base).""" from kernbench.runtime_api.tensor import Tensor allocs = _make_allocators(1) va_alloc = _make_va_allocator() placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=4096)] t = Tensor(shape=(2048,), dtype="f16", name="test") t._handle = deploy_tensor( name="test", shape=(2048,), dtype="f16", placement=placement, allocators=allocs, va_allocator=va_alloc, ) assert t.va > 0 assert t.va == t._handle.va_base # ── T13. tl.load/tl.store use VA ───────────────────────────────────── def test_tl_load_uses_va_in_dma_cmd(): """tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa).""" tl = TLContext(dispatch_cycles=0) va_ptr = 0x1_0000_0000 h = tl.load(va_ptr, shape=(32, 64), dtype="f16") dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)] assert len(dma_cmds) == 1 # The DMA command should carry the VA assert dma_cmds[0].src_addr == va_ptr def test_tl_store_uses_va_in_dma_cmd(): """tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va.""" tl = TLContext(dispatch_cycles=0) h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32") va_out = 0x2_0000_0000 tl.store(va_out, h) dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] assert len(dma_cmds) == 1 assert dma_cmds[0].dst_addr == va_out # ── T14. Kernel VA offset calculation ───────────────────────────────── def test_kernel_va_offset_in_dma(): """Kernel using base_va + pid * stride generates correct VA in DmaReadCmd.""" def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"): pid = tl.program_id(0) elem_bytes = 2 # f16 offset = pid * BLOCK_SIZE * elem_bytes a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE) va_base = 0x1_0000_0000 tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0) run_kernel(tiled_kernel, tl, a_ptr=va_base) dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)] assert len(dma_cmds) == 1 expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes assert dma_cmds[0].src_addr == expected_va