"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels. ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``, which only works for single_user_* panels (rank == pe_id within cube). For multi_user_* panels the ring is at the cube level — rank should be ``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should not run the ring (they only hold KV replicas). This test pins the desired ``rank_axis`` kwarg semantics: rank_axis = 0 (default, single_user) rank = tl.program_id(axis=0). Every PE in the cube runs the ring. Existing behavior — no change. rank_axis = 1 (multi_user) if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.) rank = tl.program_id(axis=1). Phase 1 expectation: tests fail today (kernels don't accept the kwarg). Phase 2 lands the parameter on both kernels; tests turn green and the multi_user_* diag harness clears its first send. """ from __future__ import annotations from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd from kernbench.common.pe_commands import GemmCmd from kernbench.triton_emu.tl_context import TLContext, run_kernel from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel S_Q_PREFILL = 16 S_Q_DECODE = 1 S_KV_PER_RANK = 16 H_Q = 1 H_KV = 1 D_HEAD = 64 N_RANKS_MULTI = 4 PES_PER_CUBE = 8 Q_PTR = 0x10000 K_PTR = 0x20000 V_PTR = 0x30000 O_PTR = 0x40000 def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext: return TLContext( pe_id=pe_id, num_programs=num_pes, cube_id=cube_id, num_cubes=num_cubes, dispatch_cycles=0, scratch_base=0x80000, scratch_size=1 << 20, ) # ── Default rank_axis=0 backward-compat ────────────────────────── def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes(): """rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank setup and asserting at least one GEMM and at least one IPCQ send are emitted (interior ranks send in both directions).""" tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1) run_kernel( attention_mesh_mlo_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8, ) assert any(isinstance(c, GemmCmd) for c in tl.commands), \ "default rank_axis=0 must run the kernel (≥1 GEMM)" assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \ "interior rank must emit ≥1 IpcqSendCmd" def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes(): tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1) run_kernel( attention_mesh_kv_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8, ) assert any(isinstance(c, GemmCmd) for c in tl.commands) assert any(isinstance(c, IpcqSendCmd) for c in tl.commands) # ── rank_axis=1 multi_user semantics ───────────────────────────── def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands(): """rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM, no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube must stay completely silent so the cube-level SFR install isn't asked to route sends from PEs that have no neighbors installed.""" tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) run_kernel( attention_mesh_mlo_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, rank_axis=1, ) assert not any(isinstance(c, GemmCmd) for c in tl.commands), \ "pe_id=2 with rank_axis=1 must not emit GEMMs" assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \ "pe_id=2 with rank_axis=1 must not emit IpcqSendCmd" assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \ "pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd" def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands(): tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) run_kernel( attention_mesh_kv_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, rank_axis=1, ) assert not any(isinstance(c, GemmCmd) for c in tl.commands) assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands) assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands) def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank(): """rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id. For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True AND has_W=True → IPCQ sends emitted in both E and W directions. """ tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) run_kernel( attention_mesh_mlo_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, rank_axis=1, ) sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] assert any(s.direction == "E" for s in sends), \ "cube_id=1 (interior) must emit ≥1 E-send" assert any(s.direction == "W" for s in sends), \ "cube_id=1 (interior) must emit ≥1 W-send" def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank(): tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) run_kernel( attention_mesh_kv_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, rank_axis=1, ) sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] assert any(s.direction == "E" for s in sends) assert any(s.direction == "W" for s in sends) def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends(): """cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no W-direction IPCQ sends. has_E=True → ≥1 E-direction send.""" tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) run_kernel( attention_mesh_mlo_kernel, tl, Q_PTR, K_PTR, V_PTR, O_PTR, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, rank_axis=1, ) sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] assert any(s.direction == "E" for s in sends), \ "west-edge cube_id=0 must still emit ≥1 E-send" assert not any(s.direction == "W" for s in sends), \ "west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"