"""Single-PE composite GEMM for PE_accelerator perf characterization. Three operand-staging variants are selectable via MATMUL_VARIANT: - "ref_ref" (default): a = tl.ref, b = tl.ref Both operands HBM-resident; scheduler streams per-tile DMA. - "load_ref": a = tl.load, b = tl.ref A eagerly DMA'd into TCM up-front; B streamed per-tile. - "load_load": a = tl.load, b = tl.load Both eagerly DMA'd into TCM up-front. Other env vars: MATMUL_M, MATMUL_K, MATMUL_N, MATMUL_DTYPE. Run: MATMUL_M=256 MATMUL_K=256 MATMUL_N=256 MATMUL_VARIANT=load_ref \ kernbench run --topology topology.yaml --bench matmul_composite """ import os from kernbench.policy.placement.dp import DPPolicy M = int(os.environ.get("MATMUL_M", "256")) K = int(os.environ.get("MATMUL_K", "256")) N = int(os.environ.get("MATMUL_N", "256")) DTYPE = os.environ.get("MATMUL_DTYPE", "f16") VARIANT = os.environ.get("MATMUL_VARIANT", "ref_ref") def _kernel_ref_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): M, K, N = int(M), int(K), int(N) a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE) b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE) h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) tl.wait(h) def _kernel_load_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): M, K, N = int(M), int(K), int(N) a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE) b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE) h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) tl.wait(h) def _kernel_load_load(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): M, K, N = int(M), int(K), int(N) a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE) b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE) h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) tl.wait(h) _KERNELS = { "ref_ref": _kernel_ref_ref, "load_ref": _kernel_load_ref, "load_load": _kernel_load_load, } def run(torch): if VARIANT not in _KERNELS: raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; " f"expected one of {list(_KERNELS)}") kernel_fn = _KERNELS[VARIANT] dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a") b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b") out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out") torch.launch(f"matmul_composite_{VARIANT}", kernel_fn, a, b, out, M, K, N)