Implement ADR-0024 Phase A: SIP-level TP launcher MVP

Scope (Phase A):
- D1: world_size fallback = SIP count (rank = SIP, TP boundary)
- D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0)
- D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias
- D11: tensor placement scoped to current-device SIP (post-hoc pe_index
  shift — ADR-0026 replaces with structural coords)
- D12/D13: multi-greenlet run() with simple round-robin scheduler;
  hybrid dispatch (ws == SIP count → multi-greenlet, else legacy
  single-worker for ccl.yaml override compat)
- D7 partial: backend.all_reduce submit + yield + wait via launch()'s
  new _defer_wait flag; parent-less greenlets skip yield
- Relaxed shard-count check (len(shards) > 0 instead of == world_size)
- rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips

Deferred to Phase B:
- Engine-routed install (D2) — keeps sideband
- install_plan.py module (D6) — keeps install.py
- Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock
- Validator registry (D8)
- Cross-SIP multi-greenlet + real kernel integration — matrix
  ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix;
  marked xfail(run=False) pending Phase B diagnosis (suspected per-rank
  kernel_args / program_id mismatch)

Tests:
- test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13
- test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override
  cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 /
  tree_binary_7) all pass via legacy path

514 tests pass, 1 xfail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 09:00:28 -07:00
parent 32536daf2e
commit 4ba0a83e71
6 changed files with 491 additions and 71 deletions
+89 -1
View File
@@ -42,6 +42,44 @@ def _numpy_to_dtype_str(np_dtype) -> str:
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
class _AhbmNamespace:
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
a CUDA runtime.
"""
def __init__(self) -> None:
self._device_by_greenlet: dict = {}
def set_device(self, device: int) -> None:
from greenlet import getcurrent
self._device_by_greenlet[getcurrent()] = int(device)
def current_device(self) -> int | None:
from greenlet import getcurrent
return self._device_by_greenlet.get(getcurrent())
class _AcceleratorNamespace:
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
Wraps _AhbmNamespace. Bench code can pick either:
torch.ahbm.set_device(rank) # explicit backend
torch.accelerator.set_device_index(rank) # portable
"""
def __init__(self, ahbm: "_AhbmNamespace") -> None:
self._ahbm = ahbm
def set_device_index(self, device: int) -> None:
self._ahbm.set_device(device)
def current_device_index(self) -> int | None:
return self._ahbm.current_device()
@dataclass
class RuntimeContext:
engine: SimEngine
@@ -67,6 +105,10 @@ class RuntimeContext:
dc = DistributedContext()
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
self.distributed = dc
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
self.ahbm = _AhbmNamespace()
self.accelerator = _AcceleratorNamespace(self.ahbm)
def install_ipcq(
self,
@@ -394,12 +436,40 @@ class RuntimeContext:
# DPPolicy overrides take precedence over topology dimensions
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips
# ADR-0024 D11: if torch.ahbm.set_device(r) is active AND DPPolicy
# leaves the SIP dimension at its default (replicate + no num_sips
# override), scope the tensor to SIP r only.
# NOTE: this path uses post-hoc pe_index shifting as a temporary
# measure; ADR-0026 replaces it with structural (sip, cube, pe)
# coords in ShardSpec.
current_sip = (
self.ahbm.current_device() if hasattr(self, "ahbm") else None
)
scope_to_current_sip = (
current_sip is not None
and dp.sip == "replicate"
and dp.num_sips is None
)
if scope_to_current_sip:
eff_num_sips = 1
else:
eff_num_sips = (
dp.num_sips if dp.num_sips is not None else self._num_sips
)
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
num_sips=eff_num_sips,
)
if scope_to_current_sip:
from kernbench.policy.placement.dp import ShardSpec as _SS
sip_stride = self._num_cubes * self._pes_per_cube
offset = int(current_sip) * sip_stride
placement = [
_SS(pe_index=s.pe_index + offset,
offset_bytes=s.offset_bytes, nbytes=s.nbytes)
for s in placement
]
# Infer target_pe from placement using local (within-cube) PE IDs.
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
@@ -509,6 +579,7 @@ class RuntimeContext:
kernel_name: str,
kernel_fn: Any,
*args: Any,
_defer_wait: bool = False,
**kwargs: Any,
) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op).
@@ -518,6 +589,11 @@ class RuntimeContext:
Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands).
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
responsible for waiting — used by collective ops to yield between
submit and wait so all sibling ranks can submit first.
"""
from collections import defaultdict
@@ -683,6 +759,18 @@ class RuntimeContext:
_pending_handles.append((h, sip_id))
last_handle = h
if _defer_wait:
# ADR-0024 D7: return the pending-list so the caller can yield
# between submit and drain. Used by collective ops that need
# all sibling ranks to submit before any rank waits.
return [
(h, sip_id, {
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
for h, sip_id in _pending_handles
]
# Drain pending handles now that every SIP has a launch posted.
for h, sip_id in _pending_handles:
self.wait(h, _meta={