attention: add rank_axis kwarg to mesh kernels for multi_user cube ring
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.
Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
- 0 (default): rank == tl.program_id(axis=0). Existing single_user
behavior, all spec tests unchanged.
- 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
this to the kernel via ctx.launch positional arg.
Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).
Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -149,6 +149,7 @@ def _bench_fn_multi_user_prefill(ctx):
|
||||
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
@@ -169,6 +170,7 @@ def _bench_fn_multi_user_decode(ctx):
|
||||
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels.
|
||||
|
||||
ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``,
|
||||
which only works for single_user_* panels (rank == pe_id within cube).
|
||||
For multi_user_* panels the ring is at the cube level — rank should be
|
||||
``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should
|
||||
not run the ring (they only hold KV replicas).
|
||||
|
||||
This test pins the desired ``rank_axis`` kwarg semantics:
|
||||
|
||||
rank_axis = 0 (default, single_user)
|
||||
rank = tl.program_id(axis=0). Every PE in the cube runs the ring.
|
||||
Existing behavior — no change.
|
||||
|
||||
rank_axis = 1 (multi_user)
|
||||
if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.)
|
||||
rank = tl.program_id(axis=1).
|
||||
|
||||
Phase 1 expectation: tests fail today (kernels don't accept the kwarg).
|
||||
Phase 2 lands the parameter on both kernels; tests turn green and the
|
||||
multi_user_* diag harness clears its first send.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd
|
||||
from kernbench.common.pe_commands import GemmCmd
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
||||
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
||||
|
||||
S_Q_PREFILL = 16
|
||||
S_Q_DECODE = 1
|
||||
S_KV_PER_RANK = 16
|
||||
H_Q = 1
|
||||
H_KV = 1
|
||||
D_HEAD = 64
|
||||
N_RANKS_MULTI = 4
|
||||
PES_PER_CUBE = 8
|
||||
|
||||
Q_PTR = 0x10000
|
||||
K_PTR = 0x20000
|
||||
V_PTR = 0x30000
|
||||
O_PTR = 0x40000
|
||||
|
||||
|
||||
def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext:
|
||||
return TLContext(
|
||||
pe_id=pe_id,
|
||||
num_programs=num_pes,
|
||||
cube_id=cube_id,
|
||||
num_cubes=num_cubes,
|
||||
dispatch_cycles=0,
|
||||
scratch_base=0x80000,
|
||||
scratch_size=1 << 20,
|
||||
)
|
||||
|
||||
|
||||
# ── Default rank_axis=0 backward-compat ──────────────────────────
|
||||
|
||||
|
||||
def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
|
||||
"""rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every
|
||||
PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank
|
||||
setup and asserting at least one GEMM and at least one IPCQ send
|
||||
are emitted (interior ranks send in both directions)."""
|
||||
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
|
||||
run_kernel(
|
||||
attention_mesh_mlo_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
|
||||
)
|
||||
assert any(isinstance(c, GemmCmd) for c in tl.commands), \
|
||||
"default rank_axis=0 must run the kernel (≥1 GEMM)"
|
||||
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
|
||||
"interior rank must emit ≥1 IpcqSendCmd"
|
||||
|
||||
|
||||
def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
|
||||
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
|
||||
run_kernel(
|
||||
attention_mesh_kv_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
|
||||
)
|
||||
assert any(isinstance(c, GemmCmd) for c in tl.commands)
|
||||
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands)
|
||||
|
||||
|
||||
# ── rank_axis=1 multi_user semantics ─────────────────────────────
|
||||
|
||||
|
||||
def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
|
||||
"""rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM,
|
||||
no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube
|
||||
must stay completely silent so the cube-level SFR install isn't
|
||||
asked to route sends from PEs that have no neighbors installed."""
|
||||
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
||||
run_kernel(
|
||||
attention_mesh_mlo_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
||||
rank_axis=1,
|
||||
)
|
||||
assert not any(isinstance(c, GemmCmd) for c in tl.commands), \
|
||||
"pe_id=2 with rank_axis=1 must not emit GEMMs"
|
||||
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
|
||||
"pe_id=2 with rank_axis=1 must not emit IpcqSendCmd"
|
||||
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \
|
||||
"pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd"
|
||||
|
||||
|
||||
def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
|
||||
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
||||
run_kernel(
|
||||
attention_mesh_kv_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
||||
rank_axis=1,
|
||||
)
|
||||
assert not any(isinstance(c, GemmCmd) for c in tl.commands)
|
||||
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands)
|
||||
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands)
|
||||
|
||||
|
||||
def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
|
||||
"""rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id.
|
||||
For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True
|
||||
AND has_W=True → IPCQ sends emitted in both E and W directions.
|
||||
"""
|
||||
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
||||
run_kernel(
|
||||
attention_mesh_mlo_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
||||
rank_axis=1,
|
||||
)
|
||||
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
||||
assert any(s.direction == "E" for s in sends), \
|
||||
"cube_id=1 (interior) must emit ≥1 E-send"
|
||||
assert any(s.direction == "W" for s in sends), \
|
||||
"cube_id=1 (interior) must emit ≥1 W-send"
|
||||
|
||||
|
||||
def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
|
||||
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
||||
run_kernel(
|
||||
attention_mesh_kv_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
||||
rank_axis=1,
|
||||
)
|
||||
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
||||
assert any(s.direction == "E" for s in sends)
|
||||
assert any(s.direction == "W" for s in sends)
|
||||
|
||||
|
||||
def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends():
|
||||
"""cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no
|
||||
W-direction IPCQ sends. has_E=True → ≥1 E-direction send."""
|
||||
tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
||||
run_kernel(
|
||||
attention_mesh_mlo_kernel, tl,
|
||||
Q_PTR, K_PTR, V_PTR, O_PTR,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
||||
rank_axis=1,
|
||||
)
|
||||
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
||||
assert any(s.direction == "E" for s in sends), \
|
||||
"west-edge cube_id=0 must still emit ≥1 E-send"
|
||||
assert not any(s.direction == "W" for s in sends), \
|
||||
"west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"
|
||||
Reference in New Issue
Block a user