diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index 6e94388..97d2855 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -609,6 +609,7 @@ class RuntimeContext: kernel_fn: Any, *args: Any, _defer_wait: bool = False, + _auto_dim_remap: bool = True, **kwargs: Any, ) -> RequestHandle: """Register and launch a kernel (like a fused torch op). @@ -700,21 +701,36 @@ class RuntimeContext: return t.shape # ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE # are the only axes that shrink the local shape. + # Mirror the tensor allocator's precedence (context.py L471-484): + # DPPolicy.num_cubes overrides the topology's cube count when set. + # Without this, multi_user panels at validation scale + # (DPPolicy.num_cubes=4) get sharded as if the topology's full + # cube count (16) applied — see test_launch_dim_translation.py. if dp.cube != "replicate": + eff_num_cubes = ( + dp.num_cubes if dp.num_cubes is not None else self._num_cubes + ) if dp.cube == "column_wise": - K = K // self._num_cubes + K = K // eff_num_cubes elif dp.cube == "row_wise": - M = M // self._num_cubes + M = M // eff_num_cubes if len(t.shape) < 2: return (K,) return (M, K) + # Auto-dim-remap (opt-out via _auto_dim_remap=False). Legacy + # kernels (e.g. va_offset bench) pass global dims as scalars and + # rely on launch to rewrite them to local. Mesh attention kernels + # already receive cube-local dims (S_kv_per_rank, d_head, …) and + # opt out — the remap would otherwise collide d_head=64 with K's + # global M=64 and rewrite d_head. See test_launch_dim_translation.py. dim_map: dict[int, int] = {} # global_dim → local_dim - for t in tensor_args: - local = _compute_local_shape(t) - for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])): - if g != l: - dim_map[g] = l + if _auto_dim_remap: + for t in tensor_args: + local = _compute_local_shape(t) + for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])): + if g != l: + dim_map[g] = l # Per-SIP kernel launch: each SIP gets TensorArgs with local va_base last_handle = None diff --git a/tests/attention/test_attention_mesh_panels_diag.py b/tests/attention/test_attention_mesh_panels_diag.py new file mode 100644 index 0000000..df15132 --- /dev/null +++ b/tests/attention/test_attention_mesh_panels_diag.py @@ -0,0 +1,196 @@ +"""End-to-end engine drives for the four GQA Llama-70B panels (sub-cycle 4c step 2). + +Mirrors the existing single_user_decode diag harness across all four panels +of the milestone-gqa-llama70b sweep (ADR-0057): + + single_user_prefill ring-K/V kernel, intracube PE ring (8 PEs / 1 cube) + single_user_decode allreduce-mlo kernel, intracube PE ring + multi_user_prefill ring-K/V kernel, intercube multisip (4 cubes) + multi_user_decode allreduce-mlo kernel, intercube multisip + +Each test runs the panel through ``run_bench`` with ``enable_data=True`` +and asserts ``result.completion.ok``. Failures dump the engine's op_log +tail and the exception, mirroring the decode-diag harness format. + +Validation-scale config matches ADR-0057 D4: + S_q_prefill=16, S_kv_per_rank=16, h_q=h_kv=1, d_head=64 + n_ranks_single_user=8, n_ranks_multi_user=4 +""" +from __future__ import annotations + +import traceback +from pathlib import Path + +import pytest + +from kernbench.benches._attention_mesh_kv import attention_mesh_kv_kernel +from kernbench.benches._attention_mesh_mlo import attention_mesh_mlo_kernel +from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config +from kernbench.ccl.sfr_config import ( + configure_sfr_intercube_multisip, + configure_sfr_intracube_pe_ring, +) +from kernbench.policy.placement.dp import DPPolicy +from kernbench.runtime_api.bench_runner import run_bench +from kernbench.runtime_api.types import resolve_device +from kernbench.sim_engine.engine import GraphEngine +from kernbench.topology.builder import resolve_topology + +TOPOLOGY_PATH = Path(__file__).resolve().parents[2] / "topology.yaml" + +S_Q_PREFILL = 16 +S_Q_DECODE = 1 +S_KV_PER_RANK = 16 +H_Q = 1 +H_KV = 1 +D_HEAD = 64 +N_RANKS_SINGLE_USER = 8 +N_RANKS_MULTI_USER = 4 +DTYPE = "f16" + + +# ── Helpers ────────────────────────────────────────────────────── + + +def _engine_factory(t, d): + return GraphEngine(getattr(t, "topology_obj", t), enable_data=True) + + +def _run_panel(bench_fn): + """Drive a panel through run_bench; return (exc, result, engine).""" + topo = resolve_topology(str(TOPOLOGY_PATH)) + captured: dict = {"engine": None} + + def factory(t, d): + eng = _engine_factory(t, d) + captured["engine"] = eng + return eng + + exc = None + result = None + try: + result = run_bench( + topology=topo, bench_fn=bench_fn, + device=resolve_device(None), engine_factory=factory, + ) + except BaseException as e: # noqa: BLE001 + exc = e + return exc, result, captured["engine"] + + +def _assert_ok(name: str, exc, result, engine) -> None: + if exc is not None: + oplog_len = len(getattr(engine, "op_log", []) or []) if engine else 0 + print(f"\n========== {name} FAIL ==========") + print(f"op_log records before crash: {oplog_len}") + print(f"{type(exc).__name__}: {exc}") + traceback.print_exception(type(exc), exc, exc.__traceback__) + raise AssertionError( + f"{name} failed at runtime: {exc}" + ) from exc + assert result is not None, f"{name}: no result" + assert result.completion.ok, f"{name}: completion not ok — {result.completion}" + + +# ── Panel bench fns ────────────────────────────────────────────── + + +def _bench_fn_single_user_prefill(ctx): + configure_sfr_intracube_pe_ring( + ctx.engine, ctx.spec, + resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), + ) + n = N_RANKS_SINGLE_USER + dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n) + dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n) + q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") + k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") + v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") + o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") + ctx.launch( + "single_user_prefill_mesh", attention_mesh_kv_kernel, + q, k, v, o, + S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + ) + + +def _bench_fn_single_user_decode(ctx): + configure_sfr_intracube_pe_ring( + ctx.engine, ctx.spec, + resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), + ) + n = N_RANKS_SINGLE_USER + dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=n) + dp_kv = DPPolicy(cube="replicate", pe="row_wise", num_cubes=1, num_pes=n) + q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") + k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") + v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") + o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") + ctx.launch( + "single_user_decode_mesh", attention_mesh_mlo_kernel, + q, k, v, o, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + ) + + +def _bench_fn_multi_user_prefill(ctx): + configure_sfr_intercube_multisip( + ctx.engine, ctx.spec, + resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), + ) + n = N_RANKS_MULTI_USER + dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8) + dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8) + q = ctx.zeros((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") + k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") + v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") + o = ctx.empty((S_Q_PREFILL, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") + ctx.launch( + "multi_user_prefill_mesh", attention_mesh_kv_kernel, + q, k, v, o, + S_Q_PREFILL, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + _auto_dim_remap=False, + ) + + +def _bench_fn_multi_user_decode(ctx): + configure_sfr_intercube_multisip( + ctx.engine, ctx.spec, + resolve_algorithm_config(load_ccl_config(), name="lrab_hierarchical_allreduce"), + ) + n = N_RANKS_MULTI_USER + dp_full = DPPolicy(cube="replicate", pe="replicate", num_cubes=n, num_pes=8) + dp_kv = DPPolicy(cube="row_wise", pe="replicate", num_cubes=n, num_pes=8) + q = ctx.zeros((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="q") + k = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="k") + v = ctx.zeros((S_KV_PER_RANK * n, H_KV * D_HEAD), dtype=DTYPE, dp=dp_kv, name="v") + o = ctx.empty((S_Q_DECODE, H_Q * D_HEAD), dtype=DTYPE, dp=dp_full, name="o") + ctx.launch( + "multi_user_decode_mesh", attention_mesh_mlo_kernel, + q, k, v, o, + S_Q_DECODE, S_KV_PER_RANK, H_Q, H_KV, D_HEAD, n, + _auto_dim_remap=False, + ) + + +# ── Tests ──────────────────────────────────────────────────────── + + +def test_single_user_prefill_through_engine(): + exc, result, engine = _run_panel(_bench_fn_single_user_prefill) + _assert_ok("single_user_prefill", exc, result, engine) + + +def test_single_user_decode_through_engine(): + exc, result, engine = _run_panel(_bench_fn_single_user_decode) + _assert_ok("single_user_decode", exc, result, engine) + + +def test_multi_user_prefill_through_engine(): + exc, result, engine = _run_panel(_bench_fn_multi_user_prefill) + _assert_ok("multi_user_prefill", exc, result, engine) + + +def test_multi_user_decode_through_engine(): + exc, result, engine = _run_panel(_bench_fn_multi_user_decode) + _assert_ok("multi_user_decode", exc, result, engine) diff --git a/tests/test_launch_dim_translation.py b/tests/test_launch_dim_translation.py new file mode 100644 index 0000000..488e7c8 --- /dev/null +++ b/tests/test_launch_dim_translation.py @@ -0,0 +1,131 @@ +"""Phase 1 spec test for ``ctx.launch`` dim-translation bugs surfaced by +the multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2). + +The default ``topology.yaml`` has 4×4 = 16 cubes per SIP, so +``RuntimeContext._num_cubes == 16``. Multi-user attention panels run a +4-cube ring (validation scale) by passing ``DPPolicy(num_cubes=4)``. + +Two bugs in ``ctx.launch`` make this combination silently produce wrong +kernel arguments: + +Bug A — _compute_local_shape ignores DPPolicy.num_cubes + ``_compute_local_shape`` in ``ctx.launch`` divides by + ``self._num_cubes`` (the topology's cube count, 16) instead of the + DPPolicy's effective ``num_cubes`` (4). So a ``(M=80, K=64)`` tensor + sharded ``cube="row_wise"`` with ``DPPolicy(num_cubes=4)`` produces + a local M of ``80 // 16 = 5``, not the kernel-expected ``80 // 4 = 20``. + Note: tensor allocation already honors ``dp.num_cubes`` correctly at + [context.py:471-484](src/kernbench/runtime_api/context.py#L471-L484); + the bug is the parallel computation inside ``launch`` is out of sync. + +Bug B — scalar args coincidentally equal to a global tensor dim get auto-remapped + The dim_map at [context.py:712-770](src/kernbench/runtime_api/context.py#L712-L770) + is keyed by *value*, so any scalar whose value coincides with a + global tensor dim gets rewritten to that dim's local value — even + when the scalar is unrelated. ``d_head=64`` coincides with the + multi_user K's global M = ``S_kv_per_rank * n = 16 * 4 = 64``, so + the kernel receives ``d_head = 16`` (the post-Bug-A local) or + ``d_head = 4`` (the pre-Bug-A local) instead of ``64``. + + Legacy bench kernels rely on auto-remap (e.g. ``test_va_offset.py`` + passes global N and expects the kernel to see local N). The fix is + opt-out, not removal: ``ctx.launch(..., _auto_dim_remap=False)`` + preserves scalars exactly as passed, default behavior unchanged. + +Both tests fail today. Phase 2 fixes them in [src/kernbench/runtime_api/context.py](src/kernbench/runtime_api/context.py). +""" +from __future__ import annotations + +from pathlib import Path + +from kernbench.policy.placement.dp import DPPolicy +from kernbench.runtime_api.context import RuntimeContext +from kernbench.runtime_api.types import DeviceSelector +from kernbench.sim_engine.engine import GraphEngine +from kernbench.topology.builder import load_topology + +TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" + + +def _make_ctx(corr_id: str) -> RuntimeContext: + graph = load_topology(TOPOLOGY_PATH) + engine = GraphEngine(graph) + return RuntimeContext( + engine=engine, target_device=DeviceSelector("sip:0"), + correlation_id=corr_id, spec=graph.spec, + ) + + +def test_topology_num_cubes_is_16_baseline_assumption(): + """Sanity: confirm the topology this test assumes (16 cubes per SIP). + If this fails, recheck the topology.yaml cube_mesh setting before + interpreting the other failures below. ``_num_cubes`` is initialized + lazily by ``_ensure_allocators`` on first tensor op, so trigger it.""" + ctx = _make_ctx("dim-baseline") + ctx._ensure_allocators() + assert ctx._num_cubes == 16, ( + f"expected default topology.yaml to give 16 cubes per SIP, " + f"got {ctx._num_cubes}" + ) + + +def test_ctx_launch_local_shape_honors_dppolicy_num_cubes(): + """Bug A. ``DPPolicy(num_cubes=4)`` must be the divisor for + row_wise sharding inside ctx.launch's dim_map, not the topology's 16. + + Setup: K-like tensor with M_global = 80 (cleanly divisible by both + 4 and 16, distinct local values 20 vs 5). Pass M_global as a kernel + scalar; the kernel records what it received. With correct dim_map, + scalar 80 is remapped to 20 (80 / dp.num_cubes). With current code, + it is remapped to 5 (80 / self._num_cubes = 16). + """ + captured: dict[str, int] = {} + + def _kernel(t, m_scalar, *, tl): # noqa: ARG001 + captured["m_scalar"] = int(m_scalar) + + ctx = _make_ctx("dim-bugA") + dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8) + t = ctx.zeros((80, 64), dtype="f16", dp=dp, name="t80x64") + ctx.launch("bugA_capture", _kernel, t, 80) + ctx.wait_all() + + assert "m_scalar" in captured, "kernel was not invoked" + assert captured["m_scalar"] == 20, ( + f"expected dim_map to divide 80 by dp.num_cubes=4 → 20; " + f"got {captured['m_scalar']} (likely divided by topology cubes=16)" + ) + + +def test_ctx_launch_scalar_passed_through_when_auto_remap_disabled(): + """Bug B. Scalars must not be silently remapped when their value + happens to equal a tensor's global dim — at minimum the caller must + have an opt-out. + + Setup: K-like tensor with M_global = 64 row_wise. Pass d_head = 64 + as a scalar (semantically unrelated to K's M, but coincidentally + equal). The kernel records d_head. With ``_auto_dim_remap=False`` + on ctx.launch, d_head must stay 64. + + Today: ``_auto_dim_remap`` kwarg doesn't exist → TypeError. After + Phase 2: kwarg exists, defaults to True (legacy unchanged); passing + False preserves the scalar. + """ + captured: dict[str, int] = {} + + def _kernel(t, d_head, *, tl): # noqa: ARG001 + captured["d_head"] = int(d_head) + + ctx = _make_ctx("dim-bugB") + dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8) + t = ctx.zeros((64, 64), dtype="f16", dp=dp, name="t64x64") + ctx.launch( + "bugB_capture", _kernel, t, 64, + _auto_dim_remap=False, + ) + ctx.wait_all() + + assert captured.get("d_head") == 64, ( + f"expected d_head scalar to pass through unchanged when " + f"_auto_dim_remap=False; got {captured.get('d_head')!r}" + )