"""QKV GEMM benchmark: Q*K^T projection on a single PE. Demonstrates the full host-to-PE kernel launch pipeline: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_CPU → PE_SCHEDULER → engines Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait() - Tensor a is loaded into TCM via DMA - Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32) """ from kernbench.policy.placement.dp import DPPolicy # GEMM dimensions: (M, K) x (K, N) → (M, N) M, K, N = 128, 256, 128 DTYPE = "f16" def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): """QKV GEMM kernel: out = a @ b. a is loaded into TCM (DMA_READ). b is referenced in HBM (tl.ref, no DMA — scheduler streams per-tile). """ a = tl.load(a_ptr, shape=(M, K), dtype=DTYPE) b = tl.ref(b_ptr, shape=(K, N), dtype=DTYPE) handle = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr) tl.wait(handle) def run(ctx): """Run the QKV GEMM benchmark.""" # DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE) a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a") b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b") out = ctx.empty( (M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out", ) # Launch GEMM kernel ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)