d9e767d048
Two compounding bugs in ctx.launch's dim-translation path surfaced by multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2): Bug A: _compute_local_shape divided by self._num_cubes (the topology's cube count, 16 in default topology.yaml) instead of the DPPolicy's effective num_cubes (4 for validation-scale multi_user). The tensor allocator at context.py:471-484 already honored dp.num_cubes; the parallel computation inside launch was out of sync. Fix mirrors the allocator's eff_num_cubes precedence pattern. Bug B: dim_map was keyed by value, so any scalar whose value coincidentally equaled a global tensor dim got rewritten to that dim's local value — e.g. d_head=64 colliding with K's global M=64 in multi_user mode. Legacy bench kernels (va_offset etc.) rely on this remap, so the fix is opt-out: ctx.launch(..., _auto_dim_remap=False) preserves scalars exactly as passed. Default remains True. Tests: 3 new dim-translation tests + 4-panel diag harness covers single_user_* (PASS) and multi_user_* (advances to new SFR/axis layer failure, tracked separately). va_offset + full attention spec suite unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
197 lines
7.7 KiB
Python
197 lines
7.7 KiB
Python
"""End-to-end engine drives for the four GQA Llama-70B panels (sub-cycle 4c step 2).
|
|
|
|
Mirrors the existing single_user_decode diag harness across all four panels
|
|
of the milestone-gqa-llama70b sweep (ADR-0057):
|
|
|
|
single_user_prefill ring-K/V kernel, intracube PE ring (8 PEs / 1 cube)
|
|
single_user_decode allreduce-mlo kernel, intracube PE ring
|
|
multi_user_prefill ring-K/V kernel, intercube multisip (4 cubes)
|
|
multi_user_decode allreduce-mlo kernel, intercube multisip
|
|
|
|
Each test runs the panel through ``run_bench`` with ``enable_data=True``
|
|
and asserts ``result.completion.ok``. Failures dump the engine's op_log
|
|
tail and the exception, mirroring the decode-diag harness format.
|
|
|
|
Validation-scale config matches ADR-0057 D4:
|
|
S_q_prefill=16, S_kv_per_rank=16, h_q=h_kv=1, d_head=64
|
|
n_ranks_single_user=8, n_ranks_multi_user=4
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import traceback
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
|
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
|
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
|
from kernbench.ccl.sfr_config import (
|
|
configure_sfr_intercube_multisip,
|
|
configure_sfr_intracube_pe_ring,
|
|
)
|
|
from kernbench.policy.placement.dp import DPPolicy
|
|
from kernbench.runtime_api.bench_runner import run_bench
|
|
from kernbench.runtime_api.types import resolve_device
|
|
from kernbench.sim_engine.engine import GraphEngine
|
|
from kernbench.topology.builder import resolve_topology
|
|
|
|
TOPOLOGY_PATH = Path(__file__).resolve().parents[2] / "topology.yaml"
|
|
|
|
S_Q_PREFILL = 16
|
|
S_Q_DECODE = 1
|
|
S_KV_PER_RANK = 16
|
|
H_Q = 1
|
|
H_KV = 1
|
|
D_HEAD = 64
|
|
N_RANKS_SINGLE_USER = 8
|
|
N_RANKS_MULTI_USER = 4
|
|
DTYPE = "f16"
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────
|
|
|
|
|
|
def _engine_factory(t, d):
|
|
return GraphEngine(getattr(t, "topology_obj", t), enable_data=True)
|
|
|
|
|
|
def _run_panel(bench_fn):
|
|
"""Drive a panel through run_bench; return (exc, result, engine)."""
|
|
topo = resolve_topology(str(TOPOLOGY_PATH))
|
|
captured: dict = {"engine": None}
|
|
|
|
def factory(t, d):
|
|
eng = _engine_factory(t, d)
|
|
captured["engine"] = eng
|
|
return eng
|
|
|
|
exc = None
|
|
result = None
|
|
try:
|
|
result = run_bench(
|
|
topology=topo, bench_fn=bench_fn,
|
|
device=resolve_device(None), engine_factory=factory,
|
|
)
|
|
except BaseException as e: # noqa: BLE001
|
|
exc = e
|
|
return exc, result, captured["engine"]
|
|
|
|
|
|
def _assert_ok(name: str, exc, result, engine) -> None:
|
|
if exc is not None:
|
|
oplog_len = len(getattr(engine, "op_log", []) or []) if engine else 0
|
|
print(f"\n========== {name} FAIL ==========")
|
|
print(f"op_log records before crash: {oplog_len}")
|
|
print(f"{type(exc).__name__}: {exc}")
|
|
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
|
raise AssertionError(
|
|
f"{name} failed at runtime: {exc}"
|
|
) from exc
|
|
assert result is not None, f"{name}: no result"
|
|
assert result.completion.ok, f"{name}: completion not ok — {result.completion}"
|
|
|
|
|
|
# ── Panel bench fns ──────────────────────────────────────────────
|
|
|
|
|
|
def _bench_fn_single_user_prefill(ctx):
|
|
configure_sfr_intracube_pe_ring(
|
|
ctx.engine, ctx.spec,
|
|
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
|
)
|
|
n = N_RANKS_SINGLE_USER
|
|
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
|
|
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
|
|
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
|
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
|
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
|
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
|
ctx.launch(
|
|
"single_user_prefill_mesh", attention_mesh_kv_kernel,
|
|
q, k, v, o,
|
|
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
|
)
|
|
|
|
|
|
def _bench_fn_single_user_decode(ctx):
|
|
configure_sfr_intracube_pe_ring(
|
|
ctx.engine, ctx.spec,
|
|
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
|
)
|
|
n = N_RANKS_SINGLE_USER
|
|
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
|
|
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
|
|
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
|
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
|
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
|
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
|
ctx.launch(
|
|
"single_user_decode_mesh", attention_mesh_mlo_kernel,
|
|
q, k, v, o,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
|
)
|
|
|
|
|
|
def _bench_fn_multi_user_prefill(ctx):
|
|
configure_sfr_intercube_multisip(
|
|
ctx.engine, ctx.spec,
|
|
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
|
)
|
|
n = N_RANKS_MULTI_USER
|
|
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
|
|
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
|
|
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
|
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
|
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
|
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
|
ctx.launch(
|
|
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
|
|
q, k, v, o,
|
|
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
|
_auto_dim_remap=False,
|
|
)
|
|
|
|
|
|
def _bench_fn_multi_user_decode(ctx):
|
|
configure_sfr_intercube_multisip(
|
|
ctx.engine, ctx.spec,
|
|
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
|
)
|
|
n = N_RANKS_MULTI_USER
|
|
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
|
|
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
|
|
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
|
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
|
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
|
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
|
ctx.launch(
|
|
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
|
|
q, k, v, o,
|
|
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
|
_auto_dim_remap=False,
|
|
)
|
|
|
|
|
|
# ── Tests ────────────────────────────────────────────────────────
|
|
|
|
|
|
def test_single_user_prefill_through_engine():
|
|
exc, result, engine = _run_panel(_bench_fn_single_user_prefill)
|
|
_assert_ok("single_user_prefill", exc, result, engine)
|
|
|
|
|
|
def test_single_user_decode_through_engine():
|
|
exc, result, engine = _run_panel(_bench_fn_single_user_decode)
|
|
_assert_ok("single_user_decode", exc, result, engine)
|
|
|
|
|
|
def test_multi_user_prefill_through_engine():
|
|
exc, result, engine = _run_panel(_bench_fn_multi_user_prefill)
|
|
_assert_ok("multi_user_prefill", exc, result, engine)
|
|
|
|
|
|
def test_multi_user_decode_through_engine():
|
|
exc, result, engine = _run_panel(_bench_fn_multi_user_decode)
|
|
_assert_ok("multi_user_decode", exc, result, engine)
|