222815d374
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.
Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
- 0 (default): rank == tl.program_id(axis=0). Existing single_user
behavior, all spec tests unchanged.
- 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
this to the kernel via ctx.launch positional arg.
Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).
Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
173 lines
6.8 KiB
Python
173 lines
6.8 KiB
Python
"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels.
|
|
|
|
ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``,
|
|
which only works for single_user_* panels (rank == pe_id within cube).
|
|
For multi_user_* panels the ring is at the cube level — rank should be
|
|
``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should
|
|
not run the ring (they only hold KV replicas).
|
|
|
|
This test pins the desired ``rank_axis`` kwarg semantics:
|
|
|
|
rank_axis = 0 (default, single_user)
|
|
rank = tl.program_id(axis=0). Every PE in the cube runs the ring.
|
|
Existing behavior — no change.
|
|
|
|
rank_axis = 1 (multi_user)
|
|
if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.)
|
|
rank = tl.program_id(axis=1).
|
|
|
|
Phase 1 expectation: tests fail today (kernels don't accept the kwarg).
|
|
Phase 2 lands the parameter on both kernels; tests turn green and the
|
|
multi_user_* diag harness clears its first send.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd
|
|
from kernbench.common.pe_commands import GemmCmd
|
|
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
|
|
|
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
|
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
|
|
|
S_Q_PREFILL = 16
|
|
S_Q_DECODE = 1
|
|
S_KV_PER_RANK = 16
|
|
H_Q = 1
|
|
H_KV = 1
|
|
D_HEAD = 64
|
|
N_RANKS_MULTI = 4
|
|
PES_PER_CUBE = 8
|
|
|
|
Q_PTR = 0x10000
|
|
K_PTR = 0x20000
|
|
V_PTR = 0x30000
|
|
O_PTR = 0x40000
|
|
|
|
|
|
def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext:
|
|
return TLContext(
|
|
pe_id=pe_id,
|
|
num_programs=num_pes,
|
|
cube_id=cube_id,
|
|
num_cubes=num_cubes,
|
|
dispatch_cycles=0,
|
|
scratch_base=0x80000,
|
|
scratch_size=1 << 20,
|
|
)
|
|
|
|
|
|
# ── Default rank_axis=0 backward-compat ──────────────────────────
|
|
|
|
|
|
def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
|
|
"""rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every
|
|
PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank
|
|
setup and asserting at least one GEMM and at least one IPCQ send
|
|
are emitted (interior ranks send in both directions)."""
|
|
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
|
|
run_kernel(
|
|
attention_mesh_mlo_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
|
|
)
|
|
assert any(isinstance(c, GemmCmd) for c in tl.commands), \
|
|
"default rank_axis=0 must run the kernel (≥1 GEMM)"
|
|
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
|
|
"interior rank must emit ≥1 IpcqSendCmd"
|
|
|
|
|
|
def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes():
|
|
tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1)
|
|
run_kernel(
|
|
attention_mesh_kv_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8,
|
|
)
|
|
assert any(isinstance(c, GemmCmd) for c in tl.commands)
|
|
assert any(isinstance(c, IpcqSendCmd) for c in tl.commands)
|
|
|
|
|
|
# ── rank_axis=1 multi_user semantics ─────────────────────────────
|
|
|
|
|
|
def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
|
|
"""rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM,
|
|
no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube
|
|
must stay completely silent so the cube-level SFR install isn't
|
|
asked to route sends from PEs that have no neighbors installed."""
|
|
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
|
run_kernel(
|
|
attention_mesh_mlo_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
|
rank_axis=1,
|
|
)
|
|
assert not any(isinstance(c, GemmCmd) for c in tl.commands), \
|
|
"pe_id=2 with rank_axis=1 must not emit GEMMs"
|
|
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \
|
|
"pe_id=2 with rank_axis=1 must not emit IpcqSendCmd"
|
|
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \
|
|
"pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd"
|
|
|
|
|
|
def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands():
|
|
tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
|
run_kernel(
|
|
attention_mesh_kv_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
|
rank_axis=1,
|
|
)
|
|
assert not any(isinstance(c, GemmCmd) for c in tl.commands)
|
|
assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands)
|
|
assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands)
|
|
|
|
|
|
def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
|
|
"""rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id.
|
|
For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True
|
|
AND has_W=True → IPCQ sends emitted in both E and W directions.
|
|
"""
|
|
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
|
run_kernel(
|
|
attention_mesh_mlo_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
|
rank_axis=1,
|
|
)
|
|
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
|
assert any(s.direction == "E" for s in sends), \
|
|
"cube_id=1 (interior) must emit ≥1 E-send"
|
|
assert any(s.direction == "W" for s in sends), \
|
|
"cube_id=1 (interior) must emit ≥1 W-send"
|
|
|
|
|
|
def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank():
|
|
tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
|
run_kernel(
|
|
attention_mesh_kv_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
|
rank_axis=1,
|
|
)
|
|
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
|
assert any(s.direction == "E" for s in sends)
|
|
assert any(s.direction == "W" for s in sends)
|
|
|
|
|
|
def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends():
|
|
"""cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no
|
|
W-direction IPCQ sends. has_E=True → ≥1 E-direction send."""
|
|
tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI)
|
|
run_kernel(
|
|
attention_mesh_mlo_kernel, tl,
|
|
Q_PTR, K_PTR, V_PTR, O_PTR,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI,
|
|
rank_axis=1,
|
|
)
|
|
sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)]
|
|
assert any(s.direction == "E" for s in sends), \
|
|
"west-edge cube_id=0 must still emit ≥1 E-send"
|
|
assert not any(s.direction == "W" for s in sends), \
|
|
"west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"
|