Files
kernbench2/src/kernbench/triton_emu/tl_context.py
T
2026-05-13 15:00:41 -07:00

693 lines
26 KiB
Python

"""TLContext: fake Triton Language module for kernel performance simulation.
Passed as the `tl` parameter to kernel functions. Each API call records a
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
the command list through SimPy.
Kernel code looks like standard Python — no yield, no async:
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
"""
from __future__ import annotations
import math
from typing import Literal
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCommand,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
_DTYPE_BYTES: dict[str, int] = {
"f16": 2, "f32": 4, "f64": 8,
"bf16": 2,
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
}
class TLContext:
"""Fake Triton Language context.
Args:
pe_id: program instance index (returned by program_id).
num_programs: total number of program instances.
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
"""
def __init__(
self,
pe_id: int = 0,
num_programs: int = 1,
dispatch_cycles: int = 1,
runner: Any = None,
cube_id: int = 0,
num_cubes: int = 1,
scratch_base: int = 0,
scratch_size: int = 1 << 20, # 1 MiB per kernel invocation
) -> None:
self._pe_id = pe_id
self._num_programs = num_programs
self._cube_id = cube_id
self._num_cubes = num_cubes
self._dispatch_cycles = dispatch_cycles
self._commands: list[PeCommand] = []
self._handle_counter = 0
self._completion_counter = 0
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
# PE-local scratch allocator for math/compute output handles.
# Each binary/unary/reduction op auto-allocates a unique addr from
# this pool so the resulting TensorHandle can be the source of a
# later tl.send / tl.store. Cursor resets on every kernel invocation.
self._scratch_base = scratch_base
self._scratch_size = scratch_size
self._scratch_cursor = 0
def _scratch_alloc(self, nbytes: int) -> int:
"""Allocate a unique scratch address for an output TensorHandle.
Returns 0 if no scratch base was configured (e.g. command-list mode);
in that case the resulting handle has addr=0 and cannot be used as a
send/store source. Greenlet/runner mode always supplies a base.
"""
if self._scratch_base == 0:
return 0
# 16-byte alignment
aligned = (nbytes + 15) & ~15
addr = self._scratch_base + self._scratch_cursor
self._scratch_cursor += aligned
if self._scratch_cursor > self._scratch_size:
raise RuntimeError(
f"TLContext scratch overflow: requested {nbytes}B, "
f"used {self._scratch_cursor}/{self._scratch_size}B"
)
return addr
@property
def commands(self) -> list[PeCommand]:
"""Return the recorded command trace."""
return self._commands
# ── helpers ────────────────────────────────────────────────────
def _next_handle_id(self) -> str:
self._handle_counter += 1
return f"t{self._handle_counter}"
def _next_completion_id(self) -> str:
self._completion_counter += 1
return f"c{self._completion_counter}"
def _dtype_bytes(self, dtype: str) -> int:
return _DTYPE_BYTES.get(dtype, 2)
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
return math.prod(shape) * self._dtype_bytes(dtype)
def _emit_dispatch_overhead(self) -> None:
if self._dispatch_cycles > 0:
self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, addr: int, shape: tuple[int, ...], dtype: str,
space: str = "tcm", pinned: bool = False,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
space=space,
pinned=pinned,
)
def _make_compute_out(
self, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
"""Allocate an output TensorHandle in PE-local scratch (TCM space).
Used by math/compute ops so the result has a real address that can
be the source of a later send/store. The data field stays None in
Phase 1 — Phase 2 DataExecutor fills the actual ndarray.
"""
nbytes = self._nbytes(shape, dtype)
addr = self._scratch_alloc(nbytes)
return TensorHandle(
id=self._next_handle_id(),
addr=addr, shape=shape, dtype=dtype,
nbytes=nbytes, space="tcm",
)
# ── Reference (no DMA, metadata only) ────────────────────────
def ref(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Create a TensorHandle referencing HBM data without issuing DMA.
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
def _emit(self, cmd: PeCommand) -> Any:
"""Emit command: greenlet switch if runner available, else append to list."""
if self._runner is not None:
return self._runner.switch_to_simpy(cmd)
self._commands.append(cmd)
return None
def load(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Load tensor from HBM. Returns TensorHandle pointing at HBM[ptr].
In greenlet mode: returns TensorHandle with actual numpy data.
In command-list mode: returns TensorHandle with data=None.
The returned handle's ``space`` is "hbm" so subsequent ops (math,
send, store) using this handle as a source resolve via MemoryStore
at ``(hbm, ptr)`` — which is where the load's underlying data
actually lives in Phase 2 storage.
"""
self._emit_dispatch_overhead()
handle = self._make_handle(
addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True,
)
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
data = self._emit(cmd)
if data is not None:
# Greenlet mode: attach real data to handle (preserve space + pinned)
return TensorHandle(
id=handle.id, addr=handle.addr, shape=handle.shape,
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
space=handle.space, pinned=handle.pinned,
)
return handle
def store(self, ptr: int, handle: TensorHandle) -> None:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes)
self._emit(cmd)
# ── GEMM Engine (blocking) ────────────────────────────────────
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
a: (M, K), b: (K, N) → out: (M, N)
"""
if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError("dot requires 2D tensors")
m, k = a.shape[-2], a.shape[-1]
k2, n = b.shape[-2], b.shape[-1]
if k != k2:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_compute_out(shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(x,), out=out))
return out
def exp(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("exp", x)
def log(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("log", x)
def sqrt(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sqrt", x)
def abs(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("abs", x)
def sigmoid(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sigmoid", x)
def cos(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("cos", x)
def sin(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sin", x)
# ── MATH Engine: reduction (blocking) ─────────────────────────
def _reduction(
self, op: str, x: TensorHandle, axis: int,
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_compute_out(shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("sum", x, axis)
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("max", x, axis)
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("min", x, axis)
# ── MATH Engine: binary (blocking) ────────────────────────────
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
return out
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Element-wise max of two tensors (real Triton: tl.maximum)."""
return self._binary_math("maximum", a, b)
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Element-wise min of two tensors (real Triton: tl.minimum)."""
return self._binary_math("minimum", a, b)
def fma(
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
) -> TensorHandle:
"""Fused multiply-add: a * b + c (real Triton: tl.fma)."""
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="fma", inputs=(a, b, c), out=out))
return out
def clamp(
self,
x: TensorHandle,
min: TensorHandle,
max: TensorHandle,
) -> TensorHandle:
"""Clamp x to [min, max] (real Triton: tl.clamp)."""
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="clamp", inputs=(x, min, max), out=out))
return out
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
"""Numerically-stable softmax along ``axis`` (real Triton: tl.softmax).
Implemented as a single MathCmd (op="softmax") so timing accounts
for one MATH dispatch; Phase 2 DataExecutor expands it to the
canonical (x - max) → exp → sum → div sequence.
"""
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="softmax", inputs=(x,), out=out, axis=axis))
return out
# ── Scalar helpers (real Triton: tl.cdiv etc.) ────────────────
@staticmethod
def cdiv(a: int, b: int) -> int:
"""Ceiling division: (a + b - 1) // b (real Triton: tl.cdiv).
Used by host/kernel grid math; not a tensor op, so no MathCmd
is emitted. Mirrors triton.cdiv.
"""
return -(-int(a) // int(b))
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int:
"""Return program instance index (ADR-0022).
axis=0: local PE id within cube.
axis=1: cube id.
"""
if axis == 1:
return self._cube_id
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances (ADR-0022).
axis=0: num PEs per cube.
axis=1: num cubes.
"""
if axis == 1:
return self._num_cubes
return self._num_programs
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(addr=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
def trans(self, x: TensorHandle) -> TensorHandle:
"""Transpose — shape change only, no command generated."""
if len(x.shape) < 2:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, addr=x.addr, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
# ── IPCQ (CCL) collective primitives (ADR-0023 D4) ────────────
def send(
self,
dir: str,
src: TensorHandle | None = None,
*,
src_addr: int | None = None,
nbytes: int | None = None,
shape: tuple[int, ...] | None = None,
dtype: str = "f16",
space: str = "tcm",
) -> None:
"""Send tensor data to the peer in the given direction.
Two calling forms:
tl.send(dir, handle) # use handle's metadata
tl.send(dir, src_addr=..., nbytes=..., shape=..., dtype=..., space=...)
Blocking: returns when PE_IPCQ has accepted the request and
forwarded the IpcqDmaToken to PE_DMA. Backpressure may apply.
"""
if src is not None:
src_addr = src.addr
nbytes = src.nbytes
shape = src.shape
dtype = src.dtype
space = getattr(src, "space", space)
if src_addr is None or nbytes is None or shape is None:
raise ValueError("tl.send: provide either a TensorHandle or src_addr/nbytes/shape")
# Carry the handle's .data snapshot (if available). When the source
# is a recv slot, .data holds the numpy array that was read from
# MemoryStore at recv-time. This prevents a Phase 1 race where a
# later IPCQ inbound overwrites the slot before the outbound
# PE_DMA reads it.
handle_data = getattr(src, "data", None) if src is not None else None
self._emit_dispatch_overhead()
cmd = IpcqSendCmd(
direction=dir,
src_addr=src_addr, src_space=space,
nbytes=nbytes, shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
data=handle_data,
)
self._emit(cmd)
def recv(
self,
dir: str | None = None,
shape: tuple[int, ...] = (),
dtype: str = "f16",
space: str = "tcm",
dst_addr: int | None = None,
dst_space: str | None = None,
) -> TensorHandle:
"""Receive tensor data from a peer.
Args:
dir: specific direction (e.g. "W"), or None for round-robin.
shape, dtype: expected tensor metadata.
dst_addr / dst_space: if both are provided, the slot data is
copied to (dst_space, dst_addr) before the handle is
returned ("copy_to_dst" mode). Otherwise the slot address
is returned directly ("return_slot" mode).
Returns:
TensorHandle pointing to the slot (or dst) where the data has
arrived. In greenlet/runner mode, ``handle.data`` carries the
actual ndarray; in command-list mode the handle is a placeholder.
"""
self._emit_dispatch_overhead()
if dst_addr is not None and dst_space is not None:
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
recv_mode="copy_to_dst",
dst_addr=dst_addr, dst_space=dst_space,
)
else:
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
)
result = self._emit(cmd)
if isinstance(result, dict):
slot_addr = int(result.get("src_addr", 0))
slot_space = str(result.get("src_space", "tcm"))
data = result.get("data")
return TensorHandle(
id=self._next_handle_id(),
addr=slot_addr,
shape=shape,
dtype=dtype,
nbytes=self._nbytes(shape, dtype),
data=data,
space=slot_space,
)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def recv_no_consume(
self,
dir: str | None = None,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> TensorHandle:
"""DIAGNOSTIC ONLY — recv that blocks for arrival but skips slot read.
Same blocking semantics as ``tl.recv``: the kernel waits until
the payload has landed in the IPCQ slot. Differs from ``tl.recv``
by skipping the slot-read latency charge (slot-IO + PE↔bank
fabric drain) on DST.
This entry point exists solely so the pe2pe overview plot can
draw an apples-to-apples comparison against ``tl.store`` (a
one-sided fabric write that pays no read on DST). Production
kernels MUST use ``tl.recv`` — they need to consume the data
they receive. This API is segregated from ``tl.recv`` so the
diagnostic flag can never accidentally be set in real workloads.
"""
self._emit_dispatch_overhead()
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
consume=False,
)
result = self._emit(cmd) # type: ignore[arg-type]
if isinstance(result, dict):
slot_addr = int(result.get("src_addr", 0))
slot_space = str(result.get("src_space", "tcm"))
return TensorHandle(
id=self._next_handle_id(),
addr=slot_addr,
shape=shape,
dtype=dtype,
nbytes=self._nbytes(shape, dtype),
data=None,
space=slot_space,
)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def recv_async(
self,
dir: str,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> "RecvFuture":
"""Non-blocking recv. Returns a future to pass into ``tl.wait``."""
self._emit_dispatch_overhead()
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
blocking=False,
)
future = RecvFuture(cmd=cmd)
if self._runner is not None:
self._runner.switch_to_simpy(("recv_async", future))
return future
# ── Composite + Control ───────────────────────────────────────
def composite(
self,
op: Literal["gemm", "math"],
a: TensorHandle,
b: TensorHandle | None = None,
out_ptr: int = 0,
math_op: str | None = None,
) -> CompletionHandle:
"""Submit a composite command (non-blocking, tiled pipeline).
Returns CompletionHandle for use with wait().
"""
# Compute output size based on op
if op == "gemm" and b is not None:
m, k = a.shape[-2], a.shape[-1]
n = b.shape[-1]
out_dtype = a.dtype
out_nbytes = m * n * self._dtype_bytes(out_dtype)
else:
out_nbytes = a.nbytes
completion = CompletionHandle(id=self._next_completion_id())
self._emit_dispatch_overhead()
self._emit(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
"""Wait for a composite, a recv future, or all pending composites.
- ``CompletionHandle`` (or None): wait for composite completion.
- ``RecvFuture``: wait for a non-blocking ``recv_async`` to finish.
Returns the resolved ``TensorHandle``.
"""
if isinstance(handle, RecvFuture):
if handle.resolved:
return handle.result
if self._runner is None:
raise RuntimeError(
"tl.wait(RecvFuture) requires runner mode (greenlet)"
)
result_dict = self._runner.switch_to_simpy(("recv_wait", handle))
slot_addr = int(result_dict.get("src_addr", 0))
slot_space = str(result_dict.get("src_space", "tcm"))
data = result_dict.get("data")
th = TensorHandle(
id=self._next_handle_id(),
addr=slot_addr,
shape=handle.cmd.shape,
dtype=handle.cmd.dtype,
nbytes=self._nbytes(handle.cmd.shape, handle.cmd.dtype),
data=data,
space=slot_space,
)
handle.resolved = True
handle.result = th
return th
# Composite path (existing behaviour)
self._emit(WaitCmd(handle=handle))
return None
def cycles(self, n: int) -> None:
"""Declare PE_CPU scalar execution overhead (cycles)."""
self._emit(PeCpuOverheadCmd(cycles=n))
# ── TensorHandle arithmetic operators ─────────────────────────────
# Enables: a + b, a * b, a - b, a / b in kernel code.
# Each creates a MathCmd via a module-level helper that requires a
# TLContext. We attach the context to handles via a closure approach.
def _enable_tensor_ops() -> None:
"""Patch TensorHandle with arithmetic operators.
Called once at module load. Operators create MathCmd entries via
a thread-local TLContext reference set during kernel execution.
"""
import threading
_local = threading.local()
def set_active_context(ctx: TLContext | None) -> None:
_local.ctx = ctx
def get_active_context() -> TLContext:
ctx = getattr(_local, "ctx", None)
if ctx is None:
raise RuntimeError("TensorHandle ops require an active TLContext")
return ctx
def _binop(op: str):
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
ctx = get_active_context()
return ctx._binary_math(op, self, other)
return method
# Patch TensorHandle class with operators
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
# Expose context management
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
_enable_tensor_ops()
def run_kernel(
kernel_fn,
tl_ctx: TLContext,
*args,
**kwargs,
) -> list[PeCommand]:
"""Execute a kernel function with the given TLContext and return commands.
Sets tl_ctx as the active context for TensorHandle operators,
calls the kernel, then clears the context.
"""
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
try:
kernel_fn(*args, tl=tl_ctx, **kwargs)
finally:
TLContext._set_active(None) # type: ignore[attr-defined]
return tl_ctx.commands