Files
kernbench2/src/kernbench/runtime_api/context.py
T
ywkang 4ba0a83e71 Implement ADR-0024 Phase A: SIP-level TP launcher MVP
Scope (Phase A):
- D1: world_size fallback = SIP count (rank = SIP, TP boundary)
- D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0)
- D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias
- D11: tensor placement scoped to current-device SIP (post-hoc pe_index
  shift — ADR-0026 replaces with structural coords)
- D12/D13: multi-greenlet run() with simple round-robin scheduler;
  hybrid dispatch (ws == SIP count → multi-greenlet, else legacy
  single-worker for ccl.yaml override compat)
- D7 partial: backend.all_reduce submit + yield + wait via launch()'s
  new _defer_wait flag; parent-less greenlets skip yield
- Relaxed shard-count check (len(shards) > 0 instead of == world_size)
- rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips

Deferred to Phase B:
- Engine-routed install (D2) — keeps sideband
- install_plan.py module (D6) — keeps install.py
- Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock
- Validator registry (D8)
- Cross-SIP multi-greenlet + real kernel integration — matrix
  ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix;
  marked xfail(run=False) pending Phase B diagnosis (suspected per-rank
  kernel_args / program_id mismatch)

Tests:
- test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13
- test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override
  cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 /
  tree_binary_7) all pass via legacy path

514 tests pass, 1 xfail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 09:00:28 -07:00

782 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# kernbench/runtime_api/context.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from kernbench.common.types import Completion, RequestHandle, SimEngine
from .types import DeviceSelector
def _world_size_from_spec(spec: dict | None) -> int:
"""Derive world_size from topology spec: sips × cubes × pes_per_cube."""
spec = spec or {}
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
cm = spec.get("sip", {}).get("cube_mesh", {})
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
pl = spec.get("cube", {}).get("pe_layout", {})
corners = pl.get("corners", [])
pe_per_corner = int(pl.get("pe_per_corner", 1))
pes_per_cube = pe_per_corner * max(len(corners), 1)
return sips * cubes_per_sip * pes_per_cube
def _numpy_to_dtype_str(np_dtype) -> str:
"""Map numpy dtype → kernbench dtype string used by Tensor."""
import numpy as np
kind_map = {
np.float16: "f16",
np.float32: "f32",
np.int8: "i8",
np.int16: "i16",
np.int32: "i32",
np.uint8: "u8",
np.uint16: "u16",
np.uint32: "u32",
}
for np_type, s in kind_map.items():
if np.dtype(np_dtype) == np.dtype(np_type):
return s
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
class _AhbmNamespace:
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
a CUDA runtime.
"""
def __init__(self) -> None:
self._device_by_greenlet: dict = {}
def set_device(self, device: int) -> None:
from greenlet import getcurrent
self._device_by_greenlet[getcurrent()] = int(device)
def current_device(self) -> int | None:
from greenlet import getcurrent
return self._device_by_greenlet.get(getcurrent())
class _AcceleratorNamespace:
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
Wraps _AhbmNamespace. Bench code can pick either:
torch.ahbm.set_device(rank) # explicit backend
torch.accelerator.set_device_index(rank) # portable
"""
def __init__(self, ahbm: "_AhbmNamespace") -> None:
self._ahbm = ahbm
def set_device_index(self, device: int) -> None:
self._ahbm.set_device(device)
def current_device_index(self) -> int | None:
return self._ahbm.current_device()
@dataclass
class RuntimeContext:
engine: SimEngine
target_device: DeviceSelector
correlation_id: str
spec: dict | None = None
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
distributed: Any = field(default=None, init=False) # DistributedContext for CCL benches
_ipcq_plan: dict = field(default_factory=dict, init=False) # ADR-0023 install plan
def __post_init__(self) -> None:
# Eagerly attach a DistributedContext so bench code can do
# ``dist = torch.distributed`` + ``dist.init_process_group(...)``
# without needing a separate launcher to install it.
from kernbench.runtime_api.distributed import DistributedContext
dc = DistributedContext()
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
self.distributed = dc
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
self.ahbm = _AhbmNamespace()
self.accelerator = _AcceleratorNamespace(self.ahbm)
def install_ipcq(
self,
algorithm: str | None = None,
ccl_yaml: str | None = None,
world_size_override: int | None = None,
rank_to_pe: list[tuple[int, int, int]] | None = None,
) -> dict:
"""Install IPCQ neighbor tables on all participating PEs (ADR-0023 D10).
Loads ``ccl.yaml`` (or the path provided), resolves the chosen
algorithm (or ``defaults.algorithm`` if None), and pushes per-PE
IpcqInitMsg into every PE_IPCQ component via the engine.
Args:
algorithm: name of the algorithm in ccl.yaml (or use defaults).
ccl_yaml: optional path to ccl.yaml.
world_size_override: if set, replace the algorithm's world_size.
Returns the install plan dict (rank → (sip,cube,pe), neighbor table).
"""
import importlib
from kernbench.ccl.install import (
install_ipcq as _install,
load_ccl_config,
resolve_algorithm_config,
)
cfg = load_ccl_config(ccl_yaml)
merged = resolve_algorithm_config(cfg, algorithm)
if world_size_override is not None:
merged["world_size"] = world_size_override
elif "world_size" not in merged:
# Derive from topology.yaml when neither the algorithm entry
# nor ``defaults`` carries ``world_size`` (matches pytorch DDP
# where env vars determine ranks, not the ccl config file).
merged["world_size"] = _world_size_from_spec(self.spec)
algo_module = None
try:
algo_module = importlib.import_module(merged["module"])
except ModuleNotFoundError:
pass
plan = _install(
self.engine, self.spec, merged,
algo_module=algo_module, rank_to_pe=rank_to_pe,
)
self._ipcq_plan = plan
self._ipcq_config = merged
return plan
def __enter__(self):
return self
def __exit__(self, *exc):
self.cleanup()
return False
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
if submit_fn is None:
raise AttributeError("Engine does not implement submit(request) -> RequestHandle.")
handle: RequestHandle = submit_fn(request) # type: ignore[call-arg]
self._handles.append(handle)
return handle
def is_completed(self, handle: RequestHandle) -> bool:
return handle in self._completed
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
if handle in self._completed:
completion, trace = self.engine.get_completion(handle)
return completion
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle) # type: ignore[misc]
completion, trace = self.engine.get_completion(handle)
self._completed.add(handle)
if _meta is not None and trace is not None:
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
entry.update(_meta)
self._traces.append(entry)
return completion
def wait_all(self) -> None:
for h in self._handles:
if h not in self._completed:
self.wait(h)
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── Tensor lifecycle ─────────────────────────────────────────────
def _free_tensor(self, tensor: Any) -> None:
"""Free a single tensor: unmap MMU, return VA and PA."""
handle = tensor._handle
if handle is None:
return
tensor._handle = None
if not handle.va_base:
return
from kernbench.runtime_api.kernel import MmuUnmapMsg
dp_policy = None
if tensor._dp_metadata is not None:
dp_policy = tensor._dp_metadata.dp_policy
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
# Send MmuUnmapMsg through fabric
from collections import defaultdict
if is_cube_replicate:
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in group_shards
)
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Return VA space
if self._va_allocator is not None:
self._va_allocator.free(handle.va_base, handle.nbytes)
# Return PA space
if self._allocators:
for shard in handle.shards:
flat_idx = (
shard.sip * self._num_cubes * self._pes_per_cube
+ shard.cube * self._pes_per_cube
+ shard.pe
)
alloc = self._allocators.get(flat_idx)
if alloc is not None:
from kernbench.policy.address.phyaddr import PhysAddr
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
def cleanup(self) -> None:
"""Free all tensors created by this context."""
for ref in self._tensors:
t = ref()
if t is not None and t._handle is not None:
self._free_tensor(t)
self._tensors.clear()
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
"""Lazily create PEMemAllocator instances from spec."""
if self._allocators:
return self._allocators
if self.spec is None:
raise RuntimeError(
"RuntimeContext.spec is required for tensor operations. "
"Pass spec=graph.spec when creating RuntimeContext."
)
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
system = self.spec.get("system", {})
cube = self.spec.get("cube", {})
mm = cube.get("memory_map", {})
pe_template = cube.get("pe_template", {})
pe_comps = pe_template.get("components", {})
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
total_sip_count = system.get("sips", {}).get("count", 1)
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
pes_per_cube = (
cube.get("pe_layout", {}).get("pe_per_corner", 2)
* len(cube.get("pe_layout", {}).get("corners", ["NW", "NE", "SW", "SE"]))
)
hbm_gb = mm.get("hbm_total_gb_per_cube", 48)
hbm_slices = mm.get("hbm_slices_per_cube", 8)
tcm_mb = tcm_cfg.get("size_mb", 16)
# Scope to target_device: single SIP or all SIPs
from kernbench.runtime_api.types import DeviceSelector, resolve_device
td = self.target_device if isinstance(self.target_device, DeviceSelector) else resolve_device(str(self.target_device))
if td.is_all:
sip_range = range(total_sip_count)
sip_count = total_sip_count
else:
sip_idx = td.sip_index
sip_range = range(sip_idx, sip_idx + 1)
sip_count = 1
cfg = AddressConfig(
sip_count=sip_count,
cubes_per_sip=cubes_per_sip,
pes_per_cube=pes_per_cube,
hbm_bytes_per_cube=hbm_gb * (1 << 30),
hbm_slices_per_cube=hbm_slices,
tcm_bytes_per_pe=tcm_mb * (1 << 20),
tcm_scheduler_reserved_bytes=4 * (1 << 20),
sram_bytes_per_cube=32 * (1 << 20),
)
# Create allocators scoped to target SIP(s) only
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
self._pes_per_cube = pes_per_cube
self._num_cubes = cubes_per_sip
self._num_sips = sip_count
cubes_x_pes = cubes_per_sip * pes_per_cube
for sip_id in sip_range:
for cube_id in range(cubes_per_sip):
for pe_id in range(pes_per_cube):
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
return self._allocators
def _next_tensor_name(self) -> str:
self._tensor_counter += 1
return f"t{self._tensor_counter}"
def zeros(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
def from_numpy(self, arr: Any):
"""Create a host-side tensor wrapping a numpy array.
Mirrors ``torch.from_numpy``. The returned tensor is NOT deployed
to any PE — it lives in an in-memory host staging buffer. Use
``target.copy_(host_tensor)`` to scatter its contents into a
sharded, deployed tensor.
"""
import numpy as np
from kernbench.runtime_api.tensor import Tensor
arr_c = np.ascontiguousarray(arr)
dtype_str = _numpy_to_dtype_str(arr_c.dtype)
t = Tensor(shape=tuple(arr_c.shape), dtype=dtype_str, name="host")
t._host_buffer = arr_c
t._memory_store = getattr(self.engine, "_memory_store", None)
return t
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
if not isinstance(dp, DPPolicy):
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
# DPPolicy overrides take precedence over topology dimensions
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
# ADR-0024 D11: if torch.ahbm.set_device(r) is active AND DPPolicy
# leaves the SIP dimension at its default (replicate + no num_sips
# override), scope the tensor to SIP r only.
# NOTE: this path uses post-hoc pe_index shifting as a temporary
# measure; ADR-0026 replaces it with structural (sip, cube, pe)
# coords in ShardSpec.
current_sip = (
self.ahbm.current_device() if hasattr(self, "ahbm") else None
)
scope_to_current_sip = (
current_sip is not None
and dp.sip == "replicate"
and dp.num_sips is None
)
if scope_to_current_sip:
eff_num_sips = 1
else:
eff_num_sips = (
dp.num_sips if dp.num_sips is not None else self._num_sips
)
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
num_sips=eff_num_sips,
)
if scope_to_current_sip:
from kernbench.policy.placement.dp import ShardSpec as _SS
sip_stride = self._num_cubes * self._pes_per_cube
offset = int(current_sip) * sip_stride
placement = [
_SS(pe_index=s.pe_index + offset,
offset_bytes=s.offset_bytes, nbytes=s.nbytes)
for s in placement
]
# Infer target_pe from placement using local (within-cube) PE IDs.
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement})
if len(local_pe_ids) == 1:
target_pe: int | tuple[int, ...] | str = local_pe_ids[0]
elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube:
target_pe = "all"
else:
target_pe = tuple(local_pe_ids)
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
shape=shape,
dtype=dtype,
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
)
t._handle = handle
import weakref
t._ctx_ref = weakref.ref(self)
t._memory_store = getattr(self.engine, "_memory_store", None)
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
# Strategy: always SIP-scoped (each SIP gets only its own shards).
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
# Group shards by SIP
sip_groups: dict[int, list] = defaultdict(list)
for shard in handle.shards:
sip_groups[shard.sip].append(shard)
for sip, sip_shards in sip_groups.items():
if is_cube_replicate:
# Cube replicate: per-(sip, cube) local mapping
cube_groups: dict[int, list] = defaultdict(list)
for s in sip_shards:
cube_groups[s.cube].append(s)
for cube, group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Cube sharded: broadcast all cubes within this SIP
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in sip_shards
)
cube_set = sorted({s.cube for s in sip_shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}",
entries=entries,
target_sips=(sip,),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
for shard in handle.shards:
h = self.submit(MemoryWriteMsg(
correlation_id=self.correlation_id,
request_id=f"deploy_{tensor_name}_pe{shard.pe}",
dst_sip=shard.sip, dst_cube=shard.cube, dst_pe=shard.pe,
dst_pa=shard.pa, nbytes=shard.nbytes, pattern=pattern,
target_cubes=(shard.cube,), target_pe=shard.pe,
))
self.wait(h, _meta={
"phase": "memory_write", "name": tensor_name,
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
"nbytes": shard.nbytes,
})
return t
def launch(
self,
kernel_name: str,
kernel_fn: Any,
*args: Any,
_defer_wait: bool = False,
**kwargs: Any,
) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op).
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
responsible for waiting — used by collective ops to yield between
submit and wait so all sibling ranks can submit first.
"""
from collections import defaultdict
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import _kernels, register_kernel
# Register kernel (idempotent overwrite — last call wins).
# Tests can re-register the same kernel_name with a different
# function; the user's most recent launch must use the latest fn.
_kernels[kernel_name] = kernel_fn
# Collect tensors and scalars
tensor_args: list[Tensor] = []
scalar_args: list = []
_pe_set: set[int] = set()
_pe_all = False
for a in args:
if isinstance(a, Tensor):
tensor_args.append(a)
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
_pe_all = True
elif isinstance(dp_target, tuple):
_pe_set.update(dp_target)
elif isinstance(dp_target, int):
_pe_set.add(dp_target)
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
# Resolve target_pe from collected PE info
if _pe_all:
target_pe: int | tuple[int, ...] | str = "all"
elif len(_pe_set) == 1:
target_pe = next(iter(_pe_set))
elif len(_pe_set) > 1:
target_pe = tuple(sorted(_pe_set))
else:
target_pe = 0
# Determine all target SIPs from tensor shards
sip_set: set[int] = set()
for t in tensor_args:
if t._handle is not None:
for s in t._handle.shards:
sip_set.add(s.sip)
if not sip_set:
sip_set = {0}
# Build global→local dimension mapping from tensor DPPolicies.
# Scalar args matching a tensor's global dimension get replaced
# with the cube-local value (what the kernel actually operates on).
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
"""Compute cube-local shape from DPPolicy."""
shape = t.shape
if len(shape) < 2:
shape = (1, shape[0])
M, K = shape[0], shape[1]
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None:
return t.shape
if dp.sip != "replicate":
if dp.sip == "column_wise":
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate":
if dp.cube == "column_wise":
K = K // self._num_cubes
elif dp.cube == "row_wise":
M = M // self._num_cubes
if len(t.shape) < 2:
return (K,)
return (M, K)
dim_map: dict[int, int] = {} # global_dim → local_dim
for t in tensor_args:
local = _compute_local_shape(t)
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
if g != l:
dim_map[g] = l
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
_pending_handles: list[tuple[Any, int]] = []
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
for t in tensor_args:
if t._handle is None:
continue
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
if not sip_shards:
sip_shards = list(t._handle.shards)
local_va_base = 0
if t._handle.va_base:
min_offset = min(s.offset_bytes for s in sip_shards)
local_va_base = t._handle.va_base + min_offset
sip_kernel_args.append(TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in sip_shards
),
va_base=local_va_base,
))
for s in sip_shards:
sip_cube_set.add(s.cube)
# Interleave tensor args and scalar args, replacing global dims with local
final_args: list = []
t_idx, s_idx = 0, 0
for a in args:
if isinstance(a, Tensor):
final_args.append(sip_kernel_args[t_idx])
t_idx += 1
elif isinstance(a, (int, float)):
sa = scalar_args[s_idx]
if isinstance(a, int) and a in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
final_args.append(sa)
s_idx += 1
while s_idx < len(scalar_args):
sa = scalar_args[s_idx]
if isinstance(sa.value, int) and int(sa.value) in dim_map:
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
final_args.append(sa)
s_idx += 1
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=f"{kernel_name}_sip{sip_id}",
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(final_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
# Defer wait until all SIPs are submitted (multi-SIP CCL needs
# all participating PEs to be live concurrently — waiting
# per-SIP would deadlock when ranks span SIP boundaries).
_pending_handles.append((h, sip_id))
last_handle = h
if _defer_wait:
# ADR-0024 D7: return the pending-list so the caller can yield
# between submit and drain. Used by collective ops that need
# all sibling ranks to submit before any rank waits.
return [
(h, sip_id, {
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
for h, sip_id in _pending_handles
]
# Drain pending handles now that every SIP has a launch posted.
for h, sip_id in _pending_handles:
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
return last_handle