4ba0a83e71
Scope (Phase A): - D1: world_size fallback = SIP count (rank = SIP, TP boundary) - D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0) - D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias - D11: tensor placement scoped to current-device SIP (post-hoc pe_index shift — ADR-0026 replaces with structural coords) - D12/D13: multi-greenlet run() with simple round-robin scheduler; hybrid dispatch (ws == SIP count → multi-greenlet, else legacy single-worker for ccl.yaml override compat) - D7 partial: backend.all_reduce submit + yield + wait via launch()'s new _defer_wait flag; parent-less greenlets skip yield - Relaxed shard-count check (len(shards) > 0 instead of == world_size) - rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips Deferred to Phase B: - Engine-routed install (D2) — keeps sideband - install_plan.py module (D6) — keeps install.py - Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock - Validator registry (D8) - Cross-SIP multi-greenlet + real kernel integration — matrix ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix; marked xfail(run=False) pending Phase B diagnosis (suspected per-rank kernel_args / program_id mismatch) Tests: - test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13 - test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 / tree_binary_7) all pass via legacy path 514 tests pass, 1 xfail. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
782 lines
30 KiB
Python
782 lines
30 KiB
Python
# kernbench/runtime_api/context.py
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
from kernbench.common.types import Completion, RequestHandle, SimEngine
|
||
|
||
from .types import DeviceSelector
|
||
|
||
|
||
def _world_size_from_spec(spec: dict | None) -> int:
|
||
"""Derive world_size from topology spec: sips × cubes × pes_per_cube."""
|
||
spec = spec or {}
|
||
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
|
||
pl = spec.get("cube", {}).get("pe_layout", {})
|
||
corners = pl.get("corners", [])
|
||
pe_per_corner = int(pl.get("pe_per_corner", 1))
|
||
pes_per_cube = pe_per_corner * max(len(corners), 1)
|
||
return sips * cubes_per_sip * pes_per_cube
|
||
|
||
|
||
def _numpy_to_dtype_str(np_dtype) -> str:
|
||
"""Map numpy dtype → kernbench dtype string used by Tensor."""
|
||
import numpy as np
|
||
|
||
kind_map = {
|
||
np.float16: "f16",
|
||
np.float32: "f32",
|
||
np.int8: "i8",
|
||
np.int16: "i16",
|
||
np.int32: "i32",
|
||
np.uint8: "u8",
|
||
np.uint16: "u16",
|
||
np.uint32: "u32",
|
||
}
|
||
for np_type, s in kind_map.items():
|
||
if np.dtype(np_dtype) == np.dtype(np_type):
|
||
return s
|
||
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
|
||
|
||
|
||
class _AhbmNamespace:
|
||
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
|
||
|
||
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
|
||
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
|
||
a CUDA runtime.
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._device_by_greenlet: dict = {}
|
||
|
||
def set_device(self, device: int) -> None:
|
||
from greenlet import getcurrent
|
||
self._device_by_greenlet[getcurrent()] = int(device)
|
||
|
||
def current_device(self) -> int | None:
|
||
from greenlet import getcurrent
|
||
return self._device_by_greenlet.get(getcurrent())
|
||
|
||
|
||
class _AcceleratorNamespace:
|
||
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
|
||
|
||
Wraps _AhbmNamespace. Bench code can pick either:
|
||
torch.ahbm.set_device(rank) # explicit backend
|
||
torch.accelerator.set_device_index(rank) # portable
|
||
"""
|
||
|
||
def __init__(self, ahbm: "_AhbmNamespace") -> None:
|
||
self._ahbm = ahbm
|
||
|
||
def set_device_index(self, device: int) -> None:
|
||
self._ahbm.set_device(device)
|
||
|
||
def current_device_index(self) -> int | None:
|
||
return self._ahbm.current_device()
|
||
|
||
|
||
@dataclass
|
||
class RuntimeContext:
|
||
engine: SimEngine
|
||
target_device: DeviceSelector
|
||
correlation_id: str
|
||
spec: dict | None = None
|
||
|
||
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||
_va_allocator: Any = field(default=None, init=False)
|
||
_tensor_counter: int = field(default=0, init=False)
|
||
_traces: list[dict] = field(default_factory=list, init=False)
|
||
_tensors: list[Any] = field(default_factory=list, init=False)
|
||
distributed: Any = field(default=None, init=False) # DistributedContext for CCL benches
|
||
_ipcq_plan: dict = field(default_factory=dict, init=False) # ADR-0023 install plan
|
||
|
||
def __post_init__(self) -> None:
|
||
# Eagerly attach a DistributedContext so bench code can do
|
||
# ``dist = torch.distributed`` + ``dist.init_process_group(...)``
|
||
# without needing a separate launcher to install it.
|
||
from kernbench.runtime_api.distributed import DistributedContext
|
||
dc = DistributedContext()
|
||
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
|
||
self.distributed = dc
|
||
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
|
||
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
|
||
self.ahbm = _AhbmNamespace()
|
||
self.accelerator = _AcceleratorNamespace(self.ahbm)
|
||
|
||
def install_ipcq(
|
||
self,
|
||
algorithm: str | None = None,
|
||
ccl_yaml: str | None = None,
|
||
world_size_override: int | None = None,
|
||
rank_to_pe: list[tuple[int, int, int]] | None = None,
|
||
) -> dict:
|
||
"""Install IPCQ neighbor tables on all participating PEs (ADR-0023 D10).
|
||
|
||
Loads ``ccl.yaml`` (or the path provided), resolves the chosen
|
||
algorithm (or ``defaults.algorithm`` if None), and pushes per-PE
|
||
IpcqInitMsg into every PE_IPCQ component via the engine.
|
||
|
||
Args:
|
||
algorithm: name of the algorithm in ccl.yaml (or use defaults).
|
||
ccl_yaml: optional path to ccl.yaml.
|
||
world_size_override: if set, replace the algorithm's world_size.
|
||
|
||
Returns the install plan dict (rank → (sip,cube,pe), neighbor table).
|
||
"""
|
||
import importlib
|
||
from kernbench.ccl.install import (
|
||
install_ipcq as _install,
|
||
load_ccl_config,
|
||
resolve_algorithm_config,
|
||
)
|
||
|
||
cfg = load_ccl_config(ccl_yaml)
|
||
merged = resolve_algorithm_config(cfg, algorithm)
|
||
if world_size_override is not None:
|
||
merged["world_size"] = world_size_override
|
||
elif "world_size" not in merged:
|
||
# Derive from topology.yaml when neither the algorithm entry
|
||
# nor ``defaults`` carries ``world_size`` (matches pytorch DDP
|
||
# where env vars determine ranks, not the ccl config file).
|
||
merged["world_size"] = _world_size_from_spec(self.spec)
|
||
algo_module = None
|
||
try:
|
||
algo_module = importlib.import_module(merged["module"])
|
||
except ModuleNotFoundError:
|
||
pass
|
||
plan = _install(
|
||
self.engine, self.spec, merged,
|
||
algo_module=algo_module, rank_to_pe=rank_to_pe,
|
||
)
|
||
self._ipcq_plan = plan
|
||
self._ipcq_config = merged
|
||
return plan
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, *exc):
|
||
self.cleanup()
|
||
return False
|
||
|
||
def submit(self, request: Any) -> RequestHandle:
|
||
submit_fn = getattr(self.engine, "submit", None)
|
||
if submit_fn is None:
|
||
raise AttributeError("Engine does not implement submit(request) -> RequestHandle.")
|
||
handle: RequestHandle = submit_fn(request) # type: ignore[call-arg]
|
||
self._handles.append(handle)
|
||
return handle
|
||
|
||
def is_completed(self, handle: RequestHandle) -> bool:
|
||
return handle in self._completed
|
||
|
||
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
|
||
if handle in self._completed:
|
||
completion, trace = self.engine.get_completion(handle)
|
||
return completion
|
||
|
||
wait_fn = getattr(self.engine, "wait", None)
|
||
if wait_fn is not None:
|
||
wait_fn(handle) # type: ignore[misc]
|
||
|
||
completion, trace = self.engine.get_completion(handle)
|
||
self._completed.add(handle)
|
||
if _meta is not None and trace is not None:
|
||
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
|
||
entry.update(_meta)
|
||
self._traces.append(entry)
|
||
return completion
|
||
|
||
def wait_all(self) -> None:
|
||
for h in self._handles:
|
||
if h not in self._completed:
|
||
self.wait(h)
|
||
|
||
def handles(self) -> list[RequestHandle]:
|
||
return list(self._handles)
|
||
|
||
# ── Tensor lifecycle ─────────────────────────────────────────────
|
||
|
||
def _free_tensor(self, tensor: Any) -> None:
|
||
"""Free a single tensor: unmap MMU, return VA and PA."""
|
||
handle = tensor._handle
|
||
if handle is None:
|
||
return
|
||
tensor._handle = None
|
||
|
||
if not handle.va_base:
|
||
return
|
||
|
||
from kernbench.runtime_api.kernel import MmuUnmapMsg
|
||
|
||
dp_policy = None
|
||
if tensor._dp_metadata is not None:
|
||
dp_policy = tensor._dp_metadata.dp_policy
|
||
|
||
is_cube_replicate = (
|
||
dp_policy is not None and dp_policy.cube == "replicate"
|
||
)
|
||
|
||
# Send MmuUnmapMsg through fabric
|
||
from collections import defaultdict
|
||
if is_cube_replicate:
|
||
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||
for shard in handle.shards:
|
||
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||
for (sip, cube), group_shards in cube_groups.items():
|
||
entries = tuple(
|
||
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||
for s in group_shards
|
||
)
|
||
msg = MmuUnmapMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
|
||
entries=entries,
|
||
target_sips=(sip,),
|
||
target_cubes=(cube,),
|
||
target_pe="all",
|
||
)
|
||
h = self.submit(msg)
|
||
self.wait(h)
|
||
else:
|
||
entries = tuple(
|
||
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||
for s in handle.shards
|
||
)
|
||
sip_set = sorted({s.sip for s in handle.shards})
|
||
cube_set = sorted({s.cube for s in handle.shards})
|
||
msg = MmuUnmapMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"unmap_{tensor.name}",
|
||
entries=entries,
|
||
target_sips=tuple(sip_set),
|
||
target_cubes=tuple(cube_set),
|
||
target_pe="all",
|
||
)
|
||
h = self.submit(msg)
|
||
self.wait(h)
|
||
|
||
# Return VA space
|
||
if self._va_allocator is not None:
|
||
self._va_allocator.free(handle.va_base, handle.nbytes)
|
||
|
||
# Return PA space
|
||
if self._allocators:
|
||
for shard in handle.shards:
|
||
flat_idx = (
|
||
shard.sip * self._num_cubes * self._pes_per_cube
|
||
+ shard.cube * self._pes_per_cube
|
||
+ shard.pe
|
||
)
|
||
alloc = self._allocators.get(flat_idx)
|
||
if alloc is not None:
|
||
from kernbench.policy.address.phyaddr import PhysAddr
|
||
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
|
||
|
||
def cleanup(self) -> None:
|
||
"""Free all tensors created by this context."""
|
||
for ref in self._tensors:
|
||
t = ref()
|
||
if t is not None and t._handle is not None:
|
||
self._free_tensor(t)
|
||
self._tensors.clear()
|
||
|
||
# ── PyTorch-like tensor API ──────────────────────────────────────
|
||
|
||
def _ensure_allocators(self) -> dict:
|
||
"""Lazily create PEMemAllocator instances from spec."""
|
||
if self._allocators:
|
||
return self._allocators
|
||
if self.spec is None:
|
||
raise RuntimeError(
|
||
"RuntimeContext.spec is required for tensor operations. "
|
||
"Pass spec=graph.spec when creating RuntimeContext."
|
||
)
|
||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||
|
||
system = self.spec.get("system", {})
|
||
cube = self.spec.get("cube", {})
|
||
mm = cube.get("memory_map", {})
|
||
pe_template = cube.get("pe_template", {})
|
||
pe_comps = pe_template.get("components", {})
|
||
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
|
||
|
||
total_sip_count = system.get("sips", {}).get("count", 1)
|
||
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
|
||
pes_per_cube = (
|
||
cube.get("pe_layout", {}).get("pe_per_corner", 2)
|
||
* len(cube.get("pe_layout", {}).get("corners", ["NW", "NE", "SW", "SE"]))
|
||
)
|
||
hbm_gb = mm.get("hbm_total_gb_per_cube", 48)
|
||
hbm_slices = mm.get("hbm_slices_per_cube", 8)
|
||
tcm_mb = tcm_cfg.get("size_mb", 16)
|
||
|
||
# Scope to target_device: single SIP or all SIPs
|
||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||
td = self.target_device if isinstance(self.target_device, DeviceSelector) else resolve_device(str(self.target_device))
|
||
if td.is_all:
|
||
sip_range = range(total_sip_count)
|
||
sip_count = total_sip_count
|
||
else:
|
||
sip_idx = td.sip_index
|
||
sip_range = range(sip_idx, sip_idx + 1)
|
||
sip_count = 1
|
||
|
||
cfg = AddressConfig(
|
||
sip_count=sip_count,
|
||
cubes_per_sip=cubes_per_sip,
|
||
pes_per_cube=pes_per_cube,
|
||
hbm_bytes_per_cube=hbm_gb * (1 << 30),
|
||
hbm_slices_per_cube=hbm_slices,
|
||
tcm_bytes_per_pe=tcm_mb * (1 << 20),
|
||
tcm_scheduler_reserved_bytes=4 * (1 << 20),
|
||
sram_bytes_per_cube=32 * (1 << 20),
|
||
)
|
||
# Create allocators scoped to target SIP(s) only
|
||
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
|
||
self._pes_per_cube = pes_per_cube
|
||
self._num_cubes = cubes_per_sip
|
||
self._num_sips = sip_count
|
||
cubes_x_pes = cubes_per_sip * pes_per_cube
|
||
for sip_id in sip_range:
|
||
for cube_id in range(cubes_per_sip):
|
||
for pe_id in range(pes_per_cube):
|
||
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
|
||
self._allocators[flat_idx] = PEMemAllocator(
|
||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||
)
|
||
|
||
# Initialize VA allocator (MMU mappings are installed via fabric MmuMapMsg)
|
||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||
|
||
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
|
||
page_size = int(pe_mmu_attrs.get("page_size", 4096))
|
||
|
||
self._va_allocator = VirtualAllocator(
|
||
va_base=0x1_0000_0000,
|
||
va_size=64 * (1 << 30), # 64 GB VA space
|
||
page_size=page_size,
|
||
)
|
||
|
||
return self._allocators
|
||
|
||
def _next_tensor_name(self) -> str:
|
||
self._tensor_counter += 1
|
||
return f"t{self._tensor_counter}"
|
||
|
||
def zeros(
|
||
self,
|
||
shape: tuple[int, ...],
|
||
dtype: str = "f16",
|
||
*,
|
||
dp: Any = None,
|
||
name: str | None = None,
|
||
):
|
||
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
|
||
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
|
||
|
||
def empty(
|
||
self,
|
||
shape: tuple[int, ...],
|
||
dtype: str = "f16",
|
||
*,
|
||
dp: Any = None,
|
||
name: str | None = None,
|
||
):
|
||
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
||
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
|
||
|
||
def from_numpy(self, arr: Any):
|
||
"""Create a host-side tensor wrapping a numpy array.
|
||
|
||
Mirrors ``torch.from_numpy``. The returned tensor is NOT deployed
|
||
to any PE — it lives in an in-memory host staging buffer. Use
|
||
``target.copy_(host_tensor)`` to scatter its contents into a
|
||
sharded, deployed tensor.
|
||
"""
|
||
import numpy as np
|
||
from kernbench.runtime_api.tensor import Tensor
|
||
|
||
arr_c = np.ascontiguousarray(arr)
|
||
dtype_str = _numpy_to_dtype_str(arr_c.dtype)
|
||
t = Tensor(shape=tuple(arr_c.shape), dtype=dtype_str, name="host")
|
||
t._host_buffer = arr_c
|
||
t._memory_store = getattr(self.engine, "_memory_store", None)
|
||
return t
|
||
|
||
def _create_tensor(
|
||
self,
|
||
shape: tuple[int, ...],
|
||
dtype: str,
|
||
name: str | None,
|
||
pattern: str | None,
|
||
dp: Any = None,
|
||
):
|
||
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
|
||
|
||
if not isinstance(dp, DPPolicy):
|
||
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
|
||
|
||
tensor_name = name or self._next_tensor_name()
|
||
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
|
||
|
||
dp_policy = dp
|
||
allocators = self._ensure_allocators()
|
||
itemsize = dtype_itemsize(dtype)
|
||
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
|
||
# DPPolicy overrides take precedence over topology dimensions
|
||
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
|
||
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||
# ADR-0024 D11: if torch.ahbm.set_device(r) is active AND DPPolicy
|
||
# leaves the SIP dimension at its default (replicate + no num_sips
|
||
# override), scope the tensor to SIP r only.
|
||
# NOTE: this path uses post-hoc pe_index shifting as a temporary
|
||
# measure; ADR-0026 replaces it with structural (sip, cube, pe)
|
||
# coords in ShardSpec.
|
||
current_sip = (
|
||
self.ahbm.current_device() if hasattr(self, "ahbm") else None
|
||
)
|
||
scope_to_current_sip = (
|
||
current_sip is not None
|
||
and dp.sip == "replicate"
|
||
and dp.num_sips is None
|
||
)
|
||
if scope_to_current_sip:
|
||
eff_num_sips = 1
|
||
else:
|
||
eff_num_sips = (
|
||
dp.num_sips if dp.num_sips is not None else self._num_sips
|
||
)
|
||
placement = resolve_dp_policy(
|
||
dp, shape=shape_2d, itemsize=itemsize,
|
||
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
|
||
num_sips=eff_num_sips,
|
||
)
|
||
if scope_to_current_sip:
|
||
from kernbench.policy.placement.dp import ShardSpec as _SS
|
||
sip_stride = self._num_cubes * self._pes_per_cube
|
||
offset = int(current_sip) * sip_stride
|
||
placement = [
|
||
_SS(pe_index=s.pe_index + offset,
|
||
offset_bytes=s.offset_bytes, nbytes=s.nbytes)
|
||
for s in placement
|
||
]
|
||
|
||
# Infer target_pe from placement using local (within-cube) PE IDs.
|
||
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
|
||
local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement})
|
||
if len(local_pe_ids) == 1:
|
||
target_pe: int | tuple[int, ...] | str = local_pe_ids[0]
|
||
elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube:
|
||
target_pe = "all"
|
||
else:
|
||
target_pe = tuple(local_pe_ids)
|
||
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
||
|
||
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
|
||
allocators = self._ensure_allocators()
|
||
handle = deploy_tensor(
|
||
name=tensor_name,
|
||
shape=shape,
|
||
dtype=dtype,
|
||
placement=placement,
|
||
allocators=allocators,
|
||
va_allocator=self._va_allocator,
|
||
)
|
||
t._handle = handle
|
||
import weakref
|
||
t._ctx_ref = weakref.ref(self)
|
||
t._memory_store = getattr(self.engine, "_memory_store", None)
|
||
self._tensors.append(weakref.ref(t))
|
||
|
||
# Install VA→PA mappings via fabric MmuMapMsg
|
||
# Strategy: always SIP-scoped (each SIP gets only its own shards).
|
||
# Within each SIP: cube="replicate" → per-cube, else broadcast within SIP.
|
||
if handle.va_base:
|
||
from collections import defaultdict
|
||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||
|
||
is_cube_replicate = (
|
||
dp_policy is not None and dp_policy.cube == "replicate"
|
||
)
|
||
|
||
# Group shards by SIP
|
||
sip_groups: dict[int, list] = defaultdict(list)
|
||
for shard in handle.shards:
|
||
sip_groups[shard.sip].append(shard)
|
||
|
||
for sip, sip_shards in sip_groups.items():
|
||
if is_cube_replicate:
|
||
# Cube replicate: per-(sip, cube) local mapping
|
||
cube_groups: dict[int, list] = defaultdict(list)
|
||
for s in sip_shards:
|
||
cube_groups[s.cube].append(s)
|
||
|
||
for cube, group_shards in cube_groups.items():
|
||
entries = tuple(
|
||
{"va": handle.va_base + s.offset_bytes,
|
||
"pa": s.pa, "size": s.nbytes}
|
||
for s in group_shards
|
||
)
|
||
msg = MmuMapMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
|
||
entries=entries,
|
||
target_sips=(sip,),
|
||
target_cubes=(cube,),
|
||
target_pe="all",
|
||
)
|
||
h = self.submit(msg)
|
||
self.wait(h)
|
||
else:
|
||
# Cube sharded: broadcast all cubes within this SIP
|
||
entries = tuple(
|
||
{"va": handle.va_base + s.offset_bytes,
|
||
"pa": s.pa, "size": s.nbytes}
|
||
for s in sip_shards
|
||
)
|
||
cube_set = sorted({s.cube for s in sip_shards})
|
||
msg = MmuMapMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"mmu_{tensor_name}_s{sip}",
|
||
entries=entries,
|
||
target_sips=(sip,),
|
||
target_cubes=tuple(cube_set),
|
||
target_pe="all",
|
||
)
|
||
h = self.submit(msg)
|
||
self.wait(h)
|
||
|
||
# Submit MemoryWriteMsg per shard (deploy data to device)
|
||
if pattern is not None:
|
||
for shard in handle.shards:
|
||
h = self.submit(MemoryWriteMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"deploy_{tensor_name}_pe{shard.pe}",
|
||
dst_sip=shard.sip, dst_cube=shard.cube, dst_pe=shard.pe,
|
||
dst_pa=shard.pa, nbytes=shard.nbytes, pattern=pattern,
|
||
target_cubes=(shard.cube,), target_pe=shard.pe,
|
||
))
|
||
self.wait(h, _meta={
|
||
"phase": "memory_write", "name": tensor_name,
|
||
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
|
||
"nbytes": shard.nbytes,
|
||
})
|
||
|
||
return t
|
||
|
||
def launch(
|
||
self,
|
||
kernel_name: str,
|
||
kernel_fn: Any,
|
||
*args: Any,
|
||
_defer_wait: bool = False,
|
||
**kwargs: Any,
|
||
) -> RequestHandle:
|
||
"""Register and launch a kernel (like a fused torch op).
|
||
|
||
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
|
||
Keyword args: become ScalarArg (name is discarded, order preserved).
|
||
|
||
Creates per-SIP KernelLaunchMsg with local va_base per tensor
|
||
(like host driver sending per-rank launch commands).
|
||
|
||
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
|
||
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
|
||
responsible for waiting — used by collective ops to yield between
|
||
submit and wait so all sibling ranks can submit first.
|
||
"""
|
||
from collections import defaultdict
|
||
|
||
from kernbench.runtime_api.kernel import (
|
||
KernelLaunchMsg,
|
||
KernelRef,
|
||
ScalarArg,
|
||
TensorArg,
|
||
TensorArgShard,
|
||
)
|
||
from kernbench.runtime_api.tensor import Tensor
|
||
from kernbench.triton_emu.registry import _kernels, register_kernel
|
||
|
||
# Register kernel (idempotent overwrite — last call wins).
|
||
# Tests can re-register the same kernel_name with a different
|
||
# function; the user's most recent launch must use the latest fn.
|
||
_kernels[kernel_name] = kernel_fn
|
||
|
||
# Collect tensors and scalars
|
||
tensor_args: list[Tensor] = []
|
||
scalar_args: list = []
|
||
_pe_set: set[int] = set()
|
||
_pe_all = False
|
||
|
||
for a in args:
|
||
if isinstance(a, Tensor):
|
||
tensor_args.append(a)
|
||
if a._dp_metadata is not None:
|
||
dp_target = a._dp_metadata.target_pe
|
||
if dp_target == "all":
|
||
_pe_all = True
|
||
elif isinstance(dp_target, tuple):
|
||
_pe_set.update(dp_target)
|
||
elif isinstance(dp_target, int):
|
||
_pe_set.add(dp_target)
|
||
elif isinstance(a, (int, float)):
|
||
dtype_str = "f32" if isinstance(a, float) else "i32"
|
||
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||
|
||
for v in kwargs.values():
|
||
if isinstance(v, (int, float)):
|
||
dtype_str = "f32" if isinstance(v, float) else "i32"
|
||
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||
|
||
# Resolve target_pe from collected PE info
|
||
if _pe_all:
|
||
target_pe: int | tuple[int, ...] | str = "all"
|
||
elif len(_pe_set) == 1:
|
||
target_pe = next(iter(_pe_set))
|
||
elif len(_pe_set) > 1:
|
||
target_pe = tuple(sorted(_pe_set))
|
||
else:
|
||
target_pe = 0
|
||
|
||
# Determine all target SIPs from tensor shards
|
||
sip_set: set[int] = set()
|
||
for t in tensor_args:
|
||
if t._handle is not None:
|
||
for s in t._handle.shards:
|
||
sip_set.add(s.sip)
|
||
if not sip_set:
|
||
sip_set = {0}
|
||
|
||
# Build global→local dimension mapping from tensor DPPolicies.
|
||
# Scalar args matching a tensor's global dimension get replaced
|
||
# with the cube-local value (what the kernel actually operates on).
|
||
def _compute_local_shape(t: Tensor) -> tuple[int, ...]:
|
||
"""Compute cube-local shape from DPPolicy."""
|
||
shape = t.shape
|
||
if len(shape) < 2:
|
||
shape = (1, shape[0])
|
||
M, K = shape[0], shape[1]
|
||
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
|
||
if dp is None:
|
||
return t.shape
|
||
if dp.sip != "replicate":
|
||
if dp.sip == "column_wise":
|
||
K = K // self._num_sips
|
||
elif dp.sip == "row_wise":
|
||
M = M // self._num_sips
|
||
if dp.cube != "replicate":
|
||
if dp.cube == "column_wise":
|
||
K = K // self._num_cubes
|
||
elif dp.cube == "row_wise":
|
||
M = M // self._num_cubes
|
||
if len(t.shape) < 2:
|
||
return (K,)
|
||
return (M, K)
|
||
|
||
dim_map: dict[int, int] = {} # global_dim → local_dim
|
||
for t in tensor_args:
|
||
local = _compute_local_shape(t)
|
||
for g, l in zip(t.shape if len(t.shape) >= 2 else (1, t.shape[0]), local if len(local) >= 2 else (1, local[0])):
|
||
if g != l:
|
||
dim_map[g] = l
|
||
|
||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||
last_handle = None
|
||
_pending_handles: list[tuple[Any, int]] = []
|
||
for sip_id in sorted(sip_set):
|
||
sip_kernel_args: list = []
|
||
sip_cube_set: set[int] = set()
|
||
|
||
for t in tensor_args:
|
||
if t._handle is None:
|
||
continue
|
||
sip_shards = [s for s in t._handle.shards if s.sip == sip_id]
|
||
if not sip_shards:
|
||
sip_shards = list(t._handle.shards)
|
||
|
||
local_va_base = 0
|
||
if t._handle.va_base:
|
||
min_offset = min(s.offset_bytes for s in sip_shards)
|
||
local_va_base = t._handle.va_base + min_offset
|
||
|
||
sip_kernel_args.append(TensorArg(
|
||
shards=tuple(
|
||
TensorArgShard(
|
||
sip=s.sip, cube=s.cube, pe=s.pe,
|
||
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
|
||
)
|
||
for s in sip_shards
|
||
),
|
||
va_base=local_va_base,
|
||
))
|
||
|
||
for s in sip_shards:
|
||
sip_cube_set.add(s.cube)
|
||
|
||
# Interleave tensor args and scalar args, replacing global dims with local
|
||
final_args: list = []
|
||
t_idx, s_idx = 0, 0
|
||
for a in args:
|
||
if isinstance(a, Tensor):
|
||
final_args.append(sip_kernel_args[t_idx])
|
||
t_idx += 1
|
||
elif isinstance(a, (int, float)):
|
||
sa = scalar_args[s_idx]
|
||
if isinstance(a, int) and a in dim_map:
|
||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[a])
|
||
final_args.append(sa)
|
||
s_idx += 1
|
||
while s_idx < len(scalar_args):
|
||
sa = scalar_args[s_idx]
|
||
if isinstance(sa.value, int) and int(sa.value) in dim_map:
|
||
sa = ScalarArg(dtype=sa.dtype, value=dim_map[int(sa.value)])
|
||
final_args.append(sa)
|
||
s_idx += 1
|
||
|
||
target_cubes = tuple(sorted(sip_cube_set)) if sip_cube_set else (0,)
|
||
|
||
h = self.submit(KernelLaunchMsg(
|
||
correlation_id=self.correlation_id,
|
||
request_id=f"{kernel_name}_sip{sip_id}",
|
||
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
|
||
args=tuple(final_args),
|
||
target_cubes=target_cubes,
|
||
target_pe=target_pe,
|
||
))
|
||
# Defer wait until all SIPs are submitted (multi-SIP CCL needs
|
||
# all participating PEs to be live concurrently — waiting
|
||
# per-SIP would deadlock when ranks span SIP boundaries).
|
||
_pending_handles.append((h, sip_id))
|
||
last_handle = h
|
||
|
||
if _defer_wait:
|
||
# ADR-0024 D7: return the pending-list so the caller can yield
|
||
# between submit and drain. Used by collective ops that need
|
||
# all sibling ranks to submit before any rank waits.
|
||
return [
|
||
(h, sip_id, {
|
||
"phase": "kernel", "name": kernel_name,
|
||
"sip": sip_id, "target_pe": target_pe,
|
||
})
|
||
for h, sip_id in _pending_handles
|
||
]
|
||
|
||
# Drain pending handles now that every SIP has a launch posted.
|
||
for h, sip_id in _pending_handles:
|
||
self.wait(h, _meta={
|
||
"phase": "kernel", "name": kernel_name,
|
||
"sip": sip_id, "target_pe": target_pe,
|
||
})
|
||
|
||
return last_handle
|