Files
kernbench2/benches/qkv_gemm_multi_pe.py
T
2026-03-18 11:47:48 -07:00

40 lines
1.5 KiB
Python

"""QKV GEMM benchmark: Q*K^T projection on all PEs in a cube (multi-PE).
Column-parallel GEMM: a is replicated (cube-level), b/out are column-sharded.
M_CPU fans out KernelLaunchMsg to all 8 PE_CPUs (ADR-0009 D3).
Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
- Tensor a is loaded into TCM via DMA
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
"""
from kernbench.policy.placement.dp import DPPolicy
# GEMM dimensions: (M, K) x (K, N) -> (M, N)
M, K, N = 128, 256, 128
DTYPE = "f16"
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
"""QKV GEMM kernel: out = a @ b.
a is loaded into TCM (DMA_READ).
b is referenced in HBM (tl.ref, no DMA -- scheduler streams per-tile).
"""
a = tl.load(a_ptr, shape=(M, K), dtype=DTYPE)
b = tl.ref(b_ptr, shape=(K, N), dtype=DTYPE)
handle = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
tl.wait(handle)
def run(ctx):
"""Run the multi-PE QKV GEMM benchmark."""
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
out = ctx.empty(
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
)
# Launch GEMM kernel on all PEs
ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)