attention: add rank_axis kwarg to mesh kernels for multi_user cube ring

ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.

Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
  - 0 (default): rank == tl.program_id(axis=0). Existing single_user
    behavior, all spec tests unchanged.
  - 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
    then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
    this to the kernel via ctx.launch positional arg.

Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).

Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 19:53:18 -07:00
parent d9e767d048
commit 222815d374
4 changed files with 505 additions and 0 deletions
+180
View File
@@ -0,0 +1,180 @@
"""Mesh-native bidirectional Ring-K/V attention kernel — prefill (ADR-0059 Proposed).
Each rank holds its own Q tile and 1/n_ranks of K, V (sequence-sharded).
Over ``n_ranks - 1`` bidirectional steps, K and V propagate both east and
west: chunk c_i originating at rank i reaches rank j at step ``|i - j|``.
Every rank receives every other rank's chunk **exactly once** and folds it
into a running ``(m, , o)`` via the online-softmax recurrence. After all
steps each rank holds the final attention output for its own Q tokens —
no cross-rank merge is required.
Supersedes ADR-0055's closed-ring ``_attention_ring_kv.py``. Both modules
stay on disk during the transition; this one runs on the hardware's
actual open-mesh wiring (no closed-ring SFR install required).
Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches
its imports) and invoked through ``torch.launch(...)`` — not through
``dist.all_reduce(...)``. See ADR-0055 Context for why this kernel is not
backend-dispatched via ADR-0050's algorithm-module contract.
"""
from __future__ import annotations
from kernbench.common.pe_commands import TensorHandle
def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle:
"""Reshape — metadata only, no command emitted (cf. ``tl.trans``)."""
return TensorHandle(
id=handle.id,
addr=handle.addr,
shape=new_shape,
dtype=handle.dtype,
nbytes=handle.nbytes,
data=handle.data,
space=handle.space,
pinned=handle.pinned,
)
def _partial_attention(
Q: TensorHandle,
K: TensorHandle,
V: TensorHandle,
S_q: int,
S_kv_per_rank: int,
h_q: int,
d_head: int,
tl,
) -> tuple[TensorHandle, TensorHandle, TensorHandle]:
"""One pass of partial attention against (K, V).
Emits 1 GEMM(Q·K^T) + softmax + max + sub + exp + sum + 1 GEMM(P·V).
Returns the running-statistics triplet ``(m, , O_partial)`` for the
online-softmax mlo merge.
"""
K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank))
V_2d = _view(V, (S_kv_per_rank, h_q * d_head))
scores = tl.dot(Q, K_2d_T)
m = tl.max(scores, axis=-1)
P = tl.softmax(scores, axis=-1)
scores_centered = scores - m
exp_scores = tl.exp(scores_centered)
ell = tl.sum(exp_scores, axis=-1)
O_partial = tl.dot(P, V_2d)
return m, ell, O_partial
def attention_mesh_kv_kernel(
q_ptr: int,
k_ptr: int,
v_ptr: int,
o_ptr: int,
S_q: int,
S_kv_per_rank: int,
h_q: int,
h_kv: int,
d_head: int,
n_ranks: int,
rank_axis: int = 0,
*,
tl,
) -> None:
"""Mesh-native bidirectional Ring-K/V attention — see module docstring.
``rank_axis`` selects which program-id dimension carries the ring rank:
0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube).
1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each
cube participates; the other 7 hold KV replicas but stay silent.
"""
# For multi_user (rank_axis=1) only PE 0 in each cube runs the ring.
if rank_axis != 0 and tl.program_id(axis=0) != 0:
return
rank = tl.program_id(axis=rank_axis)
has_E = rank < n_ranks - 1
has_W = rank > 0
# Q stays put on this rank — loaded once, used in every partial attention.
Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16")
# Local K, V chunk.
K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
# Step 0 (local): partial attention against own K, V — initializes the
# running triplet (m, , o).
m, ell, o = _partial_attention(
Q, K, V, S_q, S_kv_per_rank, h_q, d_head, tl,
)
# Seed bidirectional waves with own chunk (step-1 send).
to_send_east_K: TensorHandle | None = K
to_send_east_V: TensorHandle | None = V
to_send_west_K: TensorHandle | None = K
to_send_west_V: TensorHandle | None = V
# Bidirectional fan-out: n_ranks - 1 steps. By step k, the wave from
# rank i has reached rank (i ± k). After n_ranks - 1 steps, every rank
# has merged every other rank's chunk exactly once (ADR-0059 D3).
for step in range(1, n_ranks):
# Send the eastbound wave we currently hold (own at step 1; forwarded
# at later steps). ``None`` means we have no wave to forward this
# direction this step (edge rank, or the wave already passed by).
if has_E and to_send_east_K is not None:
tl.send(dir="E", src=to_send_east_K)
tl.send(dir="E", src=to_send_east_V)
if has_W and to_send_west_K is not None:
tl.send(dir="W", src=to_send_west_K)
tl.send(dir="W", src=to_send_west_V)
# Receive eastbound wave from W (carries chunk c_{rank - step}).
K_from_W: TensorHandle | None = None
V_from_W: TensorHandle | None = None
if has_W and (rank - step) >= 0:
K_from_W = tl.recv(
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
V_from_W = tl.recv(
dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
m_new, ell_new, o_new = _partial_attention(
Q, K_from_W, V_from_W, S_q, S_kv_per_rank, h_q, d_head, tl,
)
m_combined = tl.maximum(m, m_new)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_new - m_combined)
ell = ell * scale_old + ell_new * scale_new
o = o * scale_old + o_new * scale_new
m = m_combined
# Receive westbound wave from E (carries chunk c_{rank + step}).
K_from_E: TensorHandle | None = None
V_from_E: TensorHandle | None = None
if has_E and (rank + step) < n_ranks:
K_from_E = tl.recv(
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
V_from_E = tl.recv(
dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16",
)
m_new, ell_new, o_new = _partial_attention(
Q, K_from_E, V_from_E, S_q, S_kv_per_rank, h_q, d_head, tl,
)
m_combined = tl.maximum(m, m_new)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_new - m_combined)
ell = ell * scale_old + ell_new * scale_new
o = o * scale_old + o_new * scale_new
m = m_combined
# Forward what we received for next step. ``None`` propagates: if no
# chunk arrived this step (out-of-bounds wave origin), there is
# nothing to forward next step in that direction.
to_send_east_K = K_from_W
to_send_east_V = V_from_W
to_send_west_K = K_from_E
to_send_west_V = V_from_E
# Final normalize: O := o / .
O_final = o / ell
tl.store(o_ptr, O_final)
@@ -0,0 +1,151 @@
"""Mesh-native bidirectional AllReduce-mlo attention — decode (ADR-0059 Proposed).
Every rank holds the full Q (replicated, small at ``S_q=1``) and 1/n_ranks
of KV (sequence-sharded). Each rank computes its partial attention
against own KV in ONE shot, then runs a bidirectional fan-out of the
``(m, , o)`` triplet: the triplet originating at rank i reaches rank j at
step ``|i - j|``. Every rank merges every other rank's triplet exactly
once over ``n_ranks - 1`` steps, ending with the final answer replicated
on every rank.
Supersedes ADR-0056's closed-ring ``_attention_allreduce_mlo.py``. Both
modules stay on disk during the transition; this one runs on the
hardware's actual open-mesh wiring (no closed-ring SFR install required).
Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches
its imports) and invoked through ``torch.launch(...)`` — not through
``dist.all_reduce(...)``. See ADR-0056 Context for why this kernel is not
backend-dispatched via ADR-0050's algorithm-module contract.
"""
from __future__ import annotations
from kernbench.common.pe_commands import TensorHandle
def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle:
"""Reshape — metadata only, no command emitted (cf. ``tl.trans``)."""
return TensorHandle(
id=handle.id,
addr=handle.addr,
shape=new_shape,
dtype=handle.dtype,
nbytes=handle.nbytes,
data=handle.data,
space=handle.space,
pinned=handle.pinned,
)
def attention_mesh_mlo_kernel(
q_ptr: int,
k_ptr: int,
v_ptr: int,
o_ptr: int,
S_q: int,
S_kv_per_rank: int,
h_q: int,
h_kv: int,
d_head: int,
n_ranks: int,
rank_axis: int = 0,
*,
tl,
) -> None:
"""Mesh-native bidirectional AllReduce-mlo — see module docstring.
``rank_axis`` selects which program-id dimension carries the ring rank:
0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube).
1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each
cube participates; the other 7 hold KV replicas but stay silent.
"""
# For multi_user (rank_axis=1) only PE 0 in each cube runs the ring.
if rank_axis != 0 and tl.program_id(axis=0) != 0:
return
rank = tl.program_id(axis=rank_axis)
has_E = rank < n_ranks - 1
has_W = rank > 0
# Q is replicated on every rank — loaded once.
Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16")
# Local KV chunk. KV is sequence-sharded and stays put on this rank for
# the entire fan-out — distinguishing decode from prefill (ADR-0059 D3)
# where KV circulates.
K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16")
# ── One-shot local partial attention ──────────────────────────
K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank))
V_2d = _view(V, (S_kv_per_rank, h_q * d_head))
scores = tl.dot(Q, K_2d_T)
m = tl.max(scores, axis=-1)
P = tl.softmax(scores, axis=-1)
scores_centered = scores - m
exp_scores = tl.exp(scores_centered)
ell = tl.sum(exp_scores, axis=-1)
o = tl.dot(P, V_2d)
# Seed bidirectional waves with own triplet (step-1 send).
to_send_east_m: TensorHandle | None = m
to_send_east_ell: TensorHandle | None = ell
to_send_east_o: TensorHandle | None = o
to_send_west_m: TensorHandle | None = m
to_send_west_ell: TensorHandle | None = ell
to_send_west_o: TensorHandle | None = o
# Bidirectional fan-out of (m, , o) triplets — n_ranks - 1 steps.
for step in range(1, n_ranks):
# Send eastbound triplet (own at step 1; forwarded at later steps).
if has_E and to_send_east_m is not None:
tl.send(dir="E", src=to_send_east_m)
tl.send(dir="E", src=to_send_east_ell)
tl.send(dir="E", src=to_send_east_o)
# Send westbound triplet.
if has_W and to_send_west_m is not None:
tl.send(dir="W", src=to_send_west_m)
tl.send(dir="W", src=to_send_west_ell)
tl.send(dir="W", src=to_send_west_o)
# Receive eastbound triplet from W (originated at rank - step).
m_from_W: TensorHandle | None = None
ell_from_W: TensorHandle | None = None
o_from_W: TensorHandle | None = None
if has_W and (rank - step) >= 0:
m_from_W = tl.recv(dir="W", shape=m.shape, dtype="f16")
ell_from_W = tl.recv(dir="W", shape=ell.shape, dtype="f16")
o_from_W = tl.recv(dir="W", shape=o.shape, dtype="f16")
m_combined = tl.maximum(m, m_from_W)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_from_W - m_combined)
ell = ell * scale_old + ell_from_W * scale_new
o = o * scale_old + o_from_W * scale_new
m = m_combined
# Receive westbound triplet from E (originated at rank + step).
m_from_E: TensorHandle | None = None
ell_from_E: TensorHandle | None = None
o_from_E: TensorHandle | None = None
if has_E and (rank + step) < n_ranks:
m_from_E = tl.recv(dir="E", shape=m.shape, dtype="f16")
ell_from_E = tl.recv(dir="E", shape=ell.shape, dtype="f16")
o_from_E = tl.recv(dir="E", shape=o.shape, dtype="f16")
m_combined = tl.maximum(m, m_from_E)
scale_old = tl.exp(m - m_combined)
scale_new = tl.exp(m_from_E - m_combined)
ell = ell * scale_old + ell_from_E * scale_new
o = o * scale_old + o_from_E * scale_new
m = m_combined
# Forward the original received triplet (not the merged running state)
# so neighbors get the original wave. ``None`` propagates if nothing
# arrived this step.
to_send_east_m = m_from_W
to_send_east_ell = ell_from_W
to_send_east_o = o_from_W
to_send_west_m = m_from_E
to_send_west_ell = ell_from_E
to_send_west_o = o_from_E
# Final normalize: O := o / .
O_final = o / ell
tl.store(o_ptr, O_final)
@@ -149,6 +149,7 @@ def _bench_fn_multi_user_prefill(ctx):
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
q, k, v, o,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
@@ -169,6 +170,7 @@ def _bench_fn_multi_user_decode(ctx):
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
q, k, v, o,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
@@ -0,0 +1,172 @@
"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels.
ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``,
which only works for single_user_* panels (rank == pe_id within cube).
For multi_user_* panels the ring is at the cube level — rank should be
``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should
not run the ring (they only hold KV replicas).
This test pins the desired ``rank_axis`` kwarg semantics:
rank_axis = 0 (default, single_user)
rank = tl.program_id(axis=0). Every PE in the cube runs the ring.
Existing behavior — no change.
rank_axis = 1 (multi_user)
if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.)
rank = tl.program_id(axis=1).
Phase 1 expectation: tests fail today (kernels don't accept the kwarg).
Phase 2 lands the parameter on both kernels; tests turn green and the
multi_user_* diag harness clears its first send.
"""
from __future__ import annotations
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd
from kernbench.common.pe_commands import GemmCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
S_Q_PREFILL = 16
S_Q_DECODE = 1
S_KV_PER_RANK = 16
H_Q = 1
H_KV = 1
D_HEAD = 64
N_RANKS_MULTI = 4
PES_PER_CUBE = 8
Q_PTR = 0x10000
K_PTR = 0x20000
V_PTR = 0x30000
O_PTR = 0x40000
def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext:
return TLContext(
pe_id=pe_id,
num_programs=num_pes,
cube_id=cube_id,
num_cubes=num_cubes,
dispatch_cycles=0,
scratch_base=0x80000,
scratch_size=1 << 20,
)
# ── Default rank_axis=0 backward-compat ──────────────────────────
def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
"""rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every
PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank
setup and asserting at least one GEMM and at least one IPCQ send
are emitted (interior ranks send in both directions)."""
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
)
assert any(isinstance(c, GemmCmd) for c in tl.commands), \
"default rank_axis=0 must run the kernel (≥1 GEMM)"
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
"interior rank must emit ≥1 IpcqSendCmd"
def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
)
assert any(isinstance(c, GemmCmd) for c in tl.commands)
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands)
# ── rank_axis=1 multi_user semantics ─────────────────────────────
def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
"""rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM,
no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube
must stay completely silent so the cube-level SFR install isn't
asked to route sends from PEs that have no neighbors installed."""
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
assert not any(isinstance(c, GemmCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit GEMMs"
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit IpcqSendCmd"
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd"
def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
assert not any(isinstance(c, GemmCmd) for c in tl.commands)
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands)
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands)
def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
"""rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id.
For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True
AND has_W=True → IPCQ sends emitted in both E and W directions.
"""
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends), \
"cube_id=1 (interior) must emit ≥1 E-send"
assert any(s.direction == "W" for s in sends), \
"cube_id=1 (interior) must emit ≥1 W-send"
def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends)
assert any(s.direction == "W" for s in sends)
def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends():
"""cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no
W-direction IPCQ sends. has_E=True → ≥1 E-direction send."""
tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends), \
"west-edge cube_id=0 must still emit ≥1 E-send"
assert not any(s.direction == "W" for s in sends), \
"west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"