"""Mesh-native bidirectional Ring-K/V attention kernel — prefill (ADR-0059 Proposed). Each rank holds its own Q tile and 1/n_ranks of K, V (sequence-sharded). Over ``n_ranks - 1`` bidirectional steps, K and V propagate both east and west: chunk c_i originating at rank i reaches rank j at step ``|i - j|``. Every rank receives every other rank's chunk **exactly once** and folds it into a running ``(m, ℓ, o)`` via the online-softmax recurrence. After all steps each rank holds the final attention output for its own Q tokens — no cross-rank merge is required. Supersedes ADR-0055's closed-ring ``_attention_ring_kv.py``. Both modules stay on disk during the transition; this one runs on the hardware's actual open-mesh wiring (no closed-ring SFR install required). Imported by ``milestone_gqa_llama70b`` (after the bench's Phase 2 switches its imports) and invoked through ``torch.launch(...)`` — not through ``dist.all_reduce(...)``. See ADR-0055 Context for why this kernel is not backend-dispatched via ADR-0050's algorithm-module contract. """ from __future__ import annotations from kernbench.common.pe_commands import TensorHandle def _view(handle: TensorHandle, new_shape: tuple[int, ...]) -> TensorHandle: """Reshape — metadata only, no command emitted (cf. ``tl.trans``).""" return TensorHandle( id=handle.id, addr=handle.addr, shape=new_shape, dtype=handle.dtype, nbytes=handle.nbytes, data=handle.data, space=handle.space, pinned=handle.pinned, ) def _partial_attention( Q: TensorHandle, K: TensorHandle, V: TensorHandle, S_q: int, S_kv_per_rank: int, h_q: int, d_head: int, tl, ) -> tuple[TensorHandle, TensorHandle, TensorHandle]: """One pass of partial attention against (K, V). Emits 1 GEMM(Q·K^T) + softmax + max + sub + exp + sum + 1 GEMM(P·V). Returns the running-statistics triplet ``(m, ℓ, O_partial)`` for the online-softmax mlo merge. """ K_2d_T = _view(K, (h_q * d_head, S_kv_per_rank)) V_2d = _view(V, (S_kv_per_rank, h_q * d_head)) scores = tl.dot(Q, K_2d_T) m = tl.max(scores, axis=-1) P = tl.softmax(scores, axis=-1) scores_centered = scores - m exp_scores = tl.exp(scores_centered) ell = tl.sum(exp_scores, axis=-1) O_partial = tl.dot(P, V_2d) return m, ell, O_partial def attention_mesh_kv_kernel( q_ptr: int, k_ptr: int, v_ptr: int, o_ptr: int, S_q: int, S_kv_per_rank: int, h_q: int, h_kv: int, d_head: int, n_ranks: int, rank_axis: int = 0, *, tl, ) -> None: """Mesh-native bidirectional Ring-K/V attention — see module docstring. ``rank_axis`` selects which program-id dimension carries the ring rank: 0 — single_user_* panels: rank == tl.program_id(axis=0) (PE id in cube). 1 — multi_user_* panels: ring is at the cube level. Only PE 0 in each cube participates; the other 7 hold KV replicas but stay silent. """ # For multi_user (rank_axis=1) only PE 0 in each cube runs the ring. if rank_axis != 0 and tl.program_id(axis=0) != 0: return rank = tl.program_id(axis=rank_axis) has_E = rank < n_ranks - 1 has_W = rank > 0 # Q stays put on this rank — loaded once, used in every partial attention. Q = tl.load(q_ptr, shape=(S_q, h_q * d_head), dtype="f16") # Local K, V chunk. K = tl.load(k_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") V = tl.load(v_ptr, shape=(S_kv_per_rank, h_kv, d_head), dtype="f16") # Step 0 (local): partial attention against own K, V — initializes the # running triplet (m, ℓ, o). m, ell, o = _partial_attention( Q, K, V, S_q, S_kv_per_rank, h_q, d_head, tl, ) # Seed bidirectional waves with own chunk (step-1 send). to_send_east_K: TensorHandle | None = K to_send_east_V: TensorHandle | None = V to_send_west_K: TensorHandle | None = K to_send_west_V: TensorHandle | None = V # Bidirectional fan-out: n_ranks - 1 steps. By step k, the wave from # rank i has reached rank (i ± k). After n_ranks - 1 steps, every rank # has merged every other rank's chunk exactly once (ADR-0059 D3). for step in range(1, n_ranks): # Send the eastbound wave we currently hold (own at step 1; forwarded # at later steps). ``None`` means we have no wave to forward this # direction this step (edge rank, or the wave already passed by). if has_E and to_send_east_K is not None: tl.send(dir="E", src=to_send_east_K) tl.send(dir="E", src=to_send_east_V) if has_W and to_send_west_K is not None: tl.send(dir="W", src=to_send_west_K) tl.send(dir="W", src=to_send_west_V) # Receive eastbound wave from W (carries chunk c_{rank - step}). K_from_W: TensorHandle | None = None V_from_W: TensorHandle | None = None if has_W and (rank - step) >= 0: K_from_W = tl.recv( dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", ) V_from_W = tl.recv( dir="W", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", ) m_new, ell_new, o_new = _partial_attention( Q, K_from_W, V_from_W, S_q, S_kv_per_rank, h_q, d_head, tl, ) m_combined = tl.maximum(m, m_new) scale_old = tl.exp(m - m_combined) scale_new = tl.exp(m_new - m_combined) ell = ell * scale_old + ell_new * scale_new o = o * scale_old + o_new * scale_new m = m_combined # Receive westbound wave from E (carries chunk c_{rank + step}). K_from_E: TensorHandle | None = None V_from_E: TensorHandle | None = None if has_E and (rank + step) < n_ranks: K_from_E = tl.recv( dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", ) V_from_E = tl.recv( dir="E", shape=(S_kv_per_rank, h_kv, d_head), dtype="f16", ) m_new, ell_new, o_new = _partial_attention( Q, K_from_E, V_from_E, S_q, S_kv_per_rank, h_q, d_head, tl, ) m_combined = tl.maximum(m, m_new) scale_old = tl.exp(m - m_combined) scale_new = tl.exp(m_new - m_combined) ell = ell * scale_old + ell_new * scale_new o = o * scale_old + o_new * scale_new m = m_combined # Forward what we received for next step. ``None`` propagates: if no # chunk arrived this step (out-of-bounds wave origin), there is # nothing to forward next step in that direction. to_send_east_K = K_from_W to_send_east_V = V_from_W to_send_west_K = K_from_E to_send_west_V = V_from_E # Final normalize: O := o / ℓ. O_final = o / ell tl.store(o_ptr, O_final)