attention: add rank_axis kwarg to mesh kernels for multi_user cube ring

ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.

Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
  - 0 (default): rank == tl.program_id(axis=0). Existing single_user
    behavior, all spec tests unchanged.
  - 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
    then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
    this to the kernel via ctx.launch positional arg.

Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).

Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 19:53:18 -07:00
parent d9e767d048
commit 222815d374
4 changed files with 505 additions and 0 deletions
@@ -149,6 +149,7 @@ def _bench_fn_multi_user_prefill(ctx):
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
q, k, v, o,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
@@ -169,6 +170,7 @@ def _bench_fn_multi_user_decode(ctx):
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
q, k, v, o,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
@@ -0,0 +1,172 @@
"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels.
ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``,
which only works for single_user_* panels (rank == pe_id within cube).
For multi_user_* panels the ring is at the cube level — rank should be
``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should
not run the ring (they only hold KV replicas).
This test pins the desired ``rank_axis`` kwarg semantics:
rank_axis = 0 (default, single_user)
rank = tl.program_id(axis=0). Every PE in the cube runs the ring.
Existing behavior — no change.
rank_axis = 1 (multi_user)
if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.)
rank = tl.program_id(axis=1).
Phase 1 expectation: tests fail today (kernels don't accept the kwarg).
Phase 2 lands the parameter on both kernels; tests turn green and the
multi_user_* diag harness clears its first send.
"""
from __future__ import annotations
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd
from kernbench.common.pe_commands import GemmCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
S_Q_PREFILL = 16
S_Q_DECODE = 1
S_KV_PER_RANK = 16
H_Q = 1
H_KV = 1
D_HEAD = 64
N_RANKS_MULTI = 4
PES_PER_CUBE = 8
Q_PTR = 0x10000
K_PTR = 0x20000
V_PTR = 0x30000
O_PTR = 0x40000
def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext:
return TLContext(
pe_id=pe_id,
num_programs=num_pes,
cube_id=cube_id,
num_cubes=num_cubes,
dispatch_cycles=0,
scratch_base=0x80000,
scratch_size=1 << 20,
)
# ── Default rank_axis=0 backward-compat ──────────────────────────
def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
"""rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every
PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank
setup and asserting at least one GEMM and at least one IPCQ send
are emitted (interior ranks send in both directions)."""
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
)
assert any(isinstance(c, GemmCmd) for c in tl.commands), \
"default rank_axis=0 must run the kernel (≥1 GEMM)"
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
"interior rank must emit ≥1 IpcqSendCmd"
def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
)
assert any(isinstance(c, GemmCmd) for c in tl.commands)
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands)
# ── rank_axis=1 multi_user semantics ─────────────────────────────
def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
"""rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM,
no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube
must stay completely silent so the cube-level SFR install isn't
asked to route sends from PEs that have no neighbors installed."""
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
assert not any(isinstance(c, GemmCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit GEMMs"
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit IpcqSendCmd"
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \
"pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd"
def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
assert not any(isinstance(c, GemmCmd) for c in tl.commands)
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands)
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands)
def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
"""rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id.
For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True
AND has_W=True → IPCQ sends emitted in both E and W directions.
"""
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends), \
"cube_id=1 (interior) must emit ≥1 E-send"
assert any(s.direction == "W" for s in sends), \
"cube_id=1 (interior) must emit ≥1 W-send"
def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_kv_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends)
assert any(s.direction == "W" for s in sends)
def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends():
"""cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no
W-direction IPCQ sends. has_E=True → ≥1 E-direction send."""
tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
run_kernel(
attention_mesh_mlo_kernel, tl,
Q_PTR, K_PTR, V_PTR, O_PTR,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
rank_axis=1,
)
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
assert any(s.direction == "E" for s in sends), \
"west-edge cube_id=0 must still emit ≥1 E-send"
assert not any(s.direction == "W" for s in sends), \
"west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"