"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15). Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy, no PE_DMA, no fabric simulation. Just enough to verify *functional* correctness of an IPCQ-based collective algorithm. Cross-rank send/recv is implemented with greenlet cooperative scheduling plus per-(rank, direction) FIFO queues. Backpressure is not modeled — queues are unbounded. Typical usage in a test:: from kernbench.ccl.testing import run_kernel_in_mock from kernbench.ccl.algorithms.ring_allreduce import kernel inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)] outputs = run_kernel_in_mock( kernel_fn=kernel, world_size=4, topology="ring_1d", inputs=inputs, kernel_args=(16,), ) for r in range(4): assert np.allclose(outputs[r], sum(inputs)) """ from __future__ import annotations from collections import deque from typing import Any, Callable import numpy as np from greenlet import greenlet from kernbench.ccl.topologies import resolve_topology from kernbench.common.ipcq_types import IpcqInvalidDirection from kernbench.common.pe_commands import TensorHandle # ── Per-rank fake state ────────────────────────────────────────────── class _MockRankState: """Per-rank scratch holding HBM/recv slots and tl shim hooks.""" def __init__( self, rank: int, world_size: int, neighbors: dict[str, int], input_arr: np.ndarray, pes_per_cube: int = 0, ) -> None: self.rank = rank self.world_size = world_size # PEs per cube for program_id(axis=0/1). If 0 or world_size, # all ranks are in one cube (legacy single-cube behavior). self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size self.neighbors = neighbors # direction → peer rank # HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing. self._hbm: dict[int, np.ndarray] = {} self._tcm: dict[int, np.ndarray] = {} # ``t_ptr`` is the address the kernel sees. Real benches use a # column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``. # Mirror that here: each rank's slice lives at the rank-specific addr. nbytes = int(input_arr.nbytes) self.t_ptr = 0 # base; per-rank offset is rank * nbytes self._slice_addr = rank * nbytes self._hbm[self._slice_addr] = input_arr.copy() # Inbound recv FIFOs: direction → deque[ndarray] self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors} # Output (set when kernel calls tl.store at slice address) self.output: np.ndarray | None = None # Greenlet for this rank — set later self.g: greenlet | None = None # ── Mock TLContext ─────────────────────────────────────────────────── class _MockTL: """Drop-in tl shim for mock runtime. Supports the subset of TLContext API that algorithm authors use: program_id, num_programs, load, store, send, recv, recv_async, wait, plus arithmetic operations on TensorHandle (eager numpy execution, no SimPy involved). """ def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None: self._state = state self._scheduler = scheduler self._handle_counter = 0 def _next_id(self) -> str: self._handle_counter += 1 return f"mt{self._handle_counter}" @property def rank(self) -> int: return self._state.rank @property def world_size(self) -> int: return self._state.world_size # axis-aware def program_id(self, axis: int = 0) -> int: # Multi-cube: axis=0 = PE within cube, axis=1 = global cube id. # Falls back to flat (all ranks in one cube) if pes_per_cube # is not set (legacy single-cube tests). ppc = self._state.pes_per_cube if axis == 1: return self._state.rank // ppc return self._state.rank % ppc def num_programs(self, axis: int = 0) -> int: ppc = self._state.pes_per_cube if axis == 1: return self._state.world_size // ppc return ppc # ── arithmetic ops (called by TensorHandle.__add__ etc.) ── def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle: a_data = np.asarray(a.data) if a.data is not None else None b_data = np.asarray(b.data) if b.data is not None else None if a_data is None or b_data is None: result = None elif op == "add": result = a_data + b_data elif op == "sub": result = a_data - b_data elif op == "mul": result = a_data * b_data elif op == "div": result = a_data / b_data elif op == "maximum": result = np.maximum(a_data, b_data) elif op == "minimum": result = np.minimum(a_data, b_data) else: raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented") return TensorHandle( id=self._next_id(), addr=0, shape=a.shape, dtype=a.dtype, nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0, data=result, space="tcm", ) def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle: return self._binary_math("maximum", a, b) def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle: return self._binary_math("minimum", a, b) def fma( self, a: TensorHandle, b: TensorHandle, c: TensorHandle, ) -> TensorHandle: a_data = np.asarray(a.data) if a.data is not None else None b_data = np.asarray(b.data) if b.data is not None else None c_data = np.asarray(c.data) if c.data is not None else None result = ( a_data * b_data + c_data if (a_data is not None and b_data is not None and c_data is not None) else None ) return TensorHandle( id=self._next_id(), addr=0, shape=a.shape, dtype=a.dtype, nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0, data=result, space="tcm", ) def clamp( self, x: TensorHandle, min: TensorHandle, max: TensorHandle, ) -> TensorHandle: x_data = np.asarray(x.data) if x.data is not None else None lo = np.asarray(min.data) if min.data is not None else None hi = np.asarray(max.data) if max.data is not None else None result = ( np.minimum(np.maximum(x_data, lo), hi) if (x_data is not None and lo is not None and hi is not None) else None ) return TensorHandle( id=self._next_id(), addr=0, shape=x.shape, dtype=x.dtype, nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0, data=result, space="tcm", ) def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle: x_data = np.asarray(x.data) if x.data is not None else None if x_data is None: result = None else: x_max = np.max(x_data, axis=axis, keepdims=True) e = np.exp(x_data - x_max) s = np.sum(e, axis=axis, keepdims=True) result = e / s return TensorHandle( id=self._next_id(), addr=0, shape=x.shape, dtype=x.dtype, nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0, data=result, space="tcm", ) @staticmethod def cdiv(a: int, b: int) -> int: return -(-int(a) // int(b)) def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle: x_data = np.asarray(x.data) if x.data is not None else None if x_data is None: result = None elif op == "exp": result = np.exp(x_data) elif op == "log": result = np.log(x_data) elif op == "sqrt": result = np.sqrt(x_data) elif op == "abs": result = np.abs(x_data) elif op == "sigmoid": result = 1.0 / (1.0 + np.exp(-x_data)) elif op == "cos": result = np.cos(x_data) elif op == "sin": result = np.sin(x_data) else: raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented") return TensorHandle( id=self._next_id(), addr=0, shape=x.shape, dtype=x.dtype, nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0, data=result, space="tcm", ) def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle: data = self._state._hbm.get(ptr) if data is None: data = np.zeros(shape, dtype=np.float16) return TensorHandle( id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype, nbytes=int(np.prod(shape)) * 2, data=data, space="hbm", ) def store(self, ptr: int, handle: TensorHandle) -> None: if handle.data is not None: self._state._hbm[ptr] = np.asarray(handle.data) if ptr == self._state._slice_addr: self._state.output = self._state._hbm[ptr] # IPCQ def send( self, dir: str, src: TensorHandle | None = None, *, src_addr: int | None = None, nbytes: int | None = None, shape: tuple[int, ...] | None = None, dtype: str = "f16", space: str = "tcm", ) -> None: if dir not in self._state.neighbors: raise IpcqInvalidDirection( f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}" ) if src is not None: if src.data is not None: data = np.asarray(src.data) else: # Resolve from this rank's local memory at src.addr space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm stored = space_dict.get(src.addr) if stored is None: raise RuntimeError( f"mock tl.send: no data at {src.space}:0x{src.addr:x}" ) data = np.asarray(stored) else: data = None if data is None: raise RuntimeError("mock tl.send: src is None") peer_rank = self._state.neighbors[dir] # Find the reverse direction at the peer, mirroring real IPCQ # install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc. _REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E", "parent": "parent", "child_left": "child_left", "child_right": "child_right"} peer_state = self._scheduler.states[peer_rank] reverse_dir = _REVERSE.get(dir) # Fall back to "first direction pointing at me" if the explicit # reverse doesn't exist at the peer (e.g. custom directions). if reverse_dir is None or reverse_dir not in peer_state.neighbors: reverse_dir = None for d, target in peer_state.neighbors.items(): if target == self._state.rank: reverse_dir = d break if reverse_dir is None: raise RuntimeError( f"mock tl.send: peer rank {peer_rank} has no reverse direction" ) peer_state.recv_q[reverse_dir].append(data.copy()) self._scheduler._send_counter += 1 # After delivering, hand control back to scheduler so the receiver # can wake up. self._scheduler.yield_() def recv_async( self, dir: str, shape: tuple[int, ...] = (), dtype: str = "f16", ) -> dict: """Non-blocking recv. Returns a future dict to pass to tl.wait.""" if dir not in self._state.neighbors: raise IpcqInvalidDirection( f"mock tl.recv_async: direction {dir!r} not in neighbors" ) return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype} def wait(self, future: Any) -> TensorHandle: """Block until the recv future has data.""" if not isinstance(future, dict) or future.get("_kind") != "recv_future": raise TypeError("tl.wait: expected recv future from tl.recv_async") d = future["dir"] while not self._state.recv_q[d]: self._scheduler.yield_() data = self._state.recv_q[d].popleft() return self._make_handle(data, d, future["dtype"]) def recv( self, dir: str | None = None, shape: tuple[int, ...] = (), dtype: str = "f16", ) -> TensorHandle: if dir is not None and dir not in self._state.neighbors: raise IpcqInvalidDirection( f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}" ) # Wait for data while True: if dir is None: # round-robin over directions for d in self._state.neighbors: if self._state.recv_q[d]: data = self._state.recv_q[d].popleft() return self._make_handle(data, d, dtype) else: if self._state.recv_q[dir]: data = self._state.recv_q[dir].popleft() return self._make_handle(data, dir, dtype) # Yield to other ranks self._scheduler.yield_() def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle: return TensorHandle( id=f"recv_{direction}", addr=0, shape=data.shape, dtype=dtype, nbytes=int(data.nbytes), data=data, space="tcm", ) # ── Cooperative scheduler ──────────────────────────────────────────── class _MockScheduler: """Round-robin cooperative scheduler over rank greenlets.""" def __init__(self, states: list[_MockRankState]) -> None: self.states = states self._parent: greenlet | None = None self._cur_idx = 0 def yield_(self) -> None: """Called from inside a rank greenlet to give other ranks a turn.""" assert self._parent is not None self._parent.switch() def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]: from kernbench.triton_emu.tl_context import TLContext self._parent = greenlet.getcurrent() n = len(self.states) # Per-rank tl shim tls: dict[int, _MockTL] = {} def _spawn(rank_idx: int) -> greenlet: state = self.states[rank_idx] tl = _MockTL(state, self) tls[rank_idx] = tl def _entry(): # Activate this rank's tl for TensorHandle operator overloads TLContext._set_active(tl) # type: ignore[attr-defined] try: kernel_fn(state.t_ptr, *kernel_args, tl=tl) finally: TLContext._set_active(None) # type: ignore[attr-defined] return greenlet(_entry) for state in self.states: state.g = _spawn(state.rank) # Drive each rank round-robin until all dead. Detect global deadlock. # A global send counter tracks whether any greenlet delivered data # in the current round. This is more reliable than queue-depth # tracking because a recv+send pair in the same round nets to zero # depth change yet still represents real progress. self._send_counter = 0 max_idle_rounds = 10_000 idle_rounds = 0 while True: alive = [s for s in self.states if s.g is not None and not s.g.dead] if not alive: break counter_before = self._send_counter for s in self.states: if s.g is None or s.g.dead: continue TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined] s.g.switch() TLContext._set_active(None) # type: ignore[attr-defined] any_died = any(s.g is not None and s.g.dead for s in self.states) if self._send_counter > counter_before or any_died: idle_rounds = 0 else: idle_rounds += 1 if idle_rounds >= max_idle_rounds: raise RuntimeError( "mock CCL runtime: deadlock detected (no progress for " f"{max_idle_rounds} rounds)" ) return [ s.output if s.output is not None else s._hbm.get(s._slice_addr) for s in self.states ] # ── Public entry ──────────────────────────────────────────────────── def run_kernel_in_mock( kernel_fn: Callable, world_size: int, topology: str, inputs: list[np.ndarray], kernel_args: tuple = (), algo_module: Any | None = None, pes_per_cube: int = 0, ) -> list[np.ndarray]: """Run a CCL kernel under the mock runtime with no SimPy/fabric. Args: kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)`` world_size: number of ranks topology: builtin topology name (e.g. "ring_1d") inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's local tile at HBM address 0. kernel_args: extra positional args after t_ptr algo_module: optional module providing ``neighbors()`` override pes_per_cube: PEs per cube for multi-cube program_id mapping. 0 → single-cube legacy (all ranks in one cube). Returns: Per-rank output ndarrays — whatever the kernel wrote via tl.store (or the original input if the kernel didn't store). """ if len(inputs) != world_size: raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}") topo_fn = resolve_topology(topology, algo_module=algo_module) states = [ _MockRankState( rank=r, world_size=world_size, neighbors=topo_fn(r, world_size), input_arr=inputs[r], pes_per_cube=pes_per_cube, ) for r in range(world_size) ] sched = _MockScheduler(states) return sched.run(kernel_fn, kernel_args)