runtime_api: ctx.launch honors DPPolicy.num_cubes + adds _auto_dim_remap opt-out
Two compounding bugs in ctx.launch's dim-translation path surfaced by multi_user_* panels of milestone-gqa-llama70b (sub-cycle 4c step 2): Bug A: _compute_local_shape divided by self._num_cubes (the topology's cube count, 16 in default topology.yaml) instead of the DPPolicy's effective num_cubes (4 for validation-scale multi_user). The tensor allocator at context.py:471-484 already honored dp.num_cubes; the parallel computation inside launch was out of sync. Fix mirrors the allocator's eff_num_cubes precedence pattern. Bug B: dim_map was keyed by value, so any scalar whose value coincidentally equaled a global tensor dim got rewritten to that dim's local value — e.g. d_head=64 colliding with K's global M=64 in multi_user mode. Legacy bench kernels (va_offset etc.) rely on this remap, so the fix is opt-out: ctx.launch(..., _auto_dim_remap=False) preserves scalars exactly as passed. Default remains True. Tests: 3 new dim-translation tests + 4-panel diag harness covers single_user_* (PASS) and multi_user_* (advances to new SFR/axis layer failure, tracked separately). va_offset + full attention spec suite unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -609,6 +609,7 @@ class RuntimeContext:
|
||||
kernel_fn: Any,
|
||||
*args: Any,
|
||||
_defer_wait: bool = False,
|
||||
_auto_dim_remap: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RequestHandle:
|
||||
"""Register and launch a kernel (like a fused torch op).
|
||||
@@ -700,21 +701,36 @@ class RuntimeContext:
|
||||
return t.shape
|
||||
# ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE
|
||||
# are the only axes that shrink the local shape.
|
||||
# Mirror the tensor allocator's precedence (context.py L471-484):
|
||||
# DPPolicy.num_cubes overrides the topology's cube count when set.
|
||||
# Without this, multi_user panels at validation scale
|
||||
# (DPPolicy.num_cubes=4) get sharded as if the topology's full
|
||||
# cube count (16) applied — see test_launch_dim_translation.py.
|
||||
if dp.cube != "replicate":
|
||||
eff_num_cubes = (
|
||||
dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||||
)
|
||||
if dp.cube == "column_wise":
|
||||
K = K // self._num_cubes
|
||||
K = K // eff_num_cubes
|
||||
elif dp.cube == "row_wise":
|
||||
M = M // self._num_cubes
|
||||
M = M // eff_num_cubes
|
||||
if len(t.shape) < 2:
|
||||
return (K,)
|
||||
return (M, K)
|
||||
|
||||
# Auto-dim-remap (opt-out via _auto_dim_remap=False). Legacy
|
||||
# kernels (e.g. va_offset bench) pass global dims as scalars and
|
||||
# rely on launch to rewrite them to local. Mesh attention kernels
|
||||
# already receive cube-local dims (S_kv_per_rank, d_head, …) and
|
||||
# opt out — the remap would otherwise collide d_head=64 with K's
|
||||
# global M=64 and rewrite d_head. See test_launch_dim_translation.py.
|
||||
dim_map: dict[int, int] = {} # global_dim → local_dim
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
if _auto_dim_remap:
|
||||
for t in tensor_args:
|
||||
local = _compute_local_shape(t)
|
||||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||||
if g != l:
|
||||
dim_map[g] = l
|
||||
|
||||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||||
last_handle = None
|
||||
|
||||
Reference in New Issue
Block a user