Files
kernbench2/benches/gpt3_qkv.py
T
ywkang 357cab525b ADR-0026: DPPolicy intra-device only + ShardSpec structural coords
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is
solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy
itself describes only the cube × PE layout within one SIP. ShardSpec
switches to structural (sip, cube, pe) coordinates; the flat pe_index
field/property is fully removed — silent drift between global-flat and
SIP-local interpretations was a foot-gun flagged by ADR-0024 D11.

Breaking API (explicit TypeError / AttributeError):
- DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError
- ShardSpec.pe_index -> AttributeError
- ShardSpec(pe_index=...) -> TypeError
- resolve_dp_policy now takes target_sip= (required), no num_sips.

Downstream migration:
- PE allocator dict keyed by (sip, cube, pe) tuples, in both
  _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup.
- _create_tensor passes target_sip=current_sip; post-hoc pe_index
  shifting removed entirely.
- launch._compute_local_shape drops the dp.sip branch.
- Internal resolvers (column_wise / row_wise / replicate / tiled_*)
  return _LocalPeShard (cube-local identifier) instead of ShardSpec —
  resolve_dp_policy lifts them to full structural coords.

Tests:
- New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the
  contract end-to-end.
- test_sip_parallel.py rewritten: SIP composition now modeled as two
  resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style).
- Call-site migration: test_tensor, test_va_integration, test_va_offset,
  test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches
  gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy
  branch) all use intra-device DPPolicy and structural ShardSpec.

Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged
ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 13:02:19 -07:00

97 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""GPT-3 QKV projection benchmark: sharded across PEs via pe_accel_v1.
GPT-3 architecture:
d_model = 12288 (hidden dimension)
n_heads = 96 (attention heads)
d_head = 128 (dimension per head)
Sharding strategy (column-wise across all PEs):
X : (seq_len, d_model) -- replicated to all PEs
W_Q/K/V : (d_model, d_model) -- column-wise sharded across cubes × PEs
out_Q/K/V: (seq_len, d_model) -- column-wise sharded across cubes × PEs
Each PE computes:
Q_slice = X @ W_Q_slice : (seq_len, d_model) @ (d_model, cols_per_pe) -> (seq_len, cols_per_pe)
K_slice, V_slice: same
PE count is configurable via N_CUBES × N_PE_PER_CUBE (DPPolicy override).
topology.yaml is unchanged.
Run:
kernbench run gpt3_qkv
"""
from kernbench.policy.placement.dp import DPPolicy
# -- PE configuration (DPPolicy overrides — does not change topology.yaml) -----
N_SIPS = 1
N_CUBES = 16 # cubes per SIP
N_PE_PER_CUBE = 8 # PEs per cube
N_PES = N_CUBES * N_PE_PER_CUBE # 128 total
# -- GPT-3 architecture -------------------------------------------------------
GPT3_D_MODEL = 12288
SEQ_LEN = 32
COLS_PER_PE = GPT3_D_MODEL // N_PES # 12288 / 128 = 96
DTYPE = "f16"
def _gpt3_qkv_kernel(x_ptr, wq_ptr, wk_ptr, wv_ptr,
out_q_ptr, out_k_ptr, out_v_ptr,
seq_len, d_model, cols_per_pe, tl, DTYPE="f16"):
"""GPT-3 QKV sharded: each PE uses program_id to index its VA slice."""
pid = tl.program_id(0)
bpe = 2 # f16
M = int(seq_len)
K = int(d_model)
N = int(cols_per_pe)
w_slice = K * N * bpe
out_slice = M * N * bpe
x = tl.load(int(x_ptr), shape=(M, K), dtype=DTYPE)
wq = tl.ref(int(wq_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
wk = tl.ref(int(wk_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
wv = tl.ref(int(wv_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
hq = tl.composite(op="gemm", a=x, b=wq,
out_ptr=int(out_q_ptr) + pid * out_slice)
hk = tl.composite(op="gemm", a=x, b=wk,
out_ptr=int(out_k_ptr) + pid * out_slice)
hv = tl.composite(op="gemm", a=x, b=wv,
out_ptr=int(out_v_ptr) + pid * out_slice)
tl.wait(hq)
tl.wait(hk)
tl.wait(hv)
def run(torch):
"""Run the GPT-3 QKV benchmark."""
M = SEQ_LEN
K = GPT3_D_MODEL
N = COLS_PER_PE
# ADR-0026: DPPolicy is intra-device only. For multi-SIP execution the
# ADR-0024 launcher calls this bench once per SIP (each worker via
# torch.ahbm.set_device(rank)); here the policy describes only the
# cube × PE layout within a single SIP.
# X: replicated across all PEs within the SIP
dp_replicate = DPPolicy(cube="replicate", pe="replicate",
num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs within the SIP
dp_sharded = DPPolicy(cube="column_wise", pe="column_wise",
num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x")
wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq")
wk = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wk")
wv = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wv")
out_q = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_q")
out_k = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_k")
out_v = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_v")
torch.launch("gpt3_qkv", _gpt3_qkv_kernel,
x, wq, wk, wv, out_q, out_k, out_v,
M, K, N)