"""VA offset verification: each PE accesses its own local HBM slice. Verifies that column-wise sharding + VA offset calculation produces DMA addresses that translate to the correct PE's local HBM — not a remote PE. Tests: VO1. Per-PE DMA addresses are correct VAs (2D) VO2. Each VA translates to the executing PE's own HBM slice (2D) VO3. End-to-end bench completes (2D, full TP) VO4. Per-PE DMA addresses are correct VAs (1D) VO5. Each VA translates to local HBM (1D) VO6. End-to-end 1D bench completes """ from pathlib import Path import pytest from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.placement.dp import DPPolicy, column_wise from kernbench.runtime_api.tensor import deploy_tensor from kernbench.sim_engine.engine import GraphEngine from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.topology.builder import load_topology from kernbench.triton_emu.tl_context import TLContext, run_kernel TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" _MB = 1 << 20 _GB = 1 << 30 M, K = 128, 256 DTYPE = "f16" NUM_PE = 8 ELEM_BYTES = 2 def _copy_kernel_2d(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"): """Standard Triton 2D copy. M, K are cube-local.""" pid = tl.program_id(0) num_pe = tl.num_programs(0) cols_per_pe = K // num_pe elem_bytes = 2 offset = pid * M * cols_per_pe * elem_bytes data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE) tl.store(dst_ptr + offset, data) def _copy_kernel_1d(src_ptr, dst_ptr, N, tl, DTYPE="f16"): """Standard Triton 1D copy. N is cube-local.""" pid = tl.program_id(0) num_pe = tl.num_programs(0) elems_per_pe = N // num_pe elem_bytes = 2 offset = pid * elems_per_pe * elem_bytes data = tl.load(src_ptr + offset, shape=(elems_per_pe,), dtype=DTYPE) tl.store(dst_ptr + offset, data) def _make_standalone(shape, num_pe=NUM_PE): """Create standalone allocators + MMUs for unit testing.""" cfg = AddressConfig( sip_count=1, cubes_per_sip=1, pes_per_cube=num_pe, hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=num_pe, tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, sram_bytes_per_cube=32 * _MB, ) allocators = { i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg) for i in range(num_pe) } va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096) mmus = {i: PeMMU(page_size=4096) for i in range(num_pe)} return cfg, allocators, va_alloc, mmus # ── VO1. 2D: Per-PE DMA addresses are correct VAs ──────────────────── def test_2d_each_pe_computes_correct_va_offset(): """2D: each PE generates DMA at va_base + pid * block_bytes.""" src_va = 0x1_0000_0000 dst_va = 0x2_0000_0000 cols_per_pe = K // NUM_PE block_bytes = M * cols_per_pe * ELEM_BYTES for pe_id in range(NUM_PE): tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0) run_kernel(_copy_kernel_2d, tl, src_ptr=src_va, dst_ptr=dst_va, M=M, K=K) reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)] writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] expected_offset = pe_id * block_bytes assert reads[0].src_addr == src_va + expected_offset assert writes[0].dst_addr == dst_va + expected_offset # ── VO2. 2D: Each VA translates to local HBM ───────────────────────── def test_2d_va_translates_to_local_hbm(): """2D: each PE's DMA VA translates to its own HBM slice.""" cfg, allocators, va_alloc, mmus = _make_standalone((M, K)) slice_size = cfg.hbm_slice_bytes cols_per_pe = K // NUM_PE block_bytes = M * cols_per_pe * ELEM_BYTES placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE) handle = deploy_tensor( name="src", shape=(M, K), dtype="fp16", placement=placement, allocators=allocators, va_allocator=va_alloc, ) # Install per-PE mappings (simulating what context does via MmuMapMsg) for s in handle.shards: mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes) for pe_id in range(NUM_PE): va = handle.va_base + pe_id * block_bytes pa = mmus[pe_id].translate(va) decoded = PhysAddr.decode(pa) hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size) assert hbm_pe == pe_id, f"PE{pe_id} accessed PE{hbm_pe}'s HBM" # ── VO3. 2D: End-to-end bench completes ────────────────────────────── def test_2d_bench_completes(): """2D: full TP bench with standard Triton kernel pattern.""" graph = load_topology(TOPOLOGY_PATH) engine = GraphEngine(graph) ctx = RuntimeContext( engine=engine, target_device=DeviceSelector("sip:0"), correlation_id="vo3", spec=graph.spec, ) from benches.va_offset_verify import run as bench_run bench_run(ctx) ctx.wait_all() # ── VO4. 1D: Per-PE DMA addresses ──────────────────────────────────── N_1D = 1024 def test_1d_each_pe_computes_correct_offset(): """1D: each PE generates DMA at correct offset.""" src_va = 0x1_0000_0000 dst_va = 0x2_0000_0000 elems_per_pe = N_1D // NUM_PE block_bytes = elems_per_pe * ELEM_BYTES for pe_id in range(NUM_PE): tl = TLContext(pe_id=pe_id, num_programs=NUM_PE, dispatch_cycles=0) run_kernel(_copy_kernel_1d, tl, src_ptr=src_va, dst_ptr=dst_va, N=N_1D) reads = [c for c in tl.commands if isinstance(c, DmaReadCmd)] writes = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] expected_offset = pe_id * block_bytes assert reads[0].src_addr == src_va + expected_offset assert writes[0].dst_addr == dst_va + expected_offset # ── VO5. 1D: VA translates to local HBM ────────────────────────────── def test_1d_va_translates_to_local_hbm(): """1D: each PE's DMA VA translates to its own HBM slice.""" cfg, allocators, va_alloc, mmus = _make_standalone((1, N_1D)) slice_size = cfg.hbm_slice_bytes elems_per_pe = N_1D // NUM_PE block_bytes = elems_per_pe * ELEM_BYTES placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE) handle = deploy_tensor( name="src_1d", shape=(N_1D,), dtype="fp16", placement=placement, allocators=allocators, va_allocator=va_alloc, ) for s in handle.shards: mmus[s.pe].map(va=handle.va_base + s.offset_bytes, pa=s.pa, size=s.nbytes) for pe_id in range(NUM_PE): va = handle.va_base + pe_id * block_bytes pa = mmus[pe_id].translate(va) decoded = PhysAddr.decode(pa) hbm_pe = PhysAddr.hbm_pe_id(decoded.hbm_offset, slice_size) assert hbm_pe == pe_id, f"1D PE{pe_id} accessed PE{hbm_pe}'s HBM" # ── VO6. 1D: End-to-end ────────────────────────────────────────────── def test_1d_e2e_completes(): """1D: full engine run with column_wise TP sharding.""" graph = load_topology(TOPOLOGY_PATH) engine = GraphEngine(graph) ctx = RuntimeContext( engine=engine, target_device=DeviceSelector("sip:0"), correlation_id="vo6", spec=graph.spec, ) dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d") dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d") # launch() auto-localizes N_1D → cube-local N ctx.launch("va_1d_copy", _copy_kernel_1d, src, dst, N_1D) ctx.wait_all()