"""TLContext: fake Triton Language module for kernel performance simulation. Passed as the `tl` parameter to kernel functions. Each API call records a PeCommand in the internal trace. After the kernel returns, PE_CPU replays the command list through SimPy. Kernel code looks like standard Python — no yield, no async: def my_kernel(a_ptr, b_ptr, out_ptr, tl): pid = tl.program_id(0) a = tl.load(a_ptr, shape=(32, 64), dtype="f16") b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16") tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr) """ from __future__ import annotations import math from typing import Literal from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture from kernbench.common.pe_commands import ( CompletionHandle, CompositeCmd, DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, PeCommand, PeCpuOverheadCmd, TensorHandle, WaitCmd, ) _DTYPE_BYTES: dict[str, int] = { "f16": 2, "f32": 4, "f64": 8, "bf16": 2, "i8": 1, "i16": 2, "i32": 4, "i64": 8, "u8": 1, "u16": 2, "u32": 4, "u64": 8, } class TLContext: """Fake Triton Language context. Args: pe_id: program instance index (returned by program_id). num_programs: total number of program instances. dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted). """ def __init__( self, pe_id: int = 0, num_programs: int = 1, dispatch_cycles: int = 1, runner: Any = None, cube_id: int = 0, num_cubes: int = 1, scratch_base: int = 0, scratch_size: int = 1 << 20, # 1 MiB per kernel invocation ) -> None: self._pe_id = pe_id self._num_programs = num_programs self._cube_id = cube_id self._num_cubes = num_cubes self._dispatch_cycles = dispatch_cycles self._commands: list[PeCommand] = [] self._handle_counter = 0 self._completion_counter = 0 self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3) # PE-local scratch allocator for math/compute output handles. # Each binary/unary/reduction op auto-allocates a unique addr from # this pool so the resulting TensorHandle can be the source of a # later tl.send / tl.store. Cursor resets on every kernel invocation. self._scratch_base = scratch_base self._scratch_size = scratch_size self._scratch_cursor = 0 def _scratch_alloc(self, nbytes: int) -> int: """Allocate a unique scratch address for an output TensorHandle. Returns 0 if no scratch base was configured (e.g. command-list mode); in that case the resulting handle has addr=0 and cannot be used as a send/store source. Greenlet/runner mode always supplies a base. """ if self._scratch_base == 0: return 0 # 16-byte alignment aligned = (nbytes + 15) & ~15 addr = self._scratch_base + self._scratch_cursor self._scratch_cursor += aligned if self._scratch_cursor > self._scratch_size: raise RuntimeError( f"TLContext scratch overflow: requested {nbytes}B, " f"used {self._scratch_cursor}/{self._scratch_size}B" ) return addr @property def commands(self) -> list[PeCommand]: """Return the recorded command trace.""" return self._commands # ── helpers ──────────────────────────────────────────────────── def _next_handle_id(self) -> str: self._handle_counter += 1 return f"t{self._handle_counter}" def _next_completion_id(self) -> str: self._completion_counter += 1 return f"c{self._completion_counter}" def _dtype_bytes(self, dtype: str) -> int: return _DTYPE_BYTES.get(dtype, 2) def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int: return math.prod(shape) * self._dtype_bytes(dtype) def _emit_dispatch_overhead(self) -> None: if self._dispatch_cycles > 0: self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles)) def _make_handle( self, addr: int, shape: tuple[int, ...], dtype: str, space: str = "tcm", pinned: bool = False, ) -> TensorHandle: return TensorHandle( id=self._next_handle_id(), addr=addr, shape=shape, dtype=dtype, nbytes=self._nbytes(shape, dtype), space=space, pinned=pinned, ) def _make_compute_out( self, shape: tuple[int, ...], dtype: str, ) -> TensorHandle: """Allocate an output TensorHandle in PE-local scratch (TCM space). Used by math/compute ops so the result has a real address that can be the source of a later send/store. The data field stays None in Phase 1 — Phase 2 DataExecutor fills the actual ndarray. """ nbytes = self._nbytes(shape, dtype) addr = self._scratch_alloc(nbytes) return TensorHandle( id=self._next_handle_id(), addr=addr, shape=shape, dtype=dtype, nbytes=nbytes, space="tcm", ) # ── Reference (no DMA, metadata only) ──────────────────────── def ref( self, ptr: int, shape: tuple[int, ...], dtype: str = "f16", ) -> TensorHandle: """Create a TensorHandle referencing HBM data without issuing DMA. Used when the scheduler will stream data per-tile (e.g., tensor b in a composite GEMM). No command is generated. """ return self._make_handle(addr=ptr, shape=shape, dtype=dtype) # ── Data Movement (blocking, DMA engine) ────────────────────── def _emit(self, cmd: PeCommand) -> Any: """Emit command: greenlet switch if runner available, else append to list.""" if self._runner is not None: return self._runner.switch_to_simpy(cmd) self._commands.append(cmd) return None def load( self, ptr: int, shape: tuple[int, ...], dtype: str = "f16", ) -> TensorHandle: """Load tensor from HBM. Returns TensorHandle pointing at HBM[ptr]. In greenlet mode: returns TensorHandle with actual numpy data. In command-list mode: returns TensorHandle with data=None. The returned handle's ``space`` is "hbm" so subsequent ops (math, send, store) using this handle as a source resolve via MemoryStore at ``(hbm, ptr)`` — which is where the load's underlying data actually lives in Phase 2 storage. """ self._emit_dispatch_overhead() handle = self._make_handle( addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True, ) cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes) data = self._emit(cmd) if data is not None: # Greenlet mode: attach real data to handle (preserve space + pinned) return TensorHandle( id=handle.id, addr=handle.addr, shape=handle.shape, dtype=handle.dtype, nbytes=handle.nbytes, data=data, space=handle.space, pinned=handle.pinned, ) return handle def store(self, ptr: int, handle: TensorHandle) -> None: """Store tensor from TCM to HBM.""" self._emit_dispatch_overhead() cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes) self._emit(cmd) # ── GEMM Engine (blocking) ──────────────────────────────────── def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle: """Matrix multiply: out = a @ b. Both operands must be in TCM. a: (M, K), b: (K, N) → out: (M, N) """ if len(a.shape) < 2 or len(b.shape) < 2: raise ValueError("dot requires 2D tensors") m, k = a.shape[-2], a.shape[-1] k2, n = b.shape[-2], b.shape[-1] if k != k2: raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}") out_shape = (*a.shape[:-2], m, n) out_dtype = a.dtype out = self._make_compute_out(shape=out_shape, dtype=out_dtype) self._emit_dispatch_overhead() self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n)) return out # ── MATH Engine: unary (blocking) ───────────────────────────── def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle: out = self._make_compute_out(shape=x.shape, dtype=x.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op=op, inputs=(x,), out=out)) return out def exp(self, x: TensorHandle) -> TensorHandle: return self._unary_math("exp", x) def log(self, x: TensorHandle) -> TensorHandle: return self._unary_math("log", x) def sqrt(self, x: TensorHandle) -> TensorHandle: return self._unary_math("sqrt", x) def abs(self, x: TensorHandle) -> TensorHandle: return self._unary_math("abs", x) def sigmoid(self, x: TensorHandle) -> TensorHandle: return self._unary_math("sigmoid", x) def cos(self, x: TensorHandle) -> TensorHandle: return self._unary_math("cos", x) def sin(self, x: TensorHandle) -> TensorHandle: return self._unary_math("sin", x) # ── MATH Engine: reduction (blocking) ───────────────────────── def _reduction( self, op: str, x: TensorHandle, axis: int, ) -> TensorHandle: out_shape = list(x.shape) out_shape[axis] = 1 out = self._make_compute_out(shape=tuple(out_shape), dtype=x.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis)) return out def sum(self, x: TensorHandle, axis: int) -> TensorHandle: return self._reduction("sum", x, axis) def max(self, x: TensorHandle, axis: int) -> TensorHandle: return self._reduction("max", x, axis) def min(self, x: TensorHandle, axis: int) -> TensorHandle: return self._reduction("min", x, axis) # ── MATH Engine: binary (blocking) ──────────────────────────── def _binary_math( self, op: str, a: TensorHandle, b: TensorHandle, ) -> TensorHandle: out = self._make_compute_out(shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op=op, inputs=(a, b), out=out)) return out def where( self, cond: TensorHandle, a: TensorHandle, b: TensorHandle, ) -> TensorHandle: out = self._make_compute_out(shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out)) return out def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle: """Element-wise max of two tensors (real Triton: tl.maximum).""" return self._binary_math("maximum", a, b) def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle: """Element-wise min of two tensors (real Triton: tl.minimum).""" return self._binary_math("minimum", a, b) def fma( self, a: TensorHandle, b: TensorHandle, c: TensorHandle, ) -> TensorHandle: """Fused multiply-add: a * b + c (real Triton: tl.fma).""" out = self._make_compute_out(shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op="fma", inputs=(a, b, c), out=out)) return out def clamp( self, x: TensorHandle, min: TensorHandle, max: TensorHandle, ) -> TensorHandle: """Clamp x to [min, max] (real Triton: tl.clamp).""" out = self._make_compute_out(shape=x.shape, dtype=x.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op="clamp", inputs=(x, min, max), out=out)) return out def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle: """Numerically-stable softmax along ``axis`` (real Triton: tl.softmax). Implemented as a single MathCmd (op="softmax") so timing accounts for one MATH dispatch; Phase 2 DataExecutor expands it to the canonical (x - max) → exp → sum → div sequence. """ out = self._make_compute_out(shape=x.shape, dtype=x.dtype) self._emit_dispatch_overhead() self._emit(MathCmd(op="softmax", inputs=(x,), out=out, axis=axis)) return out # ── Scalar helpers (real Triton: tl.cdiv etc.) ──────────────── @staticmethod def cdiv(a: int, b: int) -> int: """Ceiling division: (a + b - 1) // b (real Triton: tl.cdiv). Used by host/kernel grid math; not a tensor op, so no MathCmd is emitted. Mirrors triton.cdiv. """ return -(-int(a) // int(b)) # ── Index / Scalar (PE_CPU, no engine) ──────────────────────── def program_id(self, axis: int = 0) -> int: """Return program instance index (ADR-0022). axis=0: local PE id within cube. axis=1: cube id. """ if axis == 1: return self._cube_id return self._pe_id def num_programs(self, axis: int = 0) -> int: """Return total number of program instances (ADR-0022). axis=0: num PEs per cube. axis=1: num cubes. """ if axis == 1: return self._num_cubes return self._num_programs def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle: """Create index range tensor in TCM.""" n = end - start return self._make_handle(addr=0, shape=(n,), dtype=dtype) def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle: """Create zero-filled tensor in TCM.""" return self._make_handle(addr=0, shape=shape, dtype=dtype) def full( self, shape: tuple[int, ...], value: float | int, dtype: str = "f16", ) -> TensorHandle: """Create constant-filled tensor in TCM.""" return self._make_handle(addr=0, shape=shape, dtype=dtype) # ── Metadata (no compute, no DMA) ───────────────────────────── def trans(self, x: TensorHandle) -> TensorHandle: """Transpose — shape change only, no command generated.""" if len(x.shape) < 2: raise ValueError("trans requires at least 2D tensor") new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2]) return TensorHandle( id=x.id, addr=x.addr, shape=new_shape, dtype=x.dtype, nbytes=x.nbytes, data=x.data, ) # ── IPCQ (CCL) collective primitives (ADR-0023 D4) ──────────── def send( self, dir: str, src: TensorHandle | None = None, *, src_addr: int | None = None, nbytes: int | None = None, shape: tuple[int, ...] | None = None, dtype: str = "f16", space: str = "tcm", ) -> None: """Send tensor data to the peer in the given direction. Two calling forms: tl.send(dir, handle) # use handle's metadata tl.send(dir, src_addr=..., nbytes=..., shape=..., dtype=..., space=...) Blocking: returns when PE_IPCQ has accepted the request and forwarded the IpcqDmaToken to PE_DMA. Backpressure may apply. """ if src is not None: src_addr = src.addr nbytes = src.nbytes shape = src.shape dtype = src.dtype space = getattr(src, "space", space) if src_addr is None or nbytes is None or shape is None: raise ValueError("tl.send: provide either a TensorHandle or src_addr/nbytes/shape") # Carry the handle's .data snapshot (if available). When the source # is a recv slot, .data holds the numpy array that was read from # MemoryStore at recv-time. This prevents a Phase 1 race where a # later IPCQ inbound overwrites the slot before the outbound # PE_DMA reads it. handle_data = getattr(src, "data", None) if src is not None else None self._emit_dispatch_overhead() cmd = IpcqSendCmd( direction=dir, src_addr=src_addr, src_space=space, nbytes=nbytes, shape=shape, dtype=dtype, handle_id=self._next_handle_id(), data=handle_data, ) self._emit(cmd) def recv( self, dir: str | None = None, shape: tuple[int, ...] = (), dtype: str = "f16", space: str = "tcm", dst_addr: int | None = None, dst_space: str | None = None, ) -> TensorHandle: """Receive tensor data from a peer. Args: dir: specific direction (e.g. "W"), or None for round-robin. shape, dtype: expected tensor metadata. dst_addr / dst_space: if both are provided, the slot data is copied to (dst_space, dst_addr) before the handle is returned ("copy_to_dst" mode). Otherwise the slot address is returned directly ("return_slot" mode). Returns: TensorHandle pointing to the slot (or dst) where the data has arrived. In greenlet/runner mode, ``handle.data`` carries the actual ndarray; in command-list mode the handle is a placeholder. """ self._emit_dispatch_overhead() if dst_addr is not None and dst_space is not None: cmd = IpcqRecvCmd( direction=dir, shape=shape, dtype=dtype, handle_id=self._next_handle_id(), recv_mode="copy_to_dst", dst_addr=dst_addr, dst_space=dst_space, ) else: cmd = IpcqRecvCmd( direction=dir, shape=shape, dtype=dtype, handle_id=self._next_handle_id(), ) result = self._emit(cmd) if isinstance(result, dict): slot_addr = int(result.get("src_addr", 0)) slot_space = str(result.get("src_space", "tcm")) data = result.get("data") return TensorHandle( id=self._next_handle_id(), addr=slot_addr, shape=shape, dtype=dtype, nbytes=self._nbytes(shape, dtype), data=data, space=slot_space, ) return self._make_handle(addr=0, shape=shape, dtype=dtype) def recv_no_consume( self, dir: str | None = None, shape: tuple[int, ...] = (), dtype: str = "f16", ) -> TensorHandle: """DIAGNOSTIC ONLY — recv that blocks for arrival but skips slot read. Same blocking semantics as ``tl.recv``: the kernel waits until the payload has landed in the IPCQ slot. Differs from ``tl.recv`` by skipping the slot-read latency charge (slot-IO + PE↔bank fabric drain) on DST. This entry point exists solely so the pe2pe overview plot can draw an apples-to-apples comparison against ``tl.store`` (a one-sided fabric write that pays no read on DST). Production kernels MUST use ``tl.recv`` — they need to consume the data they receive. This API is segregated from ``tl.recv`` so the diagnostic flag can never accidentally be set in real workloads. """ self._emit_dispatch_overhead() cmd = IpcqRecvCmd( direction=dir, shape=shape, dtype=dtype, handle_id=self._next_handle_id(), consume=False, ) result = self._emit(cmd) # type: ignore[arg-type] if isinstance(result, dict): slot_addr = int(result.get("src_addr", 0)) slot_space = str(result.get("src_space", "tcm")) return TensorHandle( id=self._next_handle_id(), addr=slot_addr, shape=shape, dtype=dtype, nbytes=self._nbytes(shape, dtype), data=None, space=slot_space, ) return self._make_handle(addr=0, shape=shape, dtype=dtype) def recv_async( self, dir: str, shape: tuple[int, ...] = (), dtype: str = "f16", ) -> "RecvFuture": """Non-blocking recv. Returns a future to pass into ``tl.wait``.""" self._emit_dispatch_overhead() cmd = IpcqRecvCmd( direction=dir, shape=shape, dtype=dtype, handle_id=self._next_handle_id(), blocking=False, ) future = RecvFuture(cmd=cmd) if self._runner is not None: self._runner.switch_to_simpy(("recv_async", future)) return future # ── Composite + Control ─────────────────────────────────────── def composite( self, op: Literal["gemm", "math"], a: TensorHandle, b: TensorHandle | None = None, out_ptr: int = 0, math_op: str | None = None, ) -> CompletionHandle: """Submit a composite command (non-blocking, tiled pipeline). Returns CompletionHandle for use with wait(). """ # Compute output size based on op if op == "gemm" and b is not None: m, k = a.shape[-2], a.shape[-1] n = b.shape[-1] out_dtype = a.dtype out_nbytes = m * n * self._dtype_bytes(out_dtype) else: out_nbytes = a.nbytes completion = CompletionHandle(id=self._next_completion_id()) self._emit_dispatch_overhead() self._emit(CompositeCmd( completion=completion, op=op, a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes, math_op=math_op, )) return completion def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any: """Wait for a composite, a recv future, or all pending composites. - ``CompletionHandle`` (or None): wait for composite completion. - ``RecvFuture``: wait for a non-blocking ``recv_async`` to finish. Returns the resolved ``TensorHandle``. """ if isinstance(handle, RecvFuture): if handle.resolved: return handle.result if self._runner is None: raise RuntimeError( "tl.wait(RecvFuture) requires runner mode (greenlet)" ) result_dict = self._runner.switch_to_simpy(("recv_wait", handle)) slot_addr = int(result_dict.get("src_addr", 0)) slot_space = str(result_dict.get("src_space", "tcm")) data = result_dict.get("data") th = TensorHandle( id=self._next_handle_id(), addr=slot_addr, shape=handle.cmd.shape, dtype=handle.cmd.dtype, nbytes=self._nbytes(handle.cmd.shape, handle.cmd.dtype), data=data, space=slot_space, ) handle.resolved = True handle.result = th return th # Composite path (existing behaviour) self._emit(WaitCmd(handle=handle)) return None def cycles(self, n: int) -> None: """Declare PE_CPU scalar execution overhead (cycles).""" self._emit(PeCpuOverheadCmd(cycles=n)) # ── TensorHandle arithmetic operators ───────────────────────────── # Enables: a + b, a * b, a - b, a / b in kernel code. # Each creates a MathCmd via a module-level helper that requires a # TLContext. We attach the context to handles via a closure approach. def _enable_tensor_ops() -> None: """Patch TensorHandle with arithmetic operators. Called once at module load. Operators create MathCmd entries via a thread-local TLContext reference set during kernel execution. """ import threading _local = threading.local() def set_active_context(ctx: TLContext | None) -> None: _local.ctx = ctx def get_active_context() -> TLContext: ctx = getattr(_local, "ctx", None) if ctx is None: raise RuntimeError("TensorHandle ops require an active TLContext") return ctx def _binop(op: str): def method(self: TensorHandle, other: TensorHandle) -> TensorHandle: ctx = get_active_context() return ctx._binary_math(op, self, other) return method # Patch TensorHandle class with operators TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined] TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined] TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined] TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined] # Expose context management TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined] TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined] _enable_tensor_ops() def run_kernel( kernel_fn, tl_ctx: TLContext, *args, **kwargs, ) -> list[PeCommand]: """Execute a kernel function with the given TLContext and return commands. Sets tl_ctx as the active context for TensorHandle operators, calls the kernel, then clears the context. """ TLContext._set_active(tl_ctx) # type: ignore[attr-defined] try: kernel_fn(*args, tl=tl_ctx, **kwargs) finally: TLContext._set_active(None) # type: ignore[attr-defined] return tl_ctx.commands