Files
kernbench2/tests/test_launch_dim_translation.py
mukesh d9e767d048 runtime_api: ctx.launch honors DPPolicy.num_cubes + adds _auto_dim_remap opt-out
Two compounding bugs in ctx.launch's dim-translation path surfaced
by multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2):

Bug A: _compute_local_shape divided by self._num_cubes (the topology's
cube count, 16 in default topology.yaml) instead of the DPPolicy's
effective num_cubes (4 for validation-scale multi_user). The tensor
allocator at context.py:471-484 already honored dp.num_cubes; the
parallel computation inside launch was out of sync. Fix mirrors the
allocator's eff_num_cubes precedence pattern.

Bug B: dim_map was keyed by value, so any scalar whose value
coincidentally equaled a global tensor dim got rewritten to that dim's
local value — e.g. d_head=64 colliding with K's global M=64 in
multi_user mode. Legacy bench kernels (va_offset etc.) rely on this
remap, so the fix is opt-out: ctx.launch(..., _auto_dim_remap=False)
preserves scalars exactly as passed. Default remains True.

Tests: 3 new dim-translation tests + 4-panel diag harness covers
single_user_* (PASS) and multi_user_* (advances to new SFR/axis layer
failure, tracked separately). va_offset + full attention spec suite
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-01 19:33:40 -07:00

132 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase 1 spec test for ``ctx.launch`` dim-translation bugs surfaced by
the multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2).
The default ``topology.yaml`` has 4×4 = 16 cubes per SIP, so
``RuntimeContext._num_cubes == 16``. Multi-user attention panels run a
4-cube ring (validation scale) by passing ``DPPolicy(num_cubes=4)``.
Two bugs in ``ctx.launch`` make this combination silently produce wrong
kernel arguments:
Bug A — _compute_local_shape ignores DPPolicy.num_cubes
``_compute_local_shape`` in ``ctx.launch`` divides by
``self._num_cubes`` (the topology's cube count, 16) instead of the
DPPolicy's effective ``num_cubes`` (4). So a ``(M=80, K=64)`` tensor
sharded ``cube="row_wise"`` with ``DPPolicy(num_cubes=4)`` produces
a local M of ``80 // 16 = 5``, not the kernel-expected ``80 // 4 = 20``.
Note: tensor allocation already honors ``dp.num_cubes`` correctly at
[context.py:471-484](src/kernbench/runtime_api/context.py#L471-L484);
the bug is the parallel computation inside ``launch`` is out of sync.
Bug B — scalar args coincidentally equal to a global tensor dim get auto-remapped
The dim_map at [context.py:712-770](src/kernbench/runtime_api/context.py#L712-L770)
is keyed by *value*, so any scalar whose value coincides with a
global tensor dim gets rewritten to that dim's local value — even
when the scalar is unrelated. ``d_head=64`` coincides with the
multi_user K's global M = ``S_kv_per_rank * n = 16 * 4 = 64``, so
the kernel receives ``d_head = 16`` (the post-Bug-A local) or
``d_head = 4`` (the pre-Bug-A local) instead of ``64``.
Legacy bench kernels rely on auto-remap (e.g. ``test_va_offset.py``
passes global N and expects the kernel to see local N). The fix is
opt-out, not removal: ``ctx.launch(..., _auto_dim_remap=False)``
preserves scalars exactly as passed, default behavior unchanged.
Both tests fail today. Phase 2 fixes them in [src/kernbench/runtime_api/context.py](src/kernbench/runtime_api/context.py).
"""
from __future__ import annotations
from pathlib import Path
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _make_ctx(corr_id: str) -> RuntimeContext:
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
return RuntimeContext(
engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id=corr_id, spec=graph.spec,
)
def test_topology_num_cubes_is_16_baseline_assumption():
"""Sanity: confirm the topology this test assumes (16 cubes per SIP).
If this fails, recheck the topology.yaml cube_mesh setting before
interpreting the other failures below. ``_num_cubes`` is initialized
lazily by ``_ensure_allocators`` on first tensor op, so trigger it."""
ctx = _make_ctx("dim-baseline")
ctx._ensure_allocators()
assert ctx._num_cubes == 16, (
f"expected default topology.yaml to give 16 cubes per SIP, "
f"got {ctx._num_cubes}"
)
def test_ctx_launch_local_shape_honors_dppolicy_num_cubes():
"""Bug A. ``DPPolicy(num_cubes=4)`` must be the divisor for
row_wise sharding inside ctx.launch's dim_map, not the topology's 16.
Setup: K-like tensor with M_global = 80 (cleanly divisible by both
4 and 16, distinct local values 20 vs 5). Pass M_global as a kernel
scalar; the kernel records what it received. With correct dim_map,
scalar 80 is remapped to 20 (80 / dp.num_cubes). With current code,
it is remapped to 5 (80 / self._num_cubes = 16).
"""
captured: dict[str, int] = {}
def _kernel(t, m_scalar, *, tl): # noqa: ARG001
captured["m_scalar"] = int(m_scalar)
ctx = _make_ctx("dim-bugA")
dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8)
t = ctx.zeros((80, 64), dtype="f16", dp=dp, name="t80x64")
ctx.launch("bugA_capture", _kernel, t, 80)
ctx.wait_all()
assert "m_scalar" in captured, "kernel was not invoked"
assert captured["m_scalar"] == 20, (
f"expected dim_map to divide 80 by dp.num_cubes=4 → 20; "
f"got {captured['m_scalar']} (likely divided by topology cubes=16)"
)
def test_ctx_launch_scalar_passed_through_when_auto_remap_disabled():
"""Bug B. Scalars must not be silently remapped when their value
happens to equal a tensor's global dim — at minimum the caller must
have an opt-out.
Setup: K-like tensor with M_global = 64 row_wise. Pass d_head = 64
as a scalar (semantically unrelated to K's M, but coincidentally
equal). The kernel records d_head. With ``_auto_dim_remap=False``
on ctx.launch, d_head must stay 64.
Today: ``_auto_dim_remap`` kwarg doesn't exist → TypeError. After
Phase 2: kwarg exists, defaults to True (legacy unchanged); passing
False preserves the scalar.
"""
captured: dict[str, int] = {}
def _kernel(t, d_head, *, tl): # noqa: ARG001
captured["d_head"] = int(d_head)
ctx = _make_ctx("dim-bugB")
dp = DPPolicy(cube="row_wise", pe="replicate", num_cubes=4, num_pes=8)
t = ctx.zeros((64, 64), dtype="f16", dp=dp, name="t64x64")
ctx.launch(
"bugB_capture", _kernel, t, 64,
_auto_dim_remap=False,
)
ctx.wait_all()
assert captured.get("d_head") == 64, (
f"expected d_head scalar to pass through unchanged when "
f"_auto_dim_remap=False; got {captured.get('d_head')!r}"
)