"""ADR-0027 T3: Worker-wait generalization + orphan invariant. Direct regression guard for ADR-0024 Phase B's kernel-greenlet orphan bug. Phase 1 of ADR-0027: these tests fail against the current code (no ``_pending_worker_waits`` field, no worker-fork in ``ctx.wait``, no scheduler drain). Phase 2 implements D0.1/D0.2/D0.4 and these pass. """ from __future__ import annotations import os import textwrap import pytest from greenlet import greenlet # ── helpers ────────────────────────────────────────────────────────── def _write_minimal_ccl_yaml(tmp_path) -> str: body = textwrap.dedent("""\ defaults: algorithm: ring_allreduce_tcm buffer_kind: tcm backpressure: sleep n_slots: 4 slot_size: 4096 vc_chunk_size: 256 ipcq_credit_size_bytes: 16 algorithms: ring_allreduce_tcm: module: kernbench.ccl.algorithms.ring_allreduce topology: ring_1d buffer_kind: tcm n_elem: 8 """) yaml_path = tmp_path / "ccl.yaml" yaml_path.write_text(body) return str(tmp_path) def _make_ctx(topology): from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) return RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_t3", spec=topology.topology_obj.spec, ) # ── D0.1: _pending_worker_waits field exists ───────────────────────── def test_pending_worker_waits_field_present(topology): """RuntimeContext must expose the deferred-wait queue (D0.1).""" with _make_ctx(topology) as ctx: assert hasattr(ctx, "_pending_worker_waits"), ( "ADR-0027 D0.1: RuntimeContext must declare _pending_worker_waits" ) assert ctx._pending_worker_waits == [], ( "_pending_worker_waits should start empty" ) # ── T3.a / T3.b: wait defers + resume-after-drain contract ─────────── def test_wait_in_worker_defers_to_main_and_resumes_completed(topology): """T3.a + T3.b: worker ctx.wait enqueues + yields; resume → _completed. Direct test of D0.2 (worker-fork) + D0.3 resume invariant (handle must be in ctx._completed when worker resumes). """ with _make_ctx(topology) as ctx: from kernbench.policy.placement.dp import DPPolicy # Worker that submits one tensor (which internally calls ctx.wait) # and records the pending-queue state observed before/after. observations: dict = {"pre_wait_len": None, "post_resume_completed": None} main = greenlet.getcurrent() def _worker(): # Observation hook: patch ctx.wait to capture a single deferral. original_wait = ctx.wait def wrapping_wait(h, *, _meta=None): observations["pre_wait_len"] = len(ctx._pending_worker_waits) result = original_wait(h, _meta=_meta) observations["post_resume_completed"] = h in ctx._completed return result ctx.wait = wrapping_wait # type: ignore[assignment] try: ctx.zeros( (1, 8), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1), name="t3_defer", ) finally: ctx.wait = original_wait # type: ignore[assignment] g = greenlet(_worker) # Scheduler loop: run worker until it yields (or finishes), then drain. while not g.dead: g.switch() if not g.dead: # Worker yielded mid-wait → simulate D0.4 drain. from kernbench.runtime_api.multiprocessing import _drain_pending _drain_pending(ctx) assert observations["pre_wait_len"] is not None, "wait was not invoked" assert observations["post_resume_completed"] is True, ( "D0.3 resume invariant: handle must be in ctx._completed on resume" ) # ── T3.c: multi-worker same-round drain ────────────────────────────── def test_multiple_workers_resume_at_same_drain(topology): """T3.c: every worker yields before any drain; all resume together.""" with _make_ctx(topology) as ctx: from kernbench.policy.placement.dp import DPPolicy dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) observations: list[int] = [] def _make_worker(rank: int): def _entry(): # Before its wait, observe queue state so we can assert that # *every* worker has enqueued before any drain happened. ctx.zeros((1, 4), dtype="f16", dp=dp, name=f"r{rank}") observations.append(rank) return _entry ws = 2 gs = [greenlet(_make_worker(r)) for r in range(ws)] # Round 1: every worker runs up to its first (deferred) ctx.wait. for g in gs: g.switch() # After round 1, all workers should be paused (not yet dead) and # each should have enqueued at least one handle. assert all(not g.dead for g in gs), ( "after round 1 switch, workers must be paused mid-wait, not dead" ) assert len(ctx._pending_worker_waits) >= ws, ( f"expected >= {ws} pending worker waits after round 1; " f"got {len(ctx._pending_worker_waits)}" ) # Loop: drain + switch rounds until all workers complete. A single # ctx.zeros() call contains multiple yield points (MmuMap, then # MemoryWrite), so more than one round is needed. from kernbench.runtime_api.multiprocessing import _drain_pending rounds = 0 while any(not g.dead for g in gs): _drain_pending(ctx) for g in gs: if not g.dead: g.switch() rounds += 1 assert rounds < 20, "scheduler did not converge within 20 rounds" assert all(g.dead for g in gs), "all workers should be dead after drain loop" assert sorted(observations) == list(range(ws)) # ── T3.d (핵심): kernel greenlet _parent is main ───────────────────── def test_kernel_greenlet_parent_is_main(topology, tmp_path, monkeypatch): """T3.d orphan invariant: kernel_runner._parent must be main greenlet. This is the direct regression guard for ADR-0024 Phase B. Runs a worker that invokes torch.launch (which eventually spawns a kernel greenlet). The kernel_runner.run() captures greenlet.getcurrent() as _parent at spawn time — that value MUST be the main greenlet, else the orphan bug is back. """ monkeypatch.chdir(_write_minimal_ccl_yaml(tmp_path)) from kernbench.triton_emu import kernel_runner as kr_mod captured_parents: list = [] main = greenlet.getcurrent() original_run = kr_mod.KernelRunner.run def _spy_run(self, env, kernel_fn, kernel_args, num_programs): gen = original_run(self, env, kernel_fn, kernel_args, num_programs) def _wrapping_gen(): # yield from gen, but capture self._parent on first step try: value = next(gen) # First yield happens after _parent is set. captured_parents.append(self._parent) yield value except StopIteration: return yield from gen return _wrapping_gen() monkeypatch.setattr(kr_mod.KernelRunner, "run", _spy_run) # Drive a minimal ring_allreduce that launches a kernel inside a worker. import benches.ccl_allreduce as bench with _make_ctx(topology) as ctx: bench.run(ctx) assert captured_parents, "no kernel_runner.run invocations observed" for p in captured_parents: assert p is main, ( f"ADR-0027 D0.7 / T3.d: kernel greenlet _parent must be main " f"greenlet; got {p!r} (main={main!r})" ) # ── T3.f: idempotency ──────────────────────────────────────────────── def test_wait_same_handle_twice_drives_engine_once(topology): """T3.f: ctx.wait(h) + ctx.wait(h) → engine.wait called once (D0.4-(3)).""" with _make_ctx(topology) as ctx: from kernbench.policy.placement.dp import DPPolicy dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) call_count = {"n": 0} original_engine_wait = ctx.engine.wait def _counting_wait(h): call_count["n"] += 1 return original_engine_wait(h) ctx.engine.wait = _counting_wait # type: ignore[assignment] def _worker(): ctx.zeros((1, 4), dtype="f16", dp=dp, name="t3f") # Manually pick a completed handle and wait twice. assert ctx._completed, "there should be at least one completed handle" h = next(iter(ctx._completed)) before = call_count["n"] ctx.wait(h) ctx.wait(h) assert call_count["n"] == before, ( "already-completed handle must not re-drive engine.wait" ) g = greenlet(_worker) while not g.dead: g.switch() if not g.dead: from kernbench.runtime_api.multiprocessing import _drain_pending _drain_pending(ctx) # ── T3.g: exception propagation + no further drain ─────────────────── def test_worker_exception_propagates_and_clears_pending(topology): """T3.g: worker raise → main propagates; _pending_worker_waits cleared.""" with _make_ctx(topology) as ctx: from kernbench.runtime_api.multiprocessing import SpawnException def _bad_worker(rank: int): raise ValueError(f"rank {rank} intentional failure") with pytest.raises(SpawnException) as exc_info: ctx.multiprocessing.spawn(_bad_worker, args=(), nprocs=2) assert ctx._pending_worker_waits == [], ( "D0.4-(4): _pending_worker_waits must be cleared on failure" ) # Root-cause rank errors are present; sibling SystemExit not in dict. assert 0 in exc_info.value.errors or 1 in exc_info.value.errors # ── T3.e: historical failure (pre-D0) — skipped per ADR ────────────── @pytest.mark.skip( reason="ADR-0027 T3.e: historical failure mode — reproduces only " "pre-D0.2. Kept as documentation; not run in Phase 2." ) def test_pre_d0_orphan_reproduction(): """Placeholder: exercises the pre-D0.2 code path that causes GreenletExit from kernel_runner._parent captured in worker context. See ADR-0024 Phase B postmortem.""" pass