runtime_api: ctx.launch honors DPPolicy.num_cubes + adds _auto_dim_remap opt-out

Two compounding bugs in ctx.launch's dim-translation path surfaced
by multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2):

Bug A: _compute_local_shape divided by self._num_cubes (the topology's
cube count, 16 in default topology.yaml) instead of the DPPolicy's
effective num_cubes (4 for validation-scale multi_user). The tensor
allocator at context.py:471-484 already honored dp.num_cubes; the
parallel computation inside launch was out of sync. Fix mirrors the
allocator's eff_num_cubes precedence pattern.

Bug B: dim_map was keyed by value, so any scalar whose value
coincidentally equaled a global tensor dim got rewritten to that dim's
local value — e.g. d_head=64 colliding with K's global M=64 in
multi_user mode. Legacy bench kernels (va_offset etc.) rely on this
remap, so the fix is opt-out: ctx.launch(..., _auto_dim_remap=False)
preserves scalars exactly as passed. Default remains True.

Tests: 3 new dim-translation tests + 4-panel diag harness covers
single_user_* (PASS) and multi_user_* (advances to new SFR/axis layer
failure, tracked separately). va_offset + full attention spec suite
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 19:33:40 -07:00
parent 313dee503c
commit d9e767d048
3 changed files with 350 additions and 7 deletions
+23 -7
View File
@@ -609,6 +609,7 @@ class RuntimeContext:
kernel_fn: Any,
*args: Any,
_defer_wait: bool = False,
_auto_dim_remap: bool = True,
**kwargs: Any,
) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op).
@@ -700,21 +701,36 @@ class RuntimeContext:
return t.shape
# ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE
# are the only axes that shrink the local shape.
# Mirror the tensor allocator's precedence (context.py L471-484):
# DPPolicy.num_cubes overrides the topology's cube count when set.
# Without this, multi_user panels at validation scale
# (DPPolicy.num_cubes=4) get sharded as if the topology's full
# cube count (16) applied — see test_launch_dim_translation.py.
if dp.cube != "replicate":
eff_num_cubes = (
dp.num_cubes if dp.num_cubes is not None else self._num_cubes
)
if dp.cube == "column_wise":
K = K // self._num_cubes
K = K // eff_num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
M = M // eff_num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
# Auto-dim-remap (opt-out via _auto_dim_remap=False). Legacy
# kernels (e.g. va_offset bench) pass global dims as scalars and
# rely on launch to rewrite them to local. Mesh attention kernels
# already receive cube-local dims (S_kv_per_rank, d_head, …) and
# opt out — the remap would otherwise collide d_head=64 with K's
# global M=64 and rewrite d_head. See test_launch_dim_translation.py.
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
if _auto_dim_remap:
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None