commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+11
View File
@@ -0,0 +1,11 @@
"""Triton emulator: fake tl module for kernel performance simulation.
Provides TLContext (the fake `tl` parameter) that kernels use to express
memory access patterns and compute operations. Kernel functions are plain
Python — no yield, no async — and generate a PeCommand trace that PE_CPU
replays through SimPy.
Usage:
from kernbench.triton_emu.registry import register_kernel, get_kernel
from kernbench.triton_emu.tl_context import TLContext
"""
+30
View File
@@ -0,0 +1,30 @@
"""Kernel registry: maps kernel names to Python callable generators.
Benchmarks register kernel functions here; PE_CPU looks them up by
KernelRef.name at execution time.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
_kernels: dict[str, Callable[..., None]] = {}
def register_kernel(name: str, fn: Callable[..., None]) -> None:
"""Register a kernel function by name."""
if name in _kernels:
raise ValueError(f"kernel '{name}' already registered")
_kernels[name] = fn
def get_kernel(name: str) -> Callable[..., None]:
"""Look up a registered kernel function by name."""
if name not in _kernels:
raise KeyError(f"kernel '{name}' not registered")
return _kernels[name]
def clear_registry() -> None:
"""Clear all registered kernels (for testing)."""
_kernels.clear()
+356
View File
@@ -0,0 +1,356 @@
"""TLContext: fake Triton Language module for kernel performance simulation.
Passed as the `tl` parameter to kernel functions. Each API call records a
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
the command list through SimPy.
Kernel code looks like standard Python — no yield, no async:
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
"""
from __future__ import annotations
import math
from typing import Literal
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCommand,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
_DTYPE_BYTES: dict[str, int] = {
"f16": 2, "f32": 4, "f64": 8,
"bf16": 2,
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
}
class TLContext:
"""Fake Triton Language context.
Args:
pe_id: program instance index (returned by program_id).
num_programs: total number of program instances.
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
"""
def __init__(
self,
pe_id: int = 0,
num_programs: int = 1,
dispatch_cycles: int = 1,
) -> None:
self._pe_id = pe_id
self._num_programs = num_programs
self._dispatch_cycles = dispatch_cycles
self._commands: list[PeCommand] = []
self._handle_counter = 0
self._completion_counter = 0
@property
def commands(self) -> list[PeCommand]:
"""Return the recorded command trace."""
return self._commands
# ── helpers ────────────────────────────────────────────────────
def _next_handle_id(self) -> str:
self._handle_counter += 1
return f"t{self._handle_counter}"
def _next_completion_id(self) -> str:
self._completion_counter += 1
return f"c{self._completion_counter}"
def _dtype_bytes(self, dtype: str) -> int:
return _DTYPE_BYTES.get(dtype, 2)
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
return math.prod(shape) * self._dtype_bytes(dtype)
def _emit_dispatch_overhead(self) -> None:
if self._dispatch_cycles > 0:
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, pa: int, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
pa=pa, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
)
# ── Reference (no DMA, metadata only) ────────────────────────
def ref(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Create a TensorHandle referencing HBM data without issuing DMA.
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
def load(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle."""
self._emit_dispatch_overhead()
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
self._commands.append(DmaReadCmd(
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
))
return handle
def store(self, ptr: int, handle: TensorHandle) -> None:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
self._commands.append(DmaWriteCmd(
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
))
# ── GEMM Engine (blocking) ────────────────────────────────────
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
a: (M, K), b: (K, N) → out: (M, N)
"""
if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError("dot requires 2D tensors")
m, k = a.shape[-2], a.shape[-1]
k2, n = b.shape[-2], b.shape[-1]
if k != k2:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
return out
def exp(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("exp", x)
def log(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("log", x)
def sqrt(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sqrt", x)
def abs(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("abs", x)
def sigmoid(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sigmoid", x)
def cos(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("cos", x)
def sin(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sin", x)
# ── MATH Engine: reduction (blocking) ─────────────────────────
def _reduction(
self, op: str, x: TensorHandle, axis: int,
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("sum", x, axis)
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("max", x, axis)
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("min", x, axis)
# ── MATH Engine: binary (blocking) ────────────────────────────
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
return out
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int:
"""Return program instance index."""
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances."""
return self._num_programs
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
def trans(self, x: TensorHandle) -> TensorHandle:
"""Transpose — shape change only, no command generated."""
if len(x.shape) < 2:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, pa=x.pa, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
# ── Composite + Control ───────────────────────────────────────
def composite(
self,
op: Literal["gemm", "math"],
a: TensorHandle,
b: TensorHandle | None = None,
out_ptr: int = 0,
math_op: str | None = None,
) -> CompletionHandle:
"""Submit a composite command (non-blocking, tiled pipeline).
Returns CompletionHandle for use with wait().
"""
# Compute output size based on op
if op == "gemm" and b is not None:
m, k = a.shape[-2], a.shape[-1]
n = b.shape[-1]
out_dtype = a.dtype
out_nbytes = m * n * self._dtype_bytes(out_dtype)
else:
out_nbytes = a.nbytes
completion = CompletionHandle(id=self._next_completion_id())
self._emit_dispatch_overhead()
self._commands.append(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion
def wait(self, handle: CompletionHandle | None = None) -> None:
"""Wait for a specific composite or all pending composites."""
self._commands.append(WaitCmd(handle=handle))
def cycles(self, n: int) -> None:
"""Declare PE_CPU scalar execution overhead (cycles)."""
self._commands.append(PeCpuOverheadCmd(cycles=n))
# ── TensorHandle arithmetic operators ─────────────────────────────
# Enables: a + b, a * b, a - b, a / b in kernel code.
# Each creates a MathCmd via a module-level helper that requires a
# TLContext. We attach the context to handles via a closure approach.
def _enable_tensor_ops() -> None:
"""Patch TensorHandle with arithmetic operators.
Called once at module load. Operators create MathCmd entries via
a thread-local TLContext reference set during kernel execution.
"""
import threading
_local = threading.local()
def set_active_context(ctx: TLContext | None) -> None:
_local.ctx = ctx
def get_active_context() -> TLContext:
ctx = getattr(_local, "ctx", None)
if ctx is None:
raise RuntimeError("TensorHandle ops require an active TLContext")
return ctx
def _binop(op: str):
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
ctx = get_active_context()
return ctx._binary_math(op, self, other)
return method
# Patch TensorHandle class with operators
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
# Expose context management
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
_enable_tensor_ops()
def run_kernel(
kernel_fn,
tl_ctx: TLContext,
*args,
**kwargs,
) -> list[PeCommand]:
"""Execute a kernel function with the given TLContext and return commands.
Sets tl_ctx as the active context for TensorHandle operators,
calls the kernel, then clears the context.
"""
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
try:
kernel_fn(*args, tl=tl_ctx, **kwargs)
finally:
TLContext._set_active(None) # type: ignore[attr-defined]
return tl_ctx.commands