commit - release 1
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
"""Triton emulator: fake tl module for kernel performance simulation.
|
||||
|
||||
Provides TLContext (the fake `tl` parameter) that kernels use to express
|
||||
memory access patterns and compute operations. Kernel functions are plain
|
||||
Python — no yield, no async — and generate a PeCommand trace that PE_CPU
|
||||
replays through SimPy.
|
||||
|
||||
Usage:
|
||||
from kernbench.triton_emu.registry import register_kernel, get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
"""
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Kernel registry: maps kernel names to Python callable generators.
|
||||
|
||||
Benchmarks register kernel functions here; PE_CPU looks them up by
|
||||
KernelRef.name at execution time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
_kernels: dict[str, Callable[..., None]] = {}
|
||||
|
||||
|
||||
def register_kernel(name: str, fn: Callable[..., None]) -> None:
|
||||
"""Register a kernel function by name."""
|
||||
if name in _kernels:
|
||||
raise ValueError(f"kernel '{name}' already registered")
|
||||
_kernels[name] = fn
|
||||
|
||||
|
||||
def get_kernel(name: str) -> Callable[..., None]:
|
||||
"""Look up a registered kernel function by name."""
|
||||
if name not in _kernels:
|
||||
raise KeyError(f"kernel '{name}' not registered")
|
||||
return _kernels[name]
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
"""Clear all registered kernels (for testing)."""
|
||||
_kernels.clear()
|
||||
@@ -0,0 +1,356 @@
|
||||
"""TLContext: fake Triton Language module for kernel performance simulation.
|
||||
|
||||
Passed as the `tl` parameter to kernel functions. Each API call records a
|
||||
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
|
||||
the command list through SimPy.
|
||||
|
||||
Kernel code looks like standard Python — no yield, no async:
|
||||
|
||||
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||
pid = tl.program_id(0)
|
||||
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
|
||||
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
PeCommand,
|
||||
PeCpuOverheadCmd,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
|
||||
_DTYPE_BYTES: dict[str, int] = {
|
||||
"f16": 2, "f32": 4, "f64": 8,
|
||||
"bf16": 2,
|
||||
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
|
||||
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
|
||||
}
|
||||
|
||||
|
||||
class TLContext:
|
||||
"""Fake Triton Language context.
|
||||
|
||||
Args:
|
||||
pe_id: program instance index (returned by program_id).
|
||||
num_programs: total number of program instances.
|
||||
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pe_id: int = 0,
|
||||
num_programs: int = 1,
|
||||
dispatch_cycles: int = 1,
|
||||
) -> None:
|
||||
self._pe_id = pe_id
|
||||
self._num_programs = num_programs
|
||||
self._dispatch_cycles = dispatch_cycles
|
||||
self._commands: list[PeCommand] = []
|
||||
self._handle_counter = 0
|
||||
self._completion_counter = 0
|
||||
|
||||
@property
|
||||
def commands(self) -> list[PeCommand]:
|
||||
"""Return the recorded command trace."""
|
||||
return self._commands
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────
|
||||
|
||||
def _next_handle_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"t{self._handle_counter}"
|
||||
|
||||
def _next_completion_id(self) -> str:
|
||||
self._completion_counter += 1
|
||||
return f"c{self._completion_counter}"
|
||||
|
||||
def _dtype_bytes(self, dtype: str) -> int:
|
||||
return _DTYPE_BYTES.get(dtype, 2)
|
||||
|
||||
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
|
||||
return math.prod(shape) * self._dtype_bytes(dtype)
|
||||
|
||||
def _emit_dispatch_overhead(self) -> None:
|
||||
if self._dispatch_cycles > 0:
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||
|
||||
def _make_handle(
|
||||
self, pa: int, shape: tuple[int, ...], dtype: str,
|
||||
) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
pa=pa, shape=shape, dtype=dtype,
|
||||
nbytes=self._nbytes(shape, dtype),
|
||||
)
|
||||
|
||||
# ── Reference (no DMA, metadata only) ────────────────────────
|
||||
|
||||
def ref(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Create a TensorHandle referencing HBM data without issuing DMA.
|
||||
|
||||
Used when the scheduler will stream data per-tile (e.g., tensor b
|
||||
in a composite GEMM). No command is generated.
|
||||
"""
|
||||
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
||||
|
||||
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
||||
|
||||
def load(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
||||
self._emit_dispatch_overhead()
|
||||
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
||||
self._commands.append(DmaReadCmd(
|
||||
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
return handle
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
"""Store tensor from TCM to HBM."""
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(DmaWriteCmd(
|
||||
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
|
||||
# ── GEMM Engine (blocking) ────────────────────────────────────
|
||||
|
||||
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
|
||||
|
||||
a: (M, K), b: (K, N) → out: (M, N)
|
||||
"""
|
||||
if len(a.shape) < 2 or len(b.shape) < 2:
|
||||
raise ValueError("dot requires 2D tensors")
|
||||
m, k = a.shape[-2], a.shape[-1]
|
||||
k2, n = b.shape[-2], b.shape[-1]
|
||||
if k != k2:
|
||||
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
||||
out_shape = (*a.shape[:-2], m, n)
|
||||
out_dtype = a.dtype
|
||||
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||
return out
|
||||
|
||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
||||
return out
|
||||
|
||||
def exp(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("exp", x)
|
||||
|
||||
def log(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("log", x)
|
||||
|
||||
def sqrt(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sqrt", x)
|
||||
|
||||
def abs(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("abs", x)
|
||||
|
||||
def sigmoid(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sigmoid", x)
|
||||
|
||||
def cos(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("cos", x)
|
||||
|
||||
def sin(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sin", x)
|
||||
|
||||
# ── MATH Engine: reduction (blocking) ─────────────────────────
|
||||
|
||||
def _reduction(
|
||||
self, op: str, x: TensorHandle, axis: int,
|
||||
) -> TensorHandle:
|
||||
out_shape = list(x.shape)
|
||||
out_shape[axis] = 1
|
||||
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||
return out
|
||||
|
||||
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("sum", x, axis)
|
||||
|
||||
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("max", x, axis)
|
||||
|
||||
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("min", x, axis)
|
||||
|
||||
# ── MATH Engine: binary (blocking) ────────────────────────────
|
||||
|
||||
def _binary_math(
|
||||
self, op: str, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
||||
return out
|
||||
|
||||
def where(
|
||||
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||
return out
|
||||
|
||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
"""Return program instance index."""
|
||||
return self._pe_id
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
"""Return total number of program instances."""
|
||||
return self._num_programs
|
||||
|
||||
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
||||
"""Create index range tensor in TCM."""
|
||||
n = end - start
|
||||
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
|
||||
|
||||
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
"""Create zero-filled tensor in TCM."""
|
||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
||||
|
||||
def full(
|
||||
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Create constant-filled tensor in TCM."""
|
||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
||||
|
||||
# ── Metadata (no compute, no DMA) ─────────────────────────────
|
||||
|
||||
def trans(self, x: TensorHandle) -> TensorHandle:
|
||||
"""Transpose — shape change only, no command generated."""
|
||||
if len(x.shape) < 2:
|
||||
raise ValueError("trans requires at least 2D tensor")
|
||||
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
|
||||
return TensorHandle(
|
||||
id=x.id, pa=x.pa, shape=new_shape,
|
||||
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
||||
)
|
||||
|
||||
# ── Composite + Control ───────────────────────────────────────
|
||||
|
||||
def composite(
|
||||
self,
|
||||
op: Literal["gemm", "math"],
|
||||
a: TensorHandle,
|
||||
b: TensorHandle | None = None,
|
||||
out_ptr: int = 0,
|
||||
math_op: str | None = None,
|
||||
) -> CompletionHandle:
|
||||
"""Submit a composite command (non-blocking, tiled pipeline).
|
||||
|
||||
Returns CompletionHandle for use with wait().
|
||||
"""
|
||||
# Compute output size based on op
|
||||
if op == "gemm" and b is not None:
|
||||
m, k = a.shape[-2], a.shape[-1]
|
||||
n = b.shape[-1]
|
||||
out_dtype = a.dtype
|
||||
out_nbytes = m * n * self._dtype_bytes(out_dtype)
|
||||
else:
|
||||
out_nbytes = a.nbytes
|
||||
|
||||
completion = CompletionHandle(id=self._next_completion_id())
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(CompositeCmd(
|
||||
completion=completion, op=op,
|
||||
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
|
||||
math_op=math_op,
|
||||
))
|
||||
return completion
|
||||
|
||||
def wait(self, handle: CompletionHandle | None = None) -> None:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
self._commands.append(WaitCmd(handle=handle))
|
||||
|
||||
def cycles(self, n: int) -> None:
|
||||
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=n))
|
||||
|
||||
|
||||
# ── TensorHandle arithmetic operators ─────────────────────────────
|
||||
# Enables: a + b, a * b, a - b, a / b in kernel code.
|
||||
# Each creates a MathCmd via a module-level helper that requires a
|
||||
# TLContext. We attach the context to handles via a closure approach.
|
||||
|
||||
|
||||
def _enable_tensor_ops() -> None:
|
||||
"""Patch TensorHandle with arithmetic operators.
|
||||
|
||||
Called once at module load. Operators create MathCmd entries via
|
||||
a thread-local TLContext reference set during kernel execution.
|
||||
"""
|
||||
import threading
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
def set_active_context(ctx: TLContext | None) -> None:
|
||||
_local.ctx = ctx
|
||||
|
||||
def get_active_context() -> TLContext:
|
||||
ctx = getattr(_local, "ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError("TensorHandle ops require an active TLContext")
|
||||
return ctx
|
||||
|
||||
def _binop(op: str):
|
||||
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
|
||||
ctx = get_active_context()
|
||||
return ctx._binary_math(op, self, other)
|
||||
return method
|
||||
|
||||
# Patch TensorHandle class with operators
|
||||
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
|
||||
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
|
||||
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
|
||||
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
|
||||
|
||||
# Expose context management
|
||||
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
|
||||
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_enable_tensor_ops()
|
||||
|
||||
|
||||
def run_kernel(
|
||||
kernel_fn,
|
||||
tl_ctx: TLContext,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[PeCommand]:
|
||||
"""Execute a kernel function with the given TLContext and return commands.
|
||||
|
||||
Sets tl_ctx as the active context for TensorHandle operators,
|
||||
calls the kernel, then clears the context.
|
||||
"""
|
||||
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(*args, tl=tl_ctx, **kwargs)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
return tl_ctx.commands
|
||||
Reference in New Issue
Block a user