"""End-to-end engine drives for the four GQA Llama-70B panels (sub-cycle 4c step 2). Mirrors the existing single_user_decode diag harness across all four panels of the milestone-gqa-llama70b sweep (ADR-0057): single_user_prefill ring-K/V kernel, intracube PE ring (8 PEs / 1 cube) single_user_decode allreduce-mlo kernel, intracube PE ring multi_user_prefill ring-K/V kernel, intercube multisip (4 cubes) multi_user_decode allreduce-mlo kernel, intercube multisip Each test runs the panel through ``run_bench`` with ``enable_data=True`` and asserts ``result.completion.ok``. Failures dump the engine's op_log tail and the exception, mirroring the decode-diag harness format. Validation-scale config matches ADR-0057 D4: S_q_prefill=16, S_kv_per_rank=16, h_q=h_kv=1, d_head=64 n_ranks_single_user=8, n_ranks_multi_user=4 """ from __future__ import annotations import traceback from pathlib import Path import pytest from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.ccl.sfr_config import ( configure_sfr_intercube_multisip, configure_sfr_intracube_pe_ring, ) from kernbench.policy.placement.dp import DPPolicy from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology TOPOLOGY_PATH = Path(__file__).resolve().parents[2] / "topology.yaml" S_Q_PREFILL = 16 S_Q_DECODE = 1 S_KV_PER_RANK = 16 H_Q = 1 H_KV = 1 D_HEAD = 64 N_RANKS_SINGLE_USER = 8 N_RANKS_MULTI_USER = 4 DTYPE = "f16" # ── Helpers ────────────────────────────────────────────────────── def _engine_factory(t, d): return GraphEngine(getattr(t, "topology_obj", t), enable_data=True) def _run_panel(bench_fn): """Drive a panel through run_bench; return (exc, result, engine).""" topo = resolve_topology(str(TOPOLOGY_PATH)) captured: dict = {"engine": None} def factory(t, d): eng = _engine_factory(t, d) captured["engine"] = eng return eng exc = None result = None try: result = run_bench( topology=topo, bench_fn=bench_fn, device=resolve_device(None), engine_factory=factory, ) except BaseException as e: # noqa: BLE001 exc = e return exc, result, captured["engine"] def _assert_ok(name: str, exc, result, engine) -> None: if exc is not None: oplog_len = len(getattr(engine, "op_log", []) or []) if engine else 0 print(f"\n========== {name} FAIL ==========") print(f"op_log records before crash: {oplog_len}") print(f"{type(exc).__name__}: {exc}") traceback.print_exception(type(exc), exc, exc.__traceback__) raise AssertionError( f"{name} failed at runtime: {exc}" ) from exc assert result is not None, f"{name}: no result" assert result.completion.ok, f"{name}: completion not ok — {result.completion}" # ── Panel bench fns ────────────────────────────────────────────── def _bench_fn_single_user_prefill(ctx): configure_sfr_intracube_pe_ring( ctx.engine, ctx.spec, resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), ) n = N_RANKS_SINGLE_USER dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n) dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n) q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") ctx.launch( "single_user_prefill_mesh", attention_mesh_kv_kernel, q, k, v, o, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, ) def _bench_fn_single_user_decode(ctx): configure_sfr_intracube_pe_ring( ctx.engine, ctx.spec, resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), ) n = N_RANKS_SINGLE_USER dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n) dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n) q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") ctx.launch( "single_user_decode_mesh", attention_mesh_mlo_kernel, q, k, v, o, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, ) def _bench_fn_multi_user_prefill(ctx): configure_sfr_intercube_multisip( ctx.engine, ctx.spec, resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), ) n = N_RANKS_MULTI_USER dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8) dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8) q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") ctx.launch( "multi_user_prefill_mesh", attention_mesh_kv_kernel, q, k, v, o, S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, _auto_dim_remap=False, ) def _bench_fn_multi_user_decode(ctx): configure_sfr_intercube_multisip( ctx.engine, ctx.spec, resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), ) n = N_RANKS_MULTI_USER dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8) dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8) q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") ctx.launch( "multi_user_decode_mesh", attention_mesh_mlo_kernel, q, k, v, o, S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, _auto_dim_remap=False, ) # ── Tests ──────────────────────────────────────────────────────── def test_single_user_prefill_through_engine(): exc, result, engine = _run_panel(_bench_fn_single_user_prefill) _assert_ok("single_user_prefill", exc, result, engine) def test_single_user_decode_through_engine(): exc, result, engine = _run_panel(_bench_fn_single_user_decode) _assert_ok("single_user_decode", exc, result, engine) def test_multi_user_prefill_through_engine(): exc, result, engine = _run_panel(_bench_fn_multi_user_prefill) _assert_ok("multi_user_prefill", exc, result, engine) def test_multi_user_decode_through_engine(): exc, result, engine = _run_panel(_bench_fn_multi_user_decode) _assert_ok("multi_user_decode", exc, result, engine)