commit - release 1
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
def run(ctx):
|
||||
print("IPCQ all reduce kernel bench")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
|
||||
BenchFn = Callable[[RuntimeContext], Any]
|
||||
|
||||
|
||||
def resolve_bench(bench_id: str) -> BenchFn:
|
||||
"""
|
||||
Resolve a bench id into a callable bench function.
|
||||
|
||||
Expected layout (repo root):
|
||||
benches/<bench_id>.py
|
||||
def run(ctx: RuntimeContext) -> Any
|
||||
"""
|
||||
bench_id = bench_id.strip()
|
||||
if not bench_id:
|
||||
raise ValueError("Bench id is empty.")
|
||||
|
||||
module_path = f"benches.{bench_id}"
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Unknown bench '{bench_id}'. Expected module {module_path}.py") from e
|
||||
|
||||
run_fn = getattr(mod, "run", None)
|
||||
if run_fn is None:
|
||||
raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.")
|
||||
if not callable(run_fn):
|
||||
raise ValueError(f"'run' in {module_path} is not callable.")
|
||||
|
||||
return run_fn
|
||||
@@ -0,0 +1,39 @@
|
||||
"""QKV GEMM benchmark: Q*K^T projection on a single PE.
|
||||
|
||||
Demonstrates the full host-to-PE kernel launch pipeline:
|
||||
Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_CPU → PE_SCHEDULER → engines
|
||||
|
||||
Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
|
||||
- Tensor a is loaded into TCM via DMA
|
||||
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# GEMM dimensions: (M, K) x (K, N) → (M, N)
|
||||
M, K, N = 128, 256, 128
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
"""QKV GEMM kernel: out = a @ b.
|
||||
|
||||
a is loaded into TCM (DMA_READ).
|
||||
b is referenced in HBM (tl.ref, no DMA — scheduler streams per-tile).
|
||||
"""
|
||||
a = tl.load(a_ptr, shape=(M, K), dtype=DTYPE)
|
||||
b = tl.ref(b_ptr, shape=(K, N), dtype=DTYPE)
|
||||
handle = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||
tl.wait(handle)
|
||||
|
||||
|
||||
def run(ctx):
|
||||
"""Run the QKV GEMM benchmark."""
|
||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
|
||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = ctx.empty(
|
||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||
)
|
||||
|
||||
# Launch GEMM kernel
|
||||
ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""QKV GEMM benchmark: Q*K^T projection on all PEs in a cube (multi-PE).
|
||||
|
||||
Column-parallel GEMM: a is replicated (cube-level), b/out are column-sharded.
|
||||
M_CPU fans out KernelLaunchMsg to all 8 PE_CPUs (ADR-0009 D3).
|
||||
|
||||
Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
|
||||
- Tensor a is loaded into TCM via DMA
|
||||
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# GEMM dimensions: (M, K) x (K, N) -> (M, N)
|
||||
M, K, N = 128, 256, 128
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
"""QKV GEMM kernel: out = a @ b.
|
||||
|
||||
a is loaded into TCM (DMA_READ).
|
||||
b is referenced in HBM (tl.ref, no DMA -- scheduler streams per-tile).
|
||||
"""
|
||||
a = tl.load(a_ptr, shape=(M, K), dtype=DTYPE)
|
||||
b = tl.ref(b_ptr, shape=(K, N), dtype=DTYPE)
|
||||
handle = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||
tl.wait(handle)
|
||||
|
||||
|
||||
def run(ctx):
|
||||
"""Run the multi-PE QKV GEMM benchmark."""
|
||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
|
||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = ctx.empty(
|
||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||
)
|
||||
|
||||
# Launch GEMM kernel on all PEs
|
||||
ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
|
||||
Reference in New Issue
Block a user