"""VA offset verification benchmark. Verifies that Triton-style base_ptr + pid * stride addressing works correctly with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own block from a sharded tensor and stores it back. The kernel uses standard Triton patterns: - tl.program_id(0) for PE index within cube - tl.num_programs(0) for PE count within cube - Shape args are automatically localized by launch() """ from kernbench.policy.placement.dp import DPPolicy M, K = 128, 256 DTYPE = "f16" def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"): """Standard Triton copy kernel. M and K are cube-local (set by launch).""" pid = tl.program_id(0) num_pe = tl.num_programs(0) cols_per_pe = K // num_pe elem_bytes = 2 # f16 offset = pid * M * cols_per_pe * elem_bytes data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE) tl.store(dst_ptr + offset, data) def run(torch): """Run the VA offset verification benchmark with full TP sharding.""" dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src") dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst") # launch() automatically converts M, K to cube-local values torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K) # Sanity check: kernel completed with non-zero latency kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"] assert len(kernel_traces) > 0, "No kernel traces recorded" for kt in kernel_traces: assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"