attention: add rank_axis kwarg to mesh kernels for multi_user cube ring
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.
Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
- 0 (default): rank == tl.program_id(axis=0). Existing single_user
behavior, all spec tests unchanged.
- 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
this to the kernel via ctx.launch positional arg.
Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).
Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -149,6 +149,7 @@ def _bench_fn_multi_user_prefill(ctx):
|
||||
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
@@ -169,6 +170,7 @@ def _bench_fn_multi_user_decode(ctx):
|
||||
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user