Implement ADR-0024 Phase A: SIP-level TP launcher MVP
Scope (Phase A): - D1: world_size fallback = SIP count (rank = SIP, TP boundary) - D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0) - D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias - D11: tensor placement scoped to current-device SIP (post-hoc pe_index shift — ADR-0026 replaces with structural coords) - D12/D13: multi-greenlet run() with simple round-robin scheduler; hybrid dispatch (ws == SIP count → multi-greenlet, else legacy single-worker for ccl.yaml override compat) - D7 partial: backend.all_reduce submit + yield + wait via launch()'s new _defer_wait flag; parent-less greenlets skip yield - Relaxed shard-count check (len(shards) > 0 instead of == world_size) - rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips Deferred to Phase B: - Engine-routed install (D2) — keeps sideband - install_plan.py module (D6) — keeps install.py - Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock - Validator registry (D8) - Cross-SIP multi-greenlet + real kernel integration — matrix ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix; marked xfail(run=False) pending Phase B diagnosis (suspected per-rank kernel_args / program_id mismatch) Tests: - test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13 - test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 / tree_binary_7) all pass via legacy path 514 tests pass, 1 xfail. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -42,6 +42,44 @@ def _numpy_to_dtype_str(np_dtype) -> str:
|
||||
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
|
||||
|
||||
|
||||
class _AhbmNamespace:
|
||||
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
|
||||
|
||||
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
|
||||
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
|
||||
a CUDA runtime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._device_by_greenlet: dict = {}
|
||||
|
||||
def set_device(self, device: int) -> None:
|
||||
from greenlet import getcurrent
|
||||
self._device_by_greenlet[getcurrent()] = int(device)
|
||||
|
||||
def current_device(self) -> int | None:
|
||||
from greenlet import getcurrent
|
||||
return self._device_by_greenlet.get(getcurrent())
|
||||
|
||||
|
||||
class _AcceleratorNamespace:
|
||||
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
|
||||
|
||||
Wraps _AhbmNamespace. Bench code can pick either:
|
||||
torch.ahbm.set_device(rank) # explicit backend
|
||||
torch.accelerator.set_device_index(rank) # portable
|
||||
"""
|
||||
|
||||
def __init__(self, ahbm: "_AhbmNamespace") -> None:
|
||||
self._ahbm = ahbm
|
||||
|
||||
def set_device_index(self, device: int) -> None:
|
||||
self._ahbm.set_device(device)
|
||||
|
||||
def current_device_index(self) -> int | None:
|
||||
return self._ahbm.current_device()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
engine: SimEngine
|
||||
@@ -67,6 +105,10 @@ class RuntimeContext:
|
||||
dc = DistributedContext()
|
||||
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
|
||||
self.distributed = dc
|
||||
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
|
||||
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
|
||||
self.ahbm = _AhbmNamespace()
|
||||
self.accelerator = _AcceleratorNamespace(self.ahbm)
|
||||
|
||||
def install_ipcq(
|
||||
self,
|
||||
@@ -394,12 +436,40 @@ class RuntimeContext:
|
||||
# DPPolicy overrides take precedence over topology dimensions
|
||||
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
|
||||
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||||
eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips
|
||||
# ADR-0024 D11: if torch.ahbm.set_device(r) is active AND DPPolicy
|
||||
# leaves the SIP dimension at its default (replicate + no num_sips
|
||||
# override), scope the tensor to SIP r only.
|
||||
# NOTE: this path uses post-hoc pe_index shifting as a temporary
|
||||
# measure; ADR-0026 replaces it with structural (sip, cube, pe)
|
||||
# coords in ShardSpec.
|
||||
current_sip = (
|
||||
self.ahbm.current_device() if hasattr(self, "ahbm") else None
|
||||
)
|
||||
scope_to_current_sip = (
|
||||
current_sip is not None
|
||||
and dp.sip == "replicate"
|
||||
and dp.num_sips is None
|
||||
)
|
||||
if scope_to_current_sip:
|
||||
eff_num_sips = 1
|
||||
else:
|
||||
eff_num_sips = (
|
||||
dp.num_sips if dp.num_sips is not None else self._num_sips
|
||||
)
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
|
||||
num_sips=eff_num_sips,
|
||||
)
|
||||
if scope_to_current_sip:
|
||||
from kernbench.policy.placement.dp import ShardSpec as _SS
|
||||
sip_stride = self._num_cubes * self._pes_per_cube
|
||||
offset = int(current_sip) * sip_stride
|
||||
placement = [
|
||||
_SS(pe_index=s.pe_index + offset,
|
||||
offset_bytes=s.offset_bytes, nbytes=s.nbytes)
|
||||
for s in placement
|
||||
]
|
||||
|
||||
# Infer target_pe from placement using local (within-cube) PE IDs.
|
||||
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
|
||||
@@ -509,6 +579,7 @@ class RuntimeContext:
|
||||
kernel_name: str,
|
||||
kernel_fn: Any,
|
||||
*args: Any,
|
||||
_defer_wait: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> RequestHandle:
|
||||
"""Register and launch a kernel (like a fused torch op).
|
||||
@@ -518,6 +589,11 @@ class RuntimeContext:
|
||||
|
||||
Creates per-SIP KernelLaunchMsg with local va_base per tensor
|
||||
(like host driver sending per-rank launch commands).
|
||||
|
||||
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
|
||||
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
|
||||
responsible for waiting — used by collective ops to yield between
|
||||
submit and wait so all sibling ranks can submit first.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -683,6 +759,18 @@ class RuntimeContext:
|
||||
_pending_handles.append((h, sip_id))
|
||||
last_handle = h
|
||||
|
||||
if _defer_wait:
|
||||
# ADR-0024 D7: return the pending-list so the caller can yield
|
||||
# between submit and drain. Used by collective ops that need
|
||||
# all sibling ranks to submit before any rank waits.
|
||||
return [
|
||||
(h, sip_id, {
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"sip": sip_id, "target_pe": target_pe,
|
||||
})
|
||||
for h, sip_id in _pending_handles
|
||||
]
|
||||
|
||||
# Drain pending handles now that every SIP has a launch posted.
|
||||
for h, sip_id in _pending_handles:
|
||||
self.wait(h, _meta={
|
||||
|
||||
Reference in New Issue
Block a user