Files
kernbench2/tests/attention/test_attention_mesh_panels_diag.py
T
mukesh 222815d374 attention: add rank_axis kwarg to mesh kernels for multi_user cube ring
ADR-0059 single_user_* panels run the ring across PEs in one cube
(rank == tl.program_id(axis=0)). multi_user_* panels run the ring
across cubes — rank should be cube_id (axis=1), and 7 of every 8 PEs
in each cube must stay silent because the cube-level SFR install only
gives the cube-coordinate PE 0 an E/W neighbor.

Add ``rank_axis: int = 0`` kwarg to both ``attention_mesh_mlo_kernel``
and ``attention_mesh_kv_kernel``:
  - 0 (default): rank == tl.program_id(axis=0). Existing single_user
    behavior, all spec tests unchanged.
  - 1: gate ``if tl.program_id(axis=0) != 0: return`` at kernel start,
    then ``rank = tl.program_id(axis=1)``. multi_user_* panels pass
    this to the kernel via ctx.launch positional arg.

Also brings in _attention_mesh_kv.py and _attention_mesh_mlo.py as
the committed home of the ADR-0059 kernels (previously living
uncommitted in the working tree from sub-cycle 4b).

Tests: 7-test rank_axis spec file (default-path + rank_axis=1 gating
and cube-id semantics, both kernels); 4-panel diag harness now green
end-to-end (single_user_prefill/decode + multi_user_prefill/decode);
763-test wider sweep clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 19:53:18 -07:00

199 lines
7.9 KiB
Python

"""End-to-end engine drives for the four GQA Llama-70B panels (sub-cycle 4c step 2).
Mirrors the existing single_user_decode diag harness across all four panels
of the milestone-gqa-llama70b sweep (ADR-0057):
single_user_prefill ring-K/V kernel, intracube PE ring (8 PEs / 1 cube)
single_user_decode allreduce-mlo kernel, intracube PE ring
multi_user_prefill ring-K/V kernel, intercube multisip (4 cubes)
multi_user_decode allreduce-mlo kernel, intercube multisip
Each test runs the panel through ``run_bench`` with ``enable_data=True``
and asserts ``result.completion.ok``. Failures dump the engine's op_log
tail and the exception, mirroring the decode-diag harness format.
Validation-scale config matches ADR-0057 D4:
S_q_prefill=16, S_kv_per_rank=16, h_q=h_kv=1, d_head=64
n_ranks_single_user=8, n_ranks_multi_user=4
"""
from __future__ import annotations
import traceback
from pathlib import Path
import pytest
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import (
configure_sfr_intercube_multisip,
configure_sfr_intracube_pe_ring,
)
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).resolve().parents[2] / "topology.yaml"
S_Q_PREFILL = 16
S_Q_DECODE = 1
S_KV_PER_RANK = 16
H_Q = 1
H_KV = 1
D_HEAD = 64
N_RANKS_SINGLE_USER = 8
N_RANKS_MULTI_USER = 4
DTYPE = "f16"
# ── Helpers ──────────────────────────────────────────────────────
def _engine_factory(t, d):
return GraphEngine(getattr(t, "topology_obj", t), enable_data=True)
def _run_panel(bench_fn):
"""Drive a panel through run_bench; return (exc, result, engine)."""
topo = resolve_topology(str(TOPOLOGY_PATH))
captured: dict = {"engine": None}
def factory(t, d):
eng = _engine_factory(t, d)
captured["engine"] = eng
return eng
exc = None
result = None
try:
result = run_bench(
topology=topo, bench_fn=bench_fn,
device=resolve_device(None), engine_factory=factory,
)
except BaseException as e: # noqa: BLE001
exc = e
return exc, result, captured["engine"]
def _assert_ok(name: str, exc, result, engine) -> None:
if exc is not None:
oplog_len = len(getattr(engine, "op_log", []) or []) if engine else 0
print(f"\n========== {name} FAIL ==========")
print(f"op_log records before crash: {oplog_len}")
print(f"{type(exc).__name__}: {exc}")
traceback.print_exception(type(exc), exc, exc.__traceback__)
raise AssertionError(
f"{name} failed at runtime: {exc}"
) from exc
assert result is not None, f"{name}: no result"
assert result.completion.ok, f"{name}: completion not ok — {result.completion}"
# ── Panel bench fns ──────────────────────────────────────────────
def _bench_fn_single_user_prefill(ctx):
configure_sfr_intracube_pe_ring(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
n = N_RANKS_SINGLE_USER
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
ctx.launch(
"single_user_prefill_mesh", attention_mesh_kv_kernel,
q, k, v, o,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
)
def _bench_fn_single_user_decode(ctx):
configure_sfr_intracube_pe_ring(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
n = N_RANKS_SINGLE_USER
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
ctx.launch(
"single_user_decode_mesh", attention_mesh_mlo_kernel,
q, k, v, o,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
)
def _bench_fn_multi_user_prefill(ctx):
configure_sfr_intercube_multisip(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
n = N_RANKS_MULTI_USER
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
ctx.launch(
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
q, k, v, o,
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
def _bench_fn_multi_user_decode(ctx):
configure_sfr_intercube_multisip(
ctx.engine, ctx.spec,
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
)
n = N_RANKS_MULTI_USER
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
ctx.launch(
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
q, k, v, o,
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
1, # rank_axis=1 → ring at cube level (ADR-0059 multi_user)
_auto_dim_remap=False,
)
# ── Tests ────────────────────────────────────────────────────────
def test_single_user_prefill_through_engine():
exc, result, engine = _run_panel(_bench_fn_single_user_prefill)
_assert_ok("single_user_prefill", exc, result, engine)
def test_single_user_decode_through_engine():
exc, result, engine = _run_panel(_bench_fn_single_user_decode)
_assert_ok("single_user_decode", exc, result, engine)
def test_multi_user_prefill_through_engine():
exc, result, engine = _run_panel(_bench_fn_multi_user_prefill)
_assert_ok("multi_user_prefill", exc, result, engine)
def test_multi_user_decode_through_engine():
exc, result, engine = _run_panel(_bench_fn_multi_user_decode)
_assert_ok("multi_user_decode", exc, result, engine)