ADR-0024 Phase B (partial): scheduler-level collective drain

Root cause (hang diagnosis):
`kernel_runner.run()` captures `greenlet.getcurrent()` at spawn time as
the kernel greenlet's `_parent`. When a worker greenlet (say g0) calls
`dist.all_reduce` → `ctx.wait(h)` → `env.run(until=h0)`, the SimPy
scheduler steps pe_cpu processes, which in turn spawn kernel greenlets.
Those kernels' `_parent` becomes g0 (current greenlet at spawn). When a
kernel yields via switch_to_simpy, control jumps back up to g0's LAST
switch point — which is the main scheduler's `g.switch()` call — rather
than the kernel_runner's generator frame. Main then re-enters its
`for g in alive: g.switch()` loop mid-wait, producing nested greenlet
re-entry. Scheduler spins: g0 never completes, g1 appears to complete
out of order, infinite loop at 100% CPU.

Fix:
- AhbmCCLBackend.all_reduce: in multi-greenlet mode, submit via
  launch(_defer_wait=True), extend backend._pending_collective_handles,
  and yield to the parent greenlet. Worker does NOT call wait.
- benches/ccl_allreduce.py run(): after each scheduler round, the MAIN
  greenlet drains backend._pending_collective_handles. This keeps
  env.run invocation in the main context, so kernel_runner's spawned
  kernel greenlets have main as their _parent — no nested re-entry.
- Legacy single-driver path (no bench scheduler): all_reduce falls back
  to inline wait when g.parent is None.

Result:
- Multi-greenlet cross-SIP ring no longer hangs (was 100% CPU infinite
  loop in kernel_runner._switch_kernel).
- ring_default_ws still xfail(strict=True): now fails as a data
  correctness issue — DataExecutor reports only 1 math op for a 2-rank
  ring (expected 2). Cross-SIP op_log replay integration is the
  remaining Phase B task.

514 passed, 1 xfailed (strict).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 09:14:03 -07:00
parent 4ba0a83e71
commit 79124daab1
3 changed files with 42 additions and 15 deletions
+15 -1
View File
@@ -156,7 +156,13 @@ def run(torch) -> None:
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
if world_size == n_sips:
# ADR-0024 D12/D13: one greenlet per rank, simple round-robin.
# ADR-0024 D12/D13: one greenlet per rank. After each scheduler
# round, the main greenlet drains any pending collective handles
# (ADR-0024 D7) — this must happen in the main context, not inside
# a worker, so env.run is invoked with main as the current greenlet
# and kernel_runner's spawned kernel greenlets correctly get main
# as their parent.
backend = dist._backend
gs: list[greenlet] = []
for rank in range(world_size):
def _entry(r: int = rank) -> None:
@@ -171,6 +177,14 @@ def run(torch) -> None:
for g in alive:
if not g.dead:
g.switch()
# Drain pending collective handles. All sibling workers have
# either submitted (and yielded) or completed; their kernels
# are live in the SimPy queue, ready to exchange via IPCQ.
pending = backend._pending_collective_handles
if pending:
for h, _sip_id, meta in pending:
torch.wait(h, _meta=meta)
backend._pending_collective_handles = []
else:
# Legacy single-worker path (ccl.yaml world_size override).
worker(rank=dist.get_rank(), world_size=world_size, torch=torch)