83ea97b05f
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
"""Single-PE composite GEMM for PE_accelerator perf characterization.
|
|
|
|
Three operand-staging variants are selectable via MATMUL_VARIANT:
|
|
|
|
- "ref_ref" (default): a = tl.ref, b = tl.ref
|
|
Both operands HBM-resident; scheduler streams per-tile DMA.
|
|
- "load_ref": a = tl.load, b = tl.ref
|
|
A eagerly DMA'd into TCM up-front; B streamed per-tile.
|
|
- "load_load": a = tl.load, b = tl.load
|
|
Both eagerly DMA'd into TCM up-front.
|
|
|
|
Other env vars: MATMUL_M, MATMUL_K, MATMUL_N, MATMUL_DTYPE.
|
|
|
|
Run:
|
|
MATMUL_M=256 MATMUL_K=256 MATMUL_N=256 MATMUL_VARIANT=load_ref \
|
|
kernbench run --topology topology.yaml --bench matmul_composite
|
|
"""
|
|
import os
|
|
|
|
from kernbench.policy.placement.dp import DPPolicy
|
|
|
|
M = int(os.environ.get("MATMUL_M", "256"))
|
|
K = int(os.environ.get("MATMUL_K", "256"))
|
|
N = int(os.environ.get("MATMUL_N", "256"))
|
|
DTYPE = os.environ.get("MATMUL_DTYPE", "f16")
|
|
VARIANT = os.environ.get("MATMUL_VARIANT", "ref_ref")
|
|
|
|
|
|
def _kernel_ref_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|
M, K, N = int(M), int(K), int(N)
|
|
a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
|
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
|
tl.wait(h)
|
|
|
|
|
|
def _kernel_load_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|
M, K, N = int(M), int(K), int(N)
|
|
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
|
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
|
tl.wait(h)
|
|
|
|
|
|
def _kernel_load_load(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|
M, K, N = int(M), int(K), int(N)
|
|
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
|
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
|
tl.wait(h)
|
|
|
|
|
|
_KERNELS = {
|
|
"ref_ref": _kernel_ref_ref,
|
|
"load_ref": _kernel_load_ref,
|
|
"load_load": _kernel_load_load,
|
|
}
|
|
|
|
|
|
def run(torch):
|
|
if VARIANT not in _KERNELS:
|
|
raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; "
|
|
f"expected one of {list(_KERNELS)}")
|
|
kernel_fn = _KERNELS[VARIANT]
|
|
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
|
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
|
|
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
|
|
out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out")
|
|
torch.launch(f"matmul_composite_{VARIANT}", kernel_fn, a, b, out, M, K, N)
|