From 222815d3742187880399a5a9c8746fcfff6fb462 Mon Sep 17 00:00:00 2001 From: Mukesh Garg Date: Mon, 1 Jun 2026 19:53:18 -0700 Subject: [PATCH] attention: add rank_axis kwarg to mesh kernels for multi_user cube ring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ADR-0059 single_user_* panels run the ring across PEs in one cube (rank == tl.program_id(axis=0)). multi_user_* panels run the ring across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs in each cube must stay silent because the cube-level SFR install only gives the cube-coordinate PE 0 an E/W neighbor. Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel`` and ``attention_mesh_kv_kernel``: - 0 (default): rank == tl.program_id(axis=0). Existing single_user behavior, all spec tests unchanged. - 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start, then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass this to the kernel via ctx.launch positional arg. Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as the committed home of the ADR-0059 kernels (previously living uncommitted in the working tree from sub-cycle 4b). Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating and cube-id semantics, both kernels); 4-panel diag harness now green end-to-end (single_user_prefill/decode + multi_user_prefill/decode); 763-test wider sweep clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/kernbench/benches/_attention_mesh_kv.py | 180 ++++++++++++++++++ src/kernbench/benches/_attention_mesh_mlo.py | 151 +++++++++++++++ .../test_attention_mesh_panels_diag.py | 2 + .../attention/test_mesh_kernels_rank_axis.py | 172 +++++++++++++++++ 4 files changed, 505 insertions(+) create mode 100644 src/kernbench/benches/_attention_mesh_kv.py create mode 100644 src/kernbench/benches/_attention_mesh_mlo.py create mode 100644 tests/attention/test_mesh_kernels_rank_axis.py diff --git a/src/kernbench/benches/_attention_mesh_kv.py b/src/kernbench/benches/_attention_mesh_kv.py new file mode 100644 index 0000000..df03fa9 --- /dev/null +++ b/src/kernbench/benches/_attention_mesh_kv.py @@ -0,0 +1,180 @@ +"""Mesh-native bidirectional Ring-K/V attention kernel — prefill (ADR-0059 Proposed). + +Each rank holds its own Q tile and 1/n_ranks of K, V (sequence-sharded). +Over ``n_ranks - 1`` bidirectional steps, K and V propagate both east and +west: chunk c_i originating at rank i reaches rank j at step ``|i - j|``. +Every rank receives every other rank's chunk **exactly once** and folds it +into a running ``(m, ℓ, o)`` via the online-softmax recurrence. After all +steps each rank holds the final attention output for its own Q tokens — +no cross-rank merge is required. + +Supersedes ADR-0055's closed-ring ``_attention_ring_kv.py``. Both modules +stay on disk during the transition; this one runs on the hardware's +actual open-mesh wiring (no closed-ring SFR install required). + +Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches +its imports) and invoked through ``torch.launch(...)`` — not through +``dist.all_reduce(...)``. See ADR-0055 Context for why this kernel is not +backend-dispatched via ADR-0050's algorithm-module contract. +""" +from __future__ import annotations + +from kernbench.common.pe_commands import TensorHandle + + +def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle: + """Reshape — metadata only, no command emitted (cf. ``tl.trans``).""" + return TensorHandle( + id=handle.id, + addr=handle.addr, + shape=new_shape, + dtype=handle.dtype, + nbytes=handle.nbytes, + data=handle.data, + space=handle.space, + pinned=handle.pinned, + ) + + +def _partial_attention( + Q: TensorHandle, + K: TensorHandle, + V: TensorHandle, + S_q: int, + S_kv_per_rank: int, + h_q: int, + d_head: int, + tl, +) -> tuple[TensorHandle, TensorHandle, TensorHandle]: + """One pass of partial attention against (K, V). + + Emits 1 GEMM(Q·K^T) + softmax + max + sub + exp + sum + 1 GEMM(P·V). + Returns the running-statistics triplet ``(m, ℓ, O_partial)`` for the + online-softmax mlo merge. + """ + K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank)) + V_2d = _view(V, (S_kv_per_rank, h_q * d_head)) + + scores = tl.dot(Q, K_2d_T) + m = tl.max(scores, axis=-1) + P = tl.softmax(scores, axis=-1) + scores_centered = scores - m + exp_scores = tl.exp(scores_centered) + ell = tl.sum(exp_scores, axis=-1) + O_partial = tl.dot(P, V_2d) + return m, ell, O_partial + + +def attention_mesh_kv_kernel( + q_ptr: int, + k_ptr: int, + v_ptr: int, + o_ptr: int, + S_q: int, + S_kv_per_rank: int, + h_q: int, + h_kv: int, + d_head: int, + n_ranks: int, + rank_axis: int = 0, + *, + tl, +) -> None: + """Mesh-native bidirectional Ring-K/V attention — see module docstring. + + ``rank_axis`` selects which program-id dimension carries the ring rank: + 0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube). + 1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each + cube participates; the other 7 hold KV replicas but stay silent. + """ + # For multi_user (rank_axis=1) only PE 0 in each cube runs the ring. + if rank_axis != 0 and tl.program_id(axis=0) != 0: + return + rank = tl.program_id(axis=rank_axis) + has_E = rank < n_ranks - 1 + has_W = rank > 0 + + # Q stays put on this rank — loaded once, used in every partial attention. + Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16") + + # Local K, V chunk. + K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") + V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") + + # Step 0 (local): partial attention against own K, V — initializes the + # running triplet (m, ℓ, o). + m, ell, o = _partial_attention( + Q, K, V, S_q, S_kv_per_rank, h_q, d_head, tl, + ) + + # Seed bidirectional waves with own chunk (step-1 send). + to_send_east_K: TensorHandle | None = K + to_send_east_V: TensorHandle | None = V + to_send_west_K: TensorHandle | None = K + to_send_west_V: TensorHandle | None = V + + # Bidirectional fan-out: n_ranks - 1 steps. By step k, the wave from + # rank i has reached rank (i ± k). After n_ranks - 1 steps, every rank + # has merged every other rank's chunk exactly once (ADR-0059 D3). + for step in range(1, n_ranks): + # Send the eastbound wave we currently hold (own at step 1; forwarded + # at later steps). ``None`` means we have no wave to forward this + # direction this step (edge rank, or the wave already passed by). + if has_E and to_send_east_K is not None: + tl.send(dir="E", src=to_send_east_K) + tl.send(dir="E", src=to_send_east_V) + if has_W and to_send_west_K is not None: + tl.send(dir="W", src=to_send_west_K) + tl.send(dir="W", src=to_send_west_V) + + # Receive eastbound wave from W (carries chunk c_{rank - step}). + K_from_W: TensorHandle | None = None + V_from_W: TensorHandle | None = None + if has_W and (rank - step) >= 0: + K_from_W = tl.recv( + dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", + ) + V_from_W = tl.recv( + dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", + ) + m_new, ell_new, o_new = _partial_attention( + Q, K_from_W, V_from_W, S_q, S_kv_per_rank, h_q, d_head, tl, + ) + m_combined = tl.maximum(m, m_new) + scale_old = tl.exp(m - m_combined) + scale_new = tl.exp(m_new - m_combined) + ell = ell * scale_old + ell_new * scale_new + o = o * scale_old + o_new * scale_new + m = m_combined + + # Receive westbound wave from E (carries chunk c_{rank + step}). + K_from_E: TensorHandle | None = None + V_from_E: TensorHandle | None = None + if has_E and (rank + step) < n_ranks: + K_from_E = tl.recv( + dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", + ) + V_from_E = tl.recv( + dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", + ) + m_new, ell_new, o_new = _partial_attention( + Q, K_from_E, V_from_E, S_q, S_kv_per_rank, h_q, d_head, tl, + ) + m_combined = tl.maximum(m, m_new) + scale_old = tl.exp(m - m_combined) + scale_new = tl.exp(m_new - m_combined) + ell = ell * scale_old + ell_new * scale_new + o = o * scale_old + o_new * scale_new + m = m_combined + + # Forward what we received for next step. ``None`` propagates: if no + # chunk arrived this step (out-of-bounds wave origin), there is + # nothing to forward next step in that direction. + to_send_east_K = K_from_W + to_send_east_V = V_from_W + to_send_west_K = K_from_E + to_send_west_V = V_from_E + + # Final normalize: O := o / ℓ. + O_final = o / ell + tl.store(o_ptr, O_final) diff --git a/src/kernbench/benches/_attention_mesh_mlo.py b/src/kernbench/benches/_attention_mesh_mlo.py new file mode 100644 index 0000000..2474626 --- /dev/null +++ b/src/kernbench/benches/_attention_mesh_mlo.py @@ -0,0 +1,151 @@ +"""Mesh-native bidirectional AllReduce-mlo attention — decode (ADR-0059 Proposed). + +Every rank holds the full Q (replicated, small at ``S_q=1``) and 1/n_ranks +of KV (sequence-sharded). Each rank computes its partial attention +against own KV in ONE shot, then runs a bidirectional fan-out of the +``(m, ℓ, o)`` triplet: the triplet originating at rank i reaches rank j at +step ``|i - j|``. Every rank merges every other rank's triplet exactly +once over ``n_ranks - 1`` steps, ending with the final answer replicated +on every rank. + +Supersedes ADR-0056's closed-ring ``_attention_allreduce_mlo.py``. Both +modules stay on disk during the transition; this one runs on the +hardware's actual open-mesh wiring (no closed-ring SFR install required). + +Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches +its imports) and invoked through ``torch.launch(...)`` — not through +``dist.all_reduce(...)``. See ADR-0056 Context for why this kernel is not +backend-dispatched via ADR-0050's algorithm-module contract. +""" +from __future__ import annotations + +from kernbench.common.pe_commands import TensorHandle + + +def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle: + """Reshape — metadata only, no command emitted (cf. ``tl.trans``).""" + return TensorHandle( + id=handle.id, + addr=handle.addr, + shape=new_shape, + dtype=handle.dtype, + nbytes=handle.nbytes, + data=handle.data, + space=handle.space, + pinned=handle.pinned, + ) + + +def attention_mesh_mlo_kernel( + q_ptr: int, + k_ptr: int, + v_ptr: int, + o_ptr: int, + S_q: int, + S_kv_per_rank: int, + h_q: int, + h_kv: int, + d_head: int, + n_ranks: int, + rank_axis: int = 0, + *, + tl, +) -> None: + """Mesh-native bidirectional AllReduce-mlo — see module docstring. + + ``rank_axis`` selects which program-id dimension carries the ring rank: + 0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube). + 1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each + cube participates; the other 7 hold KV replicas but stay silent. + """ + # For multi_user (rank_axis=1) only PE 0 in each cube runs the ring. + if rank_axis != 0 and tl.program_id(axis=0) != 0: + return + rank = tl.program_id(axis=rank_axis) + has_E = rank < n_ranks - 1 + has_W = rank > 0 + + # Q is replicated on every rank — loaded once. + Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16") + + # Local KV chunk. KV is sequence-sharded and stays put on this rank for + # the entire fan-out — distinguishing decode from prefill (ADR-0059 D3) + # where KV circulates. + K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") + V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") + + # ── One-shot local partial attention ────────────────────────── + K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank)) + V_2d = _view(V, (S_kv_per_rank, h_q * d_head)) + scores = tl.dot(Q, K_2d_T) + m = tl.max(scores, axis=-1) + P = tl.softmax(scores, axis=-1) + scores_centered = scores - m + exp_scores = tl.exp(scores_centered) + ell = tl.sum(exp_scores, axis=-1) + o = tl.dot(P, V_2d) + + # Seed bidirectional waves with own triplet (step-1 send). + to_send_east_m: TensorHandle | None = m + to_send_east_ell: TensorHandle | None = ell + to_send_east_o: TensorHandle | None = o + to_send_west_m: TensorHandle | None = m + to_send_west_ell: TensorHandle | None = ell + to_send_west_o: TensorHandle | None = o + + # Bidirectional fan-out of (m, ℓ, o) triplets — n_ranks - 1 steps. + for step in range(1, n_ranks): + # Send eastbound triplet (own at step 1; forwarded at later steps). + if has_E and to_send_east_m is not None: + tl.send(dir="E", src=to_send_east_m) + tl.send(dir="E", src=to_send_east_ell) + tl.send(dir="E", src=to_send_east_o) + # Send westbound triplet. + if has_W and to_send_west_m is not None: + tl.send(dir="W", src=to_send_west_m) + tl.send(dir="W", src=to_send_west_ell) + tl.send(dir="W", src=to_send_west_o) + + # Receive eastbound triplet from W (originated at rank - step). + m_from_W: TensorHandle | None = None + ell_from_W: TensorHandle | None = None + o_from_W: TensorHandle | None = None + if has_W and (rank - step) >= 0: + m_from_W = tl.recv(dir="W", shape=m.shape, dtype="f16") + ell_from_W = tl.recv(dir="W", shape=ell.shape, dtype="f16") + o_from_W = tl.recv(dir="W", shape=o.shape, dtype="f16") + m_combined = tl.maximum(m, m_from_W) + scale_old = tl.exp(m - m_combined) + scale_new = tl.exp(m_from_W - m_combined) + ell = ell * scale_old + ell_from_W * scale_new + o = o * scale_old + o_from_W * scale_new + m = m_combined + + # Receive westbound triplet from E (originated at rank + step). + m_from_E: TensorHandle | None = None + ell_from_E: TensorHandle | None = None + o_from_E: TensorHandle | None = None + if has_E and (rank + step) < n_ranks: + m_from_E = tl.recv(dir="E", shape=m.shape, dtype="f16") + ell_from_E = tl.recv(dir="E", shape=ell.shape, dtype="f16") + o_from_E = tl.recv(dir="E", shape=o.shape, dtype="f16") + m_combined = tl.maximum(m, m_from_E) + scale_old = tl.exp(m - m_combined) + scale_new = tl.exp(m_from_E - m_combined) + ell = ell * scale_old + ell_from_E * scale_new + o = o * scale_old + o_from_E * scale_new + m = m_combined + + # Forward the original received triplet (not the merged running state) + # so neighbors get the original wave. ``None`` propagates if nothing + # arrived this step. + to_send_east_m = m_from_W + to_send_east_ell = ell_from_W + to_send_east_o = o_from_W + to_send_west_m = m_from_E + to_send_west_ell = ell_from_E + to_send_west_o = o_from_E + + # Final normalize: O := o / ℓ. + O_final = o / ell + tl.store(o_ptr, O_final) diff --git a/tests/attention/test_attention_mesh_panels_diag.py b/tests/attention/test_attention_mesh_panels_diag.py index df15132..706071d 100644 --- a/tests/attention/test_attention_mesh_panels_diag.py +++ b/tests/attention/test_attention_mesh_panels_diag.py @@ -149,6 +149,7 @@ def _bench_fn_multi_user_prefill(ctx): "multi_user_prefill_mesh", attention_mesh_kv_kernel, q, k, v, o, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + 1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user) _auto_dim_remap=False, ) @@ -169,6 +170,7 @@ def _bench_fn_multi_user_decode(ctx): "multi_user_decode_mesh", attention_mesh_mlo_kernel, q, k, v, o, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + 1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user) _auto_dim_remap=False, ) diff --git a/tests/attention/test_mesh_kernels_rank_axis.py b/tests/attention/test_mesh_kernels_rank_axis.py new file mode 100644 index 0000000..be4229b --- /dev/null +++ b/tests/attention/test_mesh_kernels_rank_axis.py @@ -0,0 +1,172 @@ +"""Phase 1 spec test for ``rank_axis`` parameter on the two mesh kernels. + +ADR-0059's mesh kernels currently hard-code ``rank = tl.program_id(axis=0)``, +which only works for single_user_* panels (rank == pe_id within cube). +For multi_user_* panels the ring is at the cube level — rank should be +``cube_id`` (axis=1), and the 7 non-rank-leader PEs in each cube should +not run the ring (they only hold KV replicas). + +This test pins the desired ``rank_axis`` kwarg semantics: + + rank_axis = 0 (default, single_user) + rank = tl.program_id(axis=0). Every PE in the cube runs the ring. + Existing behavior — no change. + + rank_axis = 1 (multi_user) + if tl.program_id(axis=0) != 0: return. (7/8 PEs early-exit.) + rank = tl.program_id(axis=1). + +Phase 1 expectation: tests fail today (kernels don't accept the kwarg). +Phase 2 lands the parameter on both kernels; tests turn green and the +multi_user_* diag harness clears its first send. +""" +from __future__ import annotations + +from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd +from kernbench.common.pe_commands import GemmCmd +from kernbench.triton_emu.tl_context import TLContext, run_kernel + +from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel +from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel + +S_Q_PREFILL = 16 +S_Q_DECODE = 1 +S_KV_PER_RANK = 16 +H_Q = 1 +H_KV = 1 +D_HEAD = 64 +N_RANKS_MULTI = 4 +PES_PER_CUBE = 8 + +Q_PTR = 0x10000 +K_PTR = 0x20000 +V_PTR = 0x30000 +O_PTR = 0x40000 + + +def _tl(pe_id: int, cube_id: int, num_pes: int, num_cubes: int) -> TLContext: + return TLContext( + pe_id=pe_id, + num_programs=num_pes, + cube_id=cube_id, + num_cubes=num_cubes, + dispatch_cycles=0, + scratch_base=0x80000, + scratch_size=1 << 20, + ) + + +# ── Default rank_axis=0 backward-compat ────────────────────────── + + +def test_mlo_kernel_default_rank_axis_zero_emits_commands_on_all_pes(): + """rank_axis defaults to 0 → kernel uses pe_id as rank, runs on every + PE. Verify by running rank=3 (interior PE) in a single-cube 8-rank + setup and asserting at least one GEMM and at least one IPCQ send + are emitted (interior ranks send in both directions).""" + tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1) + run_kernel( + attention_mesh_mlo_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8, + ) + assert any(isinstance(c, GemmCmd) for c in tl.commands), \ + "default rank_axis=0 must run the kernel (≥1 GEMM)" + assert any(isinstance(c, IpcqSendCmd) for c in tl.commands), \ + "interior rank must emit ≥1 IpcqSendCmd" + + +def test_kv_kernel_default_rank_axis_zero_emits_commands_on_all_pes(): + tl = _tl(pe_id=3, cube_id=0, num_pes=8, num_cubes=1) + run_kernel( + attention_mesh_kv_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, 8, + ) + assert any(isinstance(c, GemmCmd) for c in tl.commands) + assert any(isinstance(c, IpcqSendCmd) for c in tl.commands) + + +# ── rank_axis=1 multi_user semantics ───────────────────────────── + + +def test_mlo_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands(): + """rank_axis=1 + pe_id != 0 → kernel must early-return; no GEMM, + no DMA, no IPCQ. The 7 non-rank-leader PEs in a multi_user cube + must stay completely silent so the cube-level SFR install isn't + asked to route sends from PEs that have no neighbors installed.""" + tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) + run_kernel( + attention_mesh_mlo_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, + rank_axis=1, + ) + assert not any(isinstance(c, GemmCmd) for c in tl.commands), \ + "pe_id=2 with rank_axis=1 must not emit GEMMs" + assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands), \ + "pe_id=2 with rank_axis=1 must not emit IpcqSendCmd" + assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands), \ + "pe_id=2 with rank_axis=1 must not emit IpcqRecvCmd" + + +def test_kv_kernel_rank_axis_one_gates_non_zero_pe_to_no_commands(): + tl = _tl(pe_id=2, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) + run_kernel( + attention_mesh_kv_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, + rank_axis=1, + ) + assert not any(isinstance(c, GemmCmd) for c in tl.commands) + assert not any(isinstance(c, IpcqSendCmd) for c in tl.commands) + assert not any(isinstance(c, IpcqRecvCmd) for c in tl.commands) + + +def test_mlo_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank(): + """rank_axis=1 + pe_id == 0 → kernel runs the ring with rank=cube_id. + For cube_id=1 in a 4-cube ring, rank=1 is an interior rank: has_E=True + AND has_W=True → IPCQ sends emitted in both E and W directions. + """ + tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) + run_kernel( + attention_mesh_mlo_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, + rank_axis=1, + ) + sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] + assert any(s.direction == "E" for s in sends), \ + "cube_id=1 (interior) must emit ≥1 E-send" + assert any(s.direction == "W" for s in sends), \ + "cube_id=1 (interior) must emit ≥1 W-send" + + +def test_kv_kernel_rank_axis_one_pe_zero_uses_cube_id_as_rank(): + tl = _tl(pe_id=0, cube_id=1, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) + run_kernel( + attention_mesh_kv_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, + rank_axis=1, + ) + sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] + assert any(s.direction == "E" for s in sends) + assert any(s.direction == "W" for s in sends) + + +def test_mlo_kernel_rank_axis_one_west_edge_cube_no_west_sends(): + """cube_id=0 (west edge) with rank_axis=1: rank=0, has_W=False → no + W-direction IPCQ sends. has_E=True → ≥1 E-direction send.""" + tl = _tl(pe_id=0, cube_id=0, num_pes=PES_PER_CUBE, num_cubes=N_RANKS_MULTI) + run_kernel( + attention_mesh_mlo_kernel, tl, + Q_PTR, K_PTR, V_PTR, O_PTR, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, N_RANKS_MULTI, + rank_axis=1, + ) + sends = [c for c in tl.commands if isinstance(c, IpcqSendCmd)] + assert any(s.direction == "E" for s in sends), \ + "west-edge cube_id=0 must still emit ≥1 E-send" + assert not any(s.direction == "W" for s in sends), \ + "west-edge cube_id=0 must NOT emit any W-send (no W neighbor)"