runtime_api: ctx.launch honors DPPolicy.num_cubes + adds _auto_dim_remap opt-out
Two compounding bugs in ctx.launch's dim-translation path surfaced by multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2): Bug A: _compute_local_shape divided by self._num_cubes (the topology's cube count, 16 in default topology.yaml) instead of the DPPolicy's effective num_cubes (4 for validation-scale multi_user). The tensor allocator at context.py:471-484 already honored dp.num_cubes; the parallel computation inside launch was out of sync. Fix mirrors the allocator's eff_num_cubes precedence pattern. Bug B: dim_map was keyed by value, so any scalar whose value coincidentally equaled a global tensor dim got rewritten to that dim's local value — e.g. d_head=64 colliding with K's global M=64 in multi_user mode. Legacy bench kernels (va_offset etc.) rely on this remap, so the fix is opt-out: ctx.launch(..., _auto_dim_remap=False) preserves scalars exactly as passed. Default remains True. Tests: 3 new dim-translation tests + 4-panel diag harness covers single_user_* (PASS) and multi_user_* (advances to new SFR/axis layer failure, tracked separately). va_offset + full attention spec suite unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -609,6 +609,7 @@ class RuntimeContext:
|
||||
kernel_fn: Any,
|
||||
*args: Any,
|
||||
_defer_wait: bool = False,
|
||||
_auto_dim_remap: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RequestHandle:
|
||||
"""Register and launch a kernel (like a fused torch op).
|
||||
@@ -700,21 +701,36 @@ class RuntimeContext:
|
||||
return t.shape
|
||||
# ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE
|
||||
# are the only axes that shrink the local shape.
|
||||
# Mirror the tensor allocator's precedence (context.py L471-484):
|
||||
# DPPolicy.num_cubes overrides the topology's cube count when set.
|
||||
# Without this, multi_user panels at validation scale
|
||||
# (DPPolicy.num_cubes=4) get sharded as if the topology's full
|
||||
# cube count (16) applied — see test_launch_dim_translation.py.
|
||||
if dp.cube != "replicate":
|
||||
eff_num_cubes = (
|
||||
dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||||
)
|
||||
if dp.cube == "column_wise":
|
||||
K = K // self._num_cubes
|
||||
K = K // eff_num_cubes
|
||||
elif dp.cube == "row_wise":
|
||||
M = M // self._num_cubes
|
||||
M = M // eff_num_cubes
|
||||
if len(t.shape) < 2:
|
||||
return (K,)
|
||||
return (M, K)
|
||||
|
||||
# Auto-dim-remap (opt-out via _auto_dim_remap=False). Legacy
|
||||
# kernels (e.g. va_offset bench) pass global dims as scalars and
|
||||
# rely on launch to rewrite them to local. Mesh attention kernels
|
||||
# already receive cube-local dims (S_kv_per_rank, d_head, …) and
|
||||
# opt out — the remap would otherwise collide d_head=64 with K's
|
||||
# global M=64 and rewrite d_head. See test_launch_dim_translation.py.
|
||||
dim_map: dict[int, int] = {} # global_dim → local_dim
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
if _auto_dim_remap:
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
|
||||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||||
last_handle = None
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
"""End-to-end engine drives for the four GQA Llama-70B panels (sub-cycle 4c step 2).
|
||||
|
||||
Mirrors the existing single_user_decode diag harness across all four panels
|
||||
of the milestone-gqa-llama70b sweep (ADR-0057):
|
||||
|
||||
single_user_prefill ring-K/V kernel, intracube PE ring (8 PEs / 1 cube)
|
||||
single_user_decode allreduce-mlo kernel, intracube PE ring
|
||||
multi_user_prefill ring-K/V kernel, intercube multisip (4 cubes)
|
||||
multi_user_decode allreduce-mlo kernel, intercube multisip
|
||||
|
||||
Each test runs the panel through ``run_bench`` with ``enable_data=True``
|
||||
and asserts ``result.completion.ok``. Failures dump the engine's op_log
|
||||
tail and the exception, mirroring the decode-diag harness format.
|
||||
|
||||
Validation-scale config matches ADR-0057 D4:
|
||||
S_q_prefill=16, S_kv_per_rank=16, h_q=h_kv=1, d_head=64
|
||||
n_ranks_single_user=8, n_ranks_multi_user=4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel
|
||||
from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import (
|
||||
configure_sfr_intercube_multisip,
|
||||
configure_sfr_intracube_pe_ring,
|
||||
)
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).resolve().parents[2] / "topology.yaml"
|
||||
|
||||
S_Q_PREFILL = 16
|
||||
S_Q_DECODE = 1
|
||||
S_KV_PER_RANK = 16
|
||||
H_Q = 1
|
||||
H_KV = 1
|
||||
D_HEAD = 64
|
||||
N_RANKS_SINGLE_USER = 8
|
||||
N_RANKS_MULTI_USER = 4
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _engine_factory(t, d):
|
||||
return GraphEngine(getattr(t, "topology_obj", t), enable_data=True)
|
||||
|
||||
|
||||
def _run_panel(bench_fn):
|
||||
"""Drive a panel through run_bench; return (exc, result, engine)."""
|
||||
topo = resolve_topology(str(TOPOLOGY_PATH))
|
||||
captured: dict = {"engine": None}
|
||||
|
||||
def factory(t, d):
|
||||
eng = _engine_factory(t, d)
|
||||
captured["engine"] = eng
|
||||
return eng
|
||||
|
||||
exc = None
|
||||
result = None
|
||||
try:
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=bench_fn,
|
||||
device=resolve_device(None), engine_factory=factory,
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
exc = e
|
||||
return exc, result, captured["engine"]
|
||||
|
||||
|
||||
def _assert_ok(name: str, exc, result, engine) -> None:
|
||||
if exc is not None:
|
||||
oplog_len = len(getattr(engine, "op_log", []) or []) if engine else 0
|
||||
print(f"\n========== {name} FAIL ==========")
|
||||
print(f"op_log records before crash: {oplog_len}")
|
||||
print(f"{type(exc).__name__}: {exc}")
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
raise AssertionError(
|
||||
f"{name} failed at runtime: {exc}"
|
||||
) from exc
|
||||
assert result is not None, f"{name}: no result"
|
||||
assert result.completion.ok, f"{name}: completion not ok — {result.completion}"
|
||||
|
||||
|
||||
# ── Panel bench fns ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _bench_fn_single_user_prefill(ctx):
|
||||
configure_sfr_intracube_pe_ring(
|
||||
ctx.engine, ctx.spec,
|
||||
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||
)
|
||||
n = N_RANKS_SINGLE_USER
|
||||
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
|
||||
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
|
||||
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
||||
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
||||
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
||||
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
||||
ctx.launch(
|
||||
"single_user_prefill_mesh", attention_mesh_kv_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
)
|
||||
|
||||
|
||||
def _bench_fn_single_user_decode(ctx):
|
||||
configure_sfr_intracube_pe_ring(
|
||||
ctx.engine, ctx.spec,
|
||||
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||
)
|
||||
n = N_RANKS_SINGLE_USER
|
||||
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n)
|
||||
dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n)
|
||||
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
||||
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
||||
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
||||
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
||||
ctx.launch(
|
||||
"single_user_decode_mesh", attention_mesh_mlo_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
)
|
||||
|
||||
|
||||
def _bench_fn_multi_user_prefill(ctx):
|
||||
configure_sfr_intercube_multisip(
|
||||
ctx.engine, ctx.spec,
|
||||
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||
)
|
||||
n = N_RANKS_MULTI_USER
|
||||
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
|
||||
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
|
||||
q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
||||
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
||||
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
||||
o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
||||
ctx.launch(
|
||||
"multi_user_prefill_mesh", attention_mesh_kv_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
|
||||
def _bench_fn_multi_user_decode(ctx):
|
||||
configure_sfr_intercube_multisip(
|
||||
ctx.engine, ctx.spec,
|
||||
resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"),
|
||||
)
|
||||
n = N_RANKS_MULTI_USER
|
||||
dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8)
|
||||
dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8)
|
||||
q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q")
|
||||
k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k")
|
||||
v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v")
|
||||
o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o")
|
||||
ctx.launch(
|
||||
"multi_user_decode_mesh", attention_mesh_mlo_kernel,
|
||||
q, k, v, o,
|
||||
S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n,
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_single_user_prefill_through_engine():
|
||||
exc, result, engine = _run_panel(_bench_fn_single_user_prefill)
|
||||
_assert_ok("single_user_prefill", exc, result, engine)
|
||||
|
||||
|
||||
def test_single_user_decode_through_engine():
|
||||
exc, result, engine = _run_panel(_bench_fn_single_user_decode)
|
||||
_assert_ok("single_user_decode", exc, result, engine)
|
||||
|
||||
|
||||
def test_multi_user_prefill_through_engine():
|
||||
exc, result, engine = _run_panel(_bench_fn_multi_user_prefill)
|
||||
_assert_ok("multi_user_prefill", exc, result, engine)
|
||||
|
||||
|
||||
def test_multi_user_decode_through_engine():
|
||||
exc, result, engine = _run_panel(_bench_fn_multi_user_decode)
|
||||
_assert_ok("multi_user_decode", exc, result, engine)
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Phase 1 spec test for ``ctx.launch`` dim-translation bugs surfaced by
|
||||
the multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2).
|
||||
|
||||
The default ``topology.yaml`` has 4×4 = 16 cubes per SIP, so
|
||||
``RuntimeContext._num_cubes == 16``. Multi-user attention panels run a
|
||||
4-cube ring (validation scale) by passing ``DPPolicy(num_cubes=4)``.
|
||||
|
||||
Two bugs in ``ctx.launch`` make this combination silently produce wrong
|
||||
kernel arguments:
|
||||
|
||||
Bug A — _compute_local_shape ignores DPPolicy.num_cubes
|
||||
``_compute_local_shape`` in ``ctx.launch`` divides by
|
||||
``self._num_cubes`` (the topology's cube count, 16) instead of the
|
||||
DPPolicy's effective ``num_cubes`` (4). So a ``(M=80, K=64)`` tensor
|
||||
sharded ``cube="row_wise"`` with ``DPPolicy(num_cubes=4)`` produces
|
||||
a local M of ``80 // 16 = 5``, not the kernel-expected ``80 // 4 = 20``.
|
||||
Note: tensor allocation already honors ``dp.num_cubes`` correctly at
|
||||
[context.py:471-484](src/kernbench/runtime_api/context.py#L471-L484);
|
||||
the bug is the parallel computation inside ``launch`` is out of sync.
|
||||
|
||||
Bug B — scalar args coincidentally equal to a global tensor dim get auto-remapped
|
||||
The dim_map at [context.py:712-770](src/kernbench/runtime_api/context.py#L712-L770)
|
||||
is keyed by *value*, so any scalar whose value coincides with a
|
||||
global tensor dim gets rewritten to that dim's local value — even
|
||||
when the scalar is unrelated. ``d_head=64`` coincides with the
|
||||
multi_user K's global M = ``S_kv_per_rank * n = 16 * 4 = 64``, so
|
||||
the kernel receives ``d_head = 16`` (the post-Bug-A local) or
|
||||
``d_head = 4`` (the pre-Bug-A local) instead of ``64``.
|
||||
|
||||
Legacy bench kernels rely on auto-remap (e.g. ``test_va_offset.py``
|
||||
passes global N and expects the kernel to see local N). The fix is
|
||||
opt-out, not removal: ``ctx.launch(..., _auto_dim_remap=False)``
|
||||
preserves scalars exactly as passed, default behavior unchanged.
|
||||
|
||||
Both tests fail today. Phase 2 fixes them in [src/kernbench/runtime_api/context.py](src/kernbench/runtime_api/context.py).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _make_ctx(corr_id: str) -> RuntimeContext:
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
return RuntimeContext(
|
||||
engine=engine, target_device=DeviceSelector("sip:0"),
|
||||
correlation_id=corr_id, spec=graph.spec,
|
||||
)
|
||||
|
||||
|
||||
def test_topology_num_cubes_is_16_baseline_assumption():
|
||||
"""Sanity: confirm the topology this test assumes (16 cubes per SIP).
|
||||
If this fails, recheck the topology.yaml cube_mesh setting before
|
||||
interpreting the other failures below. ``_num_cubes`` is initialized
|
||||
lazily by ``_ensure_allocators`` on first tensor op, so trigger it."""
|
||||
ctx = _make_ctx("dim-baseline")
|
||||
ctx._ensure_allocators()
|
||||
assert ctx._num_cubes == 16, (
|
||||
f"expected default topology.yaml to give 16 cubes per SIP, "
|
||||
f"got {ctx._num_cubes}"
|
||||
)
|
||||
|
||||
|
||||
def test_ctx_launch_local_shape_honors_dppolicy_num_cubes():
|
||||
"""Bug A. ``DPPolicy(num_cubes=4)`` must be the divisor for
|
||||
row_wise sharding inside ctx.launch's dim_map, not the topology's 16.
|
||||
|
||||
Setup: K-like tensor with M_global = 80 (cleanly divisible by both
|
||||
4 and 16, distinct local values 20 vs 5). Pass M_global as a kernel
|
||||
scalar; the kernel records what it received. With correct dim_map,
|
||||
scalar 80 is remapped to 20 (80 / dp.num_cubes). With current code,
|
||||
it is remapped to 5 (80 / self._num_cubes = 16).
|
||||
"""
|
||||
captured: dict[str, int] = {}
|
||||
|
||||
def _kernel(t, m_scalar, *, tl): # noqa: ARG001
|
||||
captured["m_scalar"] = int(m_scalar)
|
||||
|
||||
ctx = _make_ctx("dim-bugA")
|
||||
dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8)
|
||||
t = ctx.zeros((80, 64), dtype="f16", dp=dp, name="t80x64")
|
||||
ctx.launch("bugA_capture", _kernel, t, 80)
|
||||
ctx.wait_all()
|
||||
|
||||
assert "m_scalar" in captured, "kernel was not invoked"
|
||||
assert captured["m_scalar"] == 20, (
|
||||
f"expected dim_map to divide 80 by dp.num_cubes=4 → 20; "
|
||||
f"got {captured['m_scalar']} (likely divided by topology cubes=16)"
|
||||
)
|
||||
|
||||
|
||||
def test_ctx_launch_scalar_passed_through_when_auto_remap_disabled():
|
||||
"""Bug B. Scalars must not be silently remapped when their value
|
||||
happens to equal a tensor's global dim — at minimum the caller must
|
||||
have an opt-out.
|
||||
|
||||
Setup: K-like tensor with M_global = 64 row_wise. Pass d_head = 64
|
||||
as a scalar (semantically unrelated to K's M, but coincidentally
|
||||
equal). The kernel records d_head. With ``_auto_dim_remap=False``
|
||||
on ctx.launch, d_head must stay 64.
|
||||
|
||||
Today: ``_auto_dim_remap`` kwarg doesn't exist → TypeError. After
|
||||
Phase 2: kwarg exists, defaults to True (legacy unchanged); passing
|
||||
False preserves the scalar.
|
||||
"""
|
||||
captured: dict[str, int] = {}
|
||||
|
||||
def _kernel(t, d_head, *, tl): # noqa: ARG001
|
||||
captured["d_head"] = int(d_head)
|
||||
|
||||
ctx = _make_ctx("dim-bugB")
|
||||
dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8)
|
||||
t = ctx.zeros((64, 64), dtype="f16", dp=dp, name="t64x64")
|
||||
ctx.launch(
|
||||
"bugB_capture", _kernel, t, 64,
|
||||
_auto_dim_remap=False,
|
||||
)
|
||||
ctx.wait_all()
|
||||
|
||||
assert captured.get("d_head") == 64, (
|
||||
f"expected d_head scalar to pass through unchanged when "
|
||||
f"_auto_dim_remap=False; got {captured.get('d_head')!r}"
|
||||
)
|
||||
Reference in New Issue
Block a user