Files
kernbench2/benches/gpt3_qkv.py
T
ywkang 114510d4b9 Add SchedulerV2 (pe_accel), DPPolicy overrides, and new benchmarks
- Add cycle-accurate PE accelerator scheduler (SchedulerV2) with tiled
  GEMM/Math pipelines (DMA_IN → GEMM → MATH → DMA_WB)
- Add DPPolicy num_pes/num_cubes/num_sips overrides for single-PE testing
- Support tuple target_pe for targeting specific PE subsets
- Add gemm_single_pe and gpt3_qkv benchmarks
- Switch default topology to pe_scheduler_v2

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:18:49 -07:00

93 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""GPT-3 QKV projection benchmark: sharded across PEs via pe_accel_v1.
GPT-3 architecture:
d_model = 12288 (hidden dimension)
n_heads = 96 (attention heads)
d_head = 128 (dimension per head)
Sharding strategy (column-wise across all PEs):
X : (seq_len, d_model) -- replicated to all PEs
W_Q/K/V : (d_model, d_model) -- column-wise sharded across cubes × PEs
out_Q/K/V: (seq_len, d_model) -- column-wise sharded across cubes × PEs
Each PE computes:
Q_slice = X @ W_Q_slice : (seq_len, d_model) @ (d_model, cols_per_pe) -> (seq_len, cols_per_pe)
K_slice, V_slice: same
PE count is configurable via N_CUBES × N_PE_PER_CUBE (DPPolicy override).
topology.yaml is unchanged.
Run:
kernbench run gpt3_qkv
"""
from kernbench.policy.placement.dp import DPPolicy
# -- PE configuration (DPPolicy overrides — does not change topology.yaml) -----
N_SIPS = 1
N_CUBES = 16 # cubes per SIP
N_PE_PER_CUBE = 8 # PEs per cube
N_PES = N_CUBES * N_PE_PER_CUBE # 128 total
# -- GPT-3 architecture -------------------------------------------------------
GPT3_D_MODEL = 12288
SEQ_LEN = 32
COLS_PER_PE = GPT3_D_MODEL // N_PES # 12288 / 128 = 96
DTYPE = "f16"
def _gpt3_qkv_kernel(x_ptr, wq_ptr, wk_ptr, wv_ptr,
out_q_ptr, out_k_ptr, out_v_ptr,
seq_len, d_model, cols_per_pe, tl, DTYPE="f16"):
"""GPT-3 QKV sharded: each PE uses program_id to index its VA slice."""
pid = tl.program_id(0)
bpe = 2 # f16
M = int(seq_len)
K = int(d_model)
N = int(cols_per_pe)
w_slice = K * N * bpe
out_slice = M * N * bpe
x = tl.load(int(x_ptr), shape=(M, K), dtype=DTYPE)
wq = tl.ref(int(wq_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
wk = tl.ref(int(wk_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
wv = tl.ref(int(wv_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
hq = tl.composite(op="gemm", a=x, b=wq,
out_ptr=int(out_q_ptr) + pid * out_slice)
hk = tl.composite(op="gemm", a=x, b=wk,
out_ptr=int(out_k_ptr) + pid * out_slice)
hv = tl.composite(op="gemm", a=x, b=wv,
out_ptr=int(out_v_ptr) + pid * out_slice)
tl.wait(hq)
tl.wait(hk)
tl.wait(hv)
def run(torch):
"""Run the GPT-3 QKV benchmark."""
M = SEQ_LEN
K = GPT3_D_MODEL
N = COLS_PER_PE
# X: replicated across all PEs
dp_replicate = DPPolicy(cube="replicate", pe="replicate",
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs
dp_sharded = DPPolicy(cube="column_wise", pe="column_wise",
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x")
wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq")
wk = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wk")
wv = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wv")
out_q = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_q")
out_k = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_k")
out_v = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_v")
torch.launch("gpt3_qkv", _gpt3_qkv_kernel,
x, wq, wk, wv, out_q, out_k, out_v,
M, K, N)