attention: add rank_axis kwarg to mesh kernels for multi_user cube ring
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.
Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
- 0 (default): rank == tl.program_id(axis=0). Existing single_user
behavior, all spec tests unchanged.
- 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
this to the kernel via ctx.launch positional arg.
Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).
Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,180 @@
|
||||
"""Mesh-native bidirectional Ring-K/V attention kernel — prefill (ADR-0059 Proposed).
|
||||
|
||||
Each rank holds its own Q tile and 1/n_ranks of K, V (sequence-sharded).
|
||||
Over ``n_ranks - 1`` bidirectional steps, K and V propagate both east and
|
||||
west: chunk c_i originating at rank i reaches rank j at step ``|i - j|``.
|
||||
Every rank receives every other rank's chunk **exactly once** and folds it
|
||||
into a running ``(m, ℓ, o)`` via the online-softmax recurrence. After all
|
||||
steps each rank holds the final attention output for its own Q tokens —
|
||||
no cross-rank merge is required.
|
||||
|
||||
Supersedes ADR-0055's closed-ring ``_attention_ring_kv.py``. Both modules
|
||||
stay on disk during the transition; this one runs on the hardware's
|
||||
actual open-mesh wiring (no closed-ring SFR install required).
|
||||
|
||||
Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches
|
||||
its imports) and invoked through ``torch.launch(...)`` — not through
|
||||
``dist.all_reduce(...)``. See ADR-0055 Context for why this kernel is not
|
||||
backend-dispatched via ADR-0050's algorithm-module contract.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
|
||||
def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle:
|
||||
"""Reshape — metadata only, no command emitted (cf. ``tl.trans``)."""
|
||||
return TensorHandle(
|
||||
id=handle.id,
|
||||
addr=handle.addr,
|
||||
shape=new_shape,
|
||||
dtype=handle.dtype,
|
||||
nbytes=handle.nbytes,
|
||||
data=handle.data,
|
||||
space=handle.space,
|
||||
pinned=handle.pinned,
|
||||
)
|
||||
|
||||
|
||||
def _partial_attention(
|
||||
Q: TensorHandle,
|
||||
K: TensorHandle,
|
||||
V: TensorHandle,
|
||||
S_q: int,
|
||||
S_kv_per_rank: int,
|
||||
h_q: int,
|
||||
d_head: int,
|
||||
tl,
|
||||
) -> tuple[TensorHandle, TensorHandle, TensorHandle]:
|
||||
"""One pass of partial attention against (K, V).
|
||||
|
||||
Emits 1 GEMM(Q·K^T) + softmax + max + sub + exp + sum + 1 GEMM(P·V).
|
||||
Returns the running-statistics triplet ``(m, ℓ, O_partial)`` for the
|
||||
online-softmax mlo merge.
|
||||
"""
|
||||
K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank))
|
||||
V_2d = _view(V, (S_kv_per_rank, h_q * d_head))
|
||||
|
||||
scores = tl.dot(Q, K_2d_T)
|
||||
m = tl.max(scores, axis=-1)
|
||||
P = tl.softmax(scores, axis=-1)
|
||||
scores_centered = scores - m
|
||||
exp_scores = tl.exp(scores_centered)
|
||||
ell = tl.sum(exp_scores, axis=-1)
|
||||
O_partial = tl.dot(P, V_2d)
|
||||
return m, ell, O_partial
|
||||
|
||||
|
||||
def attention_mesh_kv_kernel(
|
||||
q_ptr: int,
|
||||
k_ptr: int,
|
||||
v_ptr: int,
|
||||
o_ptr: int,
|
||||
S_q: int,
|
||||
S_kv_per_rank: int,
|
||||
h_q: int,
|
||||
h_kv: int,
|
||||
d_head: int,
|
||||
n_ranks: int,
|
||||
rank_axis: int = 0,
|
||||
*,
|
||||
tl,
|
||||
) -> None:
|
||||
"""Mesh-native bidirectional Ring-K/V attention — see module docstring.
|
||||
|
||||
``rank_axis`` selects which program-id dimension carries the ring rank:
|
||||
0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube).
|
||||
1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each
|
||||
cube participates; the other 7 hold KV replicas but stay silent.
|
||||
"""
|
||||
# For multi_user (rank_axis=1) only PE 0 in each cube runs the ring.
|
||||
if rank_axis != 0 and tl.program_id(axis=0) != 0:
|
||||
return
|
||||
rank = tl.program_id(axis=rank_axis)
|
||||
has_E = rank < n_ranks - 1
|
||||
has_W = rank > 0
|
||||
|
||||
# Q stays put on this rank — loaded once, used in every partial attention.
|
||||
Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16")
|
||||
|
||||
# Local K, V chunk.
|
||||
K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
|
||||
V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
|
||||
|
||||
# Step 0 (local): partial attention against own K, V — initializes the
|
||||
# running triplet (m, ℓ, o).
|
||||
m, ell, o = _partial_attention(
|
||||
Q, K, V, S_q, S_kv_per_rank, h_q, d_head, tl,
|
||||
)
|
||||
|
||||
# Seed bidirectional waves with own chunk (step-1 send).
|
||||
to_send_east_K: TensorHandle | None = K
|
||||
to_send_east_V: TensorHandle | None = V
|
||||
to_send_west_K: TensorHandle | None = K
|
||||
to_send_west_V: TensorHandle | None = V
|
||||
|
||||
# Bidirectional fan-out: n_ranks - 1 steps. By step k, the wave from
|
||||
# rank i has reached rank (i ± k). After n_ranks - 1 steps, every rank
|
||||
# has merged every other rank's chunk exactly once (ADR-0059 D3).
|
||||
for step in range(1, n_ranks):
|
||||
# Send the eastbound wave we currently hold (own at step 1; forwarded
|
||||
# at later steps). ``None`` means we have no wave to forward this
|
||||
# direction this step (edge rank, or the wave already passed by).
|
||||
if has_E and to_send_east_K is not None:
|
||||
tl.send(dir="E", src=to_send_east_K)
|
||||
tl.send(dir="E", src=to_send_east_V)
|
||||
if has_W and to_send_west_K is not None:
|
||||
tl.send(dir="W", src=to_send_west_K)
|
||||
tl.send(dir="W", src=to_send_west_V)
|
||||
|
||||
# Receive eastbound wave from W (carries chunk c_{rank - step}).
|
||||
K_from_W: TensorHandle | None = None
|
||||
V_from_W: TensorHandle | None = None
|
||||
if has_W and (rank - step) >= 0:
|
||||
K_from_W = tl.recv(
|
||||
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
|
||||
)
|
||||
V_from_W = tl.recv(
|
||||
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
|
||||
)
|
||||
m_new, ell_new, o_new = _partial_attention(
|
||||
Q, K_from_W, V_from_W, S_q, S_kv_per_rank, h_q, d_head, tl,
|
||||
)
|
||||
m_combined = tl.maximum(m, m_new)
|
||||
scale_old = tl.exp(m - m_combined)
|
||||
scale_new = tl.exp(m_new - m_combined)
|
||||
ell = ell * scale_old + ell_new * scale_new
|
||||
o = o * scale_old + o_new * scale_new
|
||||
m = m_combined
|
||||
|
||||
# Receive westbound wave from E (carries chunk c_{rank + step}).
|
||||
K_from_E: TensorHandle | None = None
|
||||
V_from_E: TensorHandle | None = None
|
||||
if has_E and (rank + step) < n_ranks:
|
||||
K_from_E = tl.recv(
|
||||
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
|
||||
)
|
||||
V_from_E = tl.recv(
|
||||
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
|
||||
)
|
||||
m_new, ell_new, o_new = _partial_attention(
|
||||
Q, K_from_E, V_from_E, S_q, S_kv_per_rank, h_q, d_head, tl,
|
||||
)
|
||||
m_combined = tl.maximum(m, m_new)
|
||||
scale_old = tl.exp(m - m_combined)
|
||||
scale_new = tl.exp(m_new - m_combined)
|
||||
ell = ell * scale_old + ell_new * scale_new
|
||||
o = o * scale_old + o_new * scale_new
|
||||
m = m_combined
|
||||
|
||||
# Forward what we received for next step. ``None`` propagates: if no
|
||||
# chunk arrived this step (out-of-bounds wave origin), there is
|
||||
# nothing to forward next step in that direction.
|
||||
to_send_east_K = K_from_W
|
||||
to_send_east_V = V_from_W
|
||||
to_send_west_K = K_from_E
|
||||
to_send_west_V = V_from_E
|
||||
|
||||
# Final normalize: O := o / ℓ.
|
||||
O_final = o / ell
|
||||
tl.store(o_ptr, O_final)
|
||||
Reference in New Issue
Block a user