372c987995
GEMM dimension reduction: - qkv_gemm.py: M,K,N = 128,256,128 → 32,64,32 (64 tiles → 1 tile). - qkv_gemm_multi_pe.py: same reduction. - Tests verify pipeline correctness, not large-matrix throughput. - Per-test time: 18s → 1.7s. 6 tests total: 108s → 10s. pytest-xdist parallel execution: - Add pytest-xdist to dev dependencies. - pyproject.toml addopts: -n auto (use all CPU cores), -m "not slow". - Default `pytest` runs 501 tests in ~12s (previously 148s). - Full suite including slow: `pytest -m ""` → 3m24s (previously 5m43s). pytest.mark.slow: - Registered in pyproject.toml markers section. - 256-rank full-system test is the only slow-marked test. - Run with: pytest -m "" (CI) or pytest (local dev, skips slow). 502 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
"""QKV GEMM benchmark: Q*K^T projection on a single PE.
|
|
|
|
Demonstrates the full host-to-PE kernel launch pipeline:
|
|
Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_CPU → PE_SCHEDULER → engines
|
|
|
|
Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
|
|
- Tensor a is loaded into TCM via DMA
|
|
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
|
|
"""
|
|
from kernbench.policy.placement.dp import DPPolicy
|
|
|
|
# GEMM dimensions: (M, K) x (K, N) → (M, N)
|
|
# Small dims (1 tile) for fast regression. The test verifies the full
|
|
# host→PE pipeline, not large-matrix throughput.
|
|
M, K, N = 32, 64, 32
|
|
DTYPE = "f16"
|
|
|
|
|
|
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|
"""QKV GEMM kernel: out = a @ b.
|
|
|
|
a is loaded into TCM (DMA_READ).
|
|
b is referenced in HBM (tl.ref, no DMA — scheduler streams per-tile).
|
|
"""
|
|
a = tl.load(a_ptr, shape=(M, K), dtype=DTYPE)
|
|
b = tl.ref(b_ptr, shape=(K, N), dtype=DTYPE)
|
|
handle = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
|
tl.wait(handle)
|
|
|
|
|
|
def run(torch):
|
|
"""Run the QKV GEMM benchmark."""
|
|
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
|
|
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
|
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
|
out = torch.empty(
|
|
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
|
)
|
|
|
|
# Launch GEMM kernel
|
|
torch.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|