Files
kernbench2/src/kernbench/benches/_attention_mesh_kv.py
T
mukesh 222815d374 attention: add rank_axis kwarg to mesh kernels for multi_user cube ring
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.

Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
  - 0 (default): rank == tl.program_id(axis=0). Existing single_user
    behavior, all spec tests unchanged.
  - 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
    then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
    this to the kernel via ctx.launch positional arg.

Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).

Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 19:53:18 -07:00

181 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Mesh-native bidirectional Ring-K/V attention kernel — prefill (ADR-0059 Proposed).
Each rank holds its own Q tile and 1/n_ranks of K, V (sequence-sharded).
Over ``n_ranks - 1`` bidirectional steps, K and V propagate both east and
west: chunk c_i originating at rank i reaches rank j at step ``|i - j|``.
Every rank receives every other rank's chunk **exactly once** and folds it
into a running ``(m, , o)`` via the online-softmax recurrence. After all
steps each rank holds the final attention output for its own Q tokens —
no cross-rank merge is required.
Supersedes ADR-0055's closed-ring ``_attention_ring_kv.py``. Both modules
stay on disk during the transition; this one runs on the hardware's
actual open-mesh wiring (no closed-ring SFR install required).
Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches
its imports) and invoked through ``torch.launch(...)`` — not through
``dist.all_reduce(...)``. See ADR-0055 Context for why this kernel is not
backend-dispatched via ADR-0050's algorithm-module contract.
"""
from __future__ import annotations
from kernbench.common.pe_commands import TensorHandle
def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle:
"""Reshape — metadata only, no command emitted (cf. ``tl.trans``)."""
return TensorHandle(
id=handle.id,
addr=handle.addr,
shape=new_shape,
dtype=handle.dtype,
nbytes=handle.nbytes,
data=handle.data,
space=handle.space,
pinned=handle.pinned,
)
def _partial_attention(
Q: TensorHandle,
K: TensorHandle,
V: TensorHandle,
S_q: int,
S_kv_per_rank: int,
h_q: int,
d_head: int,
tl,
) -> tuple[TensorHandle, TensorHandle, TensorHandle]:
"""One pass of partial attention against (K, V).
Emits 1 GEMM(Q·K^T) + softmax + max + sub + exp + sum + 1 GEMM(P·V).
Returns the running-statistics triplet ``(m, , O_partial)`` for the
online-softmax mlo merge.
"""
K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank))
V_2d = _view(V, (S_kv_per_rank, h_q * d_head))
scores = tl.dot(Q, K_2d_T)
m = tl.max(scores, axis=-1)
P = tl.softmax(scores, axis=-1)
scores_centered = scores - m
exp_scores = tl.exp(scores_centered)
ell = tl.sum(exp_scores, axis=-1)
O_partial = tl.dot(P, V_2d)
return m, ell, O_partial
def attention_mesh_kv_kernel(
q_ptr: int,
k_ptr: int,
v_ptr: int,
o_ptr: int,
S_q: int,
S_kv_per_rank: int,
h_q: int,
h_kv: int,
d_head: int,
n_ranks: int,
rank_axis: int = 0,
*,
tl,
) -> None:
"""Mesh-native bidirectional Ring-K/V attention — see module docstring.
``rank_axis`` selects which program-id dimension carries the ring rank:
0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube).
1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each
cube participates; the other 7 hold KV replicas but stay silent.
"""
# For multi_user (rank_axis=1) only PE 0 in each cube runs the ring.
if rank_axis != 0 and tl.program_id(axis=0) != 0:
return
rank = tl.program_id(axis=rank_axis)
has_E = rank < n_ranks - 1
has_W = rank > 0
# Q stays put on this rank — loaded once, used in every partial attention.
Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16")
# Local K, V chunk.
K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
# Step 0 (local): partial attention against own K, V — initializes the
# running triplet (m, , o).
m, ell, o = _partial_attention(
Q, K, V, S_q, S_kv_per_rank, h_q, d_head, tl,
)
# Seed bidirectional waves with own chunk (step-1 send).
to_send_east_K: TensorHandle | None = K
to_send_east_V: TensorHandle | None = V
to_send_west_K: TensorHandle | None = K
to_send_west_V: TensorHandle | None = V
# Bidirectional fan-out: n_ranks - 1 steps. By step k, the wave from
# rank i has reached rank (i ± k). After n_ranks - 1 steps, every rank
# has merged every other rank's chunk exactly once (ADR-0059 D3).
for step in range(1, n_ranks):
# Send the eastbound wave we currently hold (own at step 1; forwarded
# at later steps). ``None`` means we have no wave to forward this
# direction this step (edge rank, or the wave already passed by).
if has_E and to_send_east_K is not None:
tl.send(dir="E", src=to_send_east_K)
tl.send(dir="E", src=to_send_east_V)
if has_W and to_send_west_K is not None:
tl.send(dir="W", src=to_send_west_K)
tl.send(dir="W", src=to_send_west_V)
# Receive eastbound wave from W (carries chunk c_{rank - step}).
K_from_W: TensorHandle | None = None
V_from_W: TensorHandle | None = None
if has_W and (rank - step) >= 0:
K_from_W = tl.recv(
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
V_from_W = tl.recv(
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
m_new, ell_new, o_new = _partial_attention(
Q, K_from_W, V_from_W, S_q, S_kv_per_rank, h_q, d_head, tl,
)
m_combined = tl.maximum(m, m_new)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_new - m_combined)
ell = ell * scale_old + ell_new * scale_new
o = o * scale_old + o_new * scale_new
m = m_combined
# Receive westbound wave from E (carries chunk c_{rank + step}).
K_from_E: TensorHandle | None = None
V_from_E: TensorHandle | None = None
if has_E and (rank + step) < n_ranks:
K_from_E = tl.recv(
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
V_from_E = tl.recv(
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
m_new, ell_new, o_new = _partial_attention(
Q, K_from_E, V_from_E, S_q, S_kv_per_rank, h_q, d_head, tl,
)
m_combined = tl.maximum(m, m_new)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_new - m_combined)
ell = ell * scale_old + ell_new * scale_new
o = o * scale_old + o_new * scale_new
m = m_combined
# Forward what we received for next step. ``None`` propagates: if no
# chunk arrived this step (out-of-bounds wave origin), there is
# nothing to forward next step in that direction.
to_send_east_K = K_from_W
to_send_east_V = V_from_W
to_send_west_K = K_from_E
to_send_west_V = V_from_E
# Final normalize: O := o / .
O_final = o / ell
tl.store(o_ptr, O_final)