Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-13 15:00:41 -07:00
parent 5accd98171
commit 83ea97b05f
11 changed files with 4219 additions and 51 deletions
+69
View File
@@ -0,0 +1,69 @@
"""Single-PE composite GEMM for PE_accelerator perf characterization.
Three operand-staging variants are selectable via MATMUL_VARIANT:
- "ref_ref" (default): a = tl.ref, b = tl.ref
Both operands HBM-resident; scheduler streams per-tile DMA.
- "load_ref": a = tl.load, b = tl.ref
A eagerly DMA'd into TCM up-front; B streamed per-tile.
- "load_load": a = tl.load, b = tl.load
Both eagerly DMA'd into TCM up-front.
Other env vars: MATMUL_M, MATMUL_K, MATMUL_N, MATMUL_DTYPE.
Run:
MATMUL_M=256 MATMUL_K=256 MATMUL_N=256 MATMUL_VARIANT=load_ref \
kernbench run --topology topology.yaml --bench matmul_composite
"""
import os
from kernbench.policy.placement.dp import DPPolicy
M = int(os.environ.get("MATMUL_M", "256"))
K = int(os.environ.get("MATMUL_K", "256"))
N = int(os.environ.get("MATMUL_N", "256"))
DTYPE = os.environ.get("MATMUL_DTYPE", "f16")
VARIANT = os.environ.get("MATMUL_VARIANT", "ref_ref")
def _kernel_ref_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
M, K, N = int(M), int(K), int(N)
a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE)
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
tl.wait(h)
def _kernel_load_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
M, K, N = int(M), int(K), int(N)
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
tl.wait(h)
def _kernel_load_load(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
M, K, N = int(M), int(K), int(N)
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
tl.wait(h)
_KERNELS = {
"ref_ref": _kernel_ref_ref,
"load_ref": _kernel_load_ref,
"load_load": _kernel_load_load,
}
def run(torch):
if VARIANT not in _KERNELS:
raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; "
f"expected one of {list(_KERNELS)}")
kernel_fn = _KERNELS[VARIANT]
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out")
torch.launch(f"matmul_composite_{VARIANT}", kernel_fn, a, b, out, M, K, N)