83ea97b05f
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
693 lines
26 KiB
Python
693 lines
26 KiB
Python
"""TLContext: fake Triton Language module for kernel performance simulation.
|
|
|
|
Passed as the `tl` parameter to kernel functions. Each API call records a
|
|
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
|
|
the command list through SimPy.
|
|
|
|
Kernel code looks like standard Python — no yield, no async:
|
|
|
|
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
|
|
pid = tl.program_id(0)
|
|
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
|
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
|
|
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import Literal
|
|
|
|
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
|
|
from kernbench.common.pe_commands import (
|
|
CompletionHandle,
|
|
CompositeCmd,
|
|
DmaReadCmd,
|
|
DmaWriteCmd,
|
|
GemmCmd,
|
|
MathCmd,
|
|
PeCommand,
|
|
PeCpuOverheadCmd,
|
|
TensorHandle,
|
|
WaitCmd,
|
|
)
|
|
|
|
_DTYPE_BYTES: dict[str, int] = {
|
|
"f16": 2, "f32": 4, "f64": 8,
|
|
"bf16": 2,
|
|
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
|
|
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
|
|
}
|
|
|
|
|
|
class TLContext:
|
|
"""Fake Triton Language context.
|
|
|
|
Args:
|
|
pe_id: program instance index (returned by program_id).
|
|
num_programs: total number of program instances.
|
|
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pe_id: int = 0,
|
|
num_programs: int = 1,
|
|
dispatch_cycles: int = 1,
|
|
runner: Any = None,
|
|
cube_id: int = 0,
|
|
num_cubes: int = 1,
|
|
scratch_base: int = 0,
|
|
scratch_size: int = 1 << 20, # 1 MiB per kernel invocation
|
|
) -> None:
|
|
self._pe_id = pe_id
|
|
self._num_programs = num_programs
|
|
self._cube_id = cube_id
|
|
self._num_cubes = num_cubes
|
|
self._dispatch_cycles = dispatch_cycles
|
|
self._commands: list[PeCommand] = []
|
|
self._handle_counter = 0
|
|
self._completion_counter = 0
|
|
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
|
|
# PE-local scratch allocator for math/compute output handles.
|
|
# Each binary/unary/reduction op auto-allocates a unique addr from
|
|
# this pool so the resulting TensorHandle can be the source of a
|
|
# later tl.send / tl.store. Cursor resets on every kernel invocation.
|
|
self._scratch_base = scratch_base
|
|
self._scratch_size = scratch_size
|
|
self._scratch_cursor = 0
|
|
|
|
def _scratch_alloc(self, nbytes: int) -> int:
|
|
"""Allocate a unique scratch address for an output TensorHandle.
|
|
|
|
Returns 0 if no scratch base was configured (e.g. command-list mode);
|
|
in that case the resulting handle has addr=0 and cannot be used as a
|
|
send/store source. Greenlet/runner mode always supplies a base.
|
|
"""
|
|
if self._scratch_base == 0:
|
|
return 0
|
|
# 16-byte alignment
|
|
aligned = (nbytes + 15) & ~15
|
|
addr = self._scratch_base + self._scratch_cursor
|
|
self._scratch_cursor += aligned
|
|
if self._scratch_cursor > self._scratch_size:
|
|
raise RuntimeError(
|
|
f"TLContext scratch overflow: requested {nbytes}B, "
|
|
f"used {self._scratch_cursor}/{self._scratch_size}B"
|
|
)
|
|
return addr
|
|
|
|
@property
|
|
def commands(self) -> list[PeCommand]:
|
|
"""Return the recorded command trace."""
|
|
return self._commands
|
|
|
|
# ── helpers ────────────────────────────────────────────────────
|
|
|
|
def _next_handle_id(self) -> str:
|
|
self._handle_counter += 1
|
|
return f"t{self._handle_counter}"
|
|
|
|
def _next_completion_id(self) -> str:
|
|
self._completion_counter += 1
|
|
return f"c{self._completion_counter}"
|
|
|
|
def _dtype_bytes(self, dtype: str) -> int:
|
|
return _DTYPE_BYTES.get(dtype, 2)
|
|
|
|
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
|
|
return math.prod(shape) * self._dtype_bytes(dtype)
|
|
|
|
def _emit_dispatch_overhead(self) -> None:
|
|
if self._dispatch_cycles > 0:
|
|
self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
|
|
|
def _make_handle(
|
|
self, addr: int, shape: tuple[int, ...], dtype: str,
|
|
space: str = "tcm", pinned: bool = False,
|
|
) -> TensorHandle:
|
|
return TensorHandle(
|
|
id=self._next_handle_id(),
|
|
addr=addr, shape=shape, dtype=dtype,
|
|
nbytes=self._nbytes(shape, dtype),
|
|
space=space,
|
|
pinned=pinned,
|
|
)
|
|
|
|
def _make_compute_out(
|
|
self, shape: tuple[int, ...], dtype: str,
|
|
) -> TensorHandle:
|
|
"""Allocate an output TensorHandle in PE-local scratch (TCM space).
|
|
|
|
Used by math/compute ops so the result has a real address that can
|
|
be the source of a later send/store. The data field stays None in
|
|
Phase 1 — Phase 2 DataExecutor fills the actual ndarray.
|
|
"""
|
|
nbytes = self._nbytes(shape, dtype)
|
|
addr = self._scratch_alloc(nbytes)
|
|
return TensorHandle(
|
|
id=self._next_handle_id(),
|
|
addr=addr, shape=shape, dtype=dtype,
|
|
nbytes=nbytes, space="tcm",
|
|
)
|
|
|
|
# ── Reference (no DMA, metadata only) ────────────────────────
|
|
|
|
def ref(
|
|
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
|
) -> TensorHandle:
|
|
"""Create a TensorHandle referencing HBM data without issuing DMA.
|
|
|
|
Used when the scheduler will stream data per-tile (e.g., tensor b
|
|
in a composite GEMM). No command is generated.
|
|
"""
|
|
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
|
|
|
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
|
|
|
def _emit(self, cmd: PeCommand) -> Any:
|
|
"""Emit command: greenlet switch if runner available, else append to list."""
|
|
if self._runner is not None:
|
|
return self._runner.switch_to_simpy(cmd)
|
|
self._commands.append(cmd)
|
|
return None
|
|
|
|
def load(
|
|
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
|
) -> TensorHandle:
|
|
"""Load tensor from HBM. Returns TensorHandle pointing at HBM[ptr].
|
|
|
|
In greenlet mode: returns TensorHandle with actual numpy data.
|
|
In command-list mode: returns TensorHandle with data=None.
|
|
|
|
The returned handle's ``space`` is "hbm" so subsequent ops (math,
|
|
send, store) using this handle as a source resolve via MemoryStore
|
|
at ``(hbm, ptr)`` — which is where the load's underlying data
|
|
actually lives in Phase 2 storage.
|
|
"""
|
|
self._emit_dispatch_overhead()
|
|
handle = self._make_handle(
|
|
addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True,
|
|
)
|
|
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
|
data = self._emit(cmd)
|
|
if data is not None:
|
|
# Greenlet mode: attach real data to handle (preserve space + pinned)
|
|
return TensorHandle(
|
|
id=handle.id, addr=handle.addr, shape=handle.shape,
|
|
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
|
space=handle.space, pinned=handle.pinned,
|
|
)
|
|
return handle
|
|
|
|
def store(self, ptr: int, handle: TensorHandle) -> None:
|
|
"""Store tensor from TCM to HBM."""
|
|
self._emit_dispatch_overhead()
|
|
cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes)
|
|
self._emit(cmd)
|
|
|
|
# ── GEMM Engine (blocking) ────────────────────────────────────
|
|
|
|
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
|
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
|
|
|
|
a: (M, K), b: (K, N) → out: (M, N)
|
|
"""
|
|
if len(a.shape) < 2 or len(b.shape) < 2:
|
|
raise ValueError("dot requires 2D tensors")
|
|
m, k = a.shape[-2], a.shape[-1]
|
|
k2, n = b.shape[-2], b.shape[-1]
|
|
if k != k2:
|
|
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
|
out_shape = (*a.shape[:-2], m, n)
|
|
out_dtype = a.dtype
|
|
out = self._make_compute_out(shape=out_shape, dtype=out_dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
|
return out
|
|
|
|
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
|
|
|
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
|
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op=op, inputs=(x,), out=out))
|
|
return out
|
|
|
|
def exp(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("exp", x)
|
|
|
|
def log(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("log", x)
|
|
|
|
def sqrt(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("sqrt", x)
|
|
|
|
def abs(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("abs", x)
|
|
|
|
def sigmoid(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("sigmoid", x)
|
|
|
|
def cos(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("cos", x)
|
|
|
|
def sin(self, x: TensorHandle) -> TensorHandle:
|
|
return self._unary_math("sin", x)
|
|
|
|
# ── MATH Engine: reduction (blocking) ─────────────────────────
|
|
|
|
def _reduction(
|
|
self, op: str, x: TensorHandle, axis: int,
|
|
) -> TensorHandle:
|
|
out_shape = list(x.shape)
|
|
out_shape[axis] = 1
|
|
out = self._make_compute_out(shape=tuple(out_shape), dtype=x.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
|
return out
|
|
|
|
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
|
return self._reduction("sum", x, axis)
|
|
|
|
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
|
|
return self._reduction("max", x, axis)
|
|
|
|
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
|
|
return self._reduction("min", x, axis)
|
|
|
|
# ── MATH Engine: binary (blocking) ────────────────────────────
|
|
|
|
def _binary_math(
|
|
self, op: str, a: TensorHandle, b: TensorHandle,
|
|
) -> TensorHandle:
|
|
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
|
|
return out
|
|
|
|
def where(
|
|
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
|
) -> TensorHandle:
|
|
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
|
return out
|
|
|
|
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
|
"""Element-wise max of two tensors (real Triton: tl.maximum)."""
|
|
return self._binary_math("maximum", a, b)
|
|
|
|
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
|
"""Element-wise min of two tensors (real Triton: tl.minimum)."""
|
|
return self._binary_math("minimum", a, b)
|
|
|
|
def fma(
|
|
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
|
) -> TensorHandle:
|
|
"""Fused multiply-add: a * b + c (real Triton: tl.fma)."""
|
|
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op="fma", inputs=(a, b, c), out=out))
|
|
return out
|
|
|
|
def clamp(
|
|
self,
|
|
x: TensorHandle,
|
|
min: TensorHandle,
|
|
max: TensorHandle,
|
|
) -> TensorHandle:
|
|
"""Clamp x to [min, max] (real Triton: tl.clamp)."""
|
|
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op="clamp", inputs=(x, min, max), out=out))
|
|
return out
|
|
|
|
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
|
"""Numerically-stable softmax along ``axis`` (real Triton: tl.softmax).
|
|
|
|
Implemented as a single MathCmd (op="softmax") so timing accounts
|
|
for one MATH dispatch; Phase 2 DataExecutor expands it to the
|
|
canonical (x - max) → exp → sum → div sequence.
|
|
"""
|
|
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
|
self._emit_dispatch_overhead()
|
|
self._emit(MathCmd(op="softmax", inputs=(x,), out=out, axis=axis))
|
|
return out
|
|
|
|
# ── Scalar helpers (real Triton: tl.cdiv etc.) ────────────────
|
|
|
|
@staticmethod
|
|
def cdiv(a: int, b: int) -> int:
|
|
"""Ceiling division: (a + b - 1) // b (real Triton: tl.cdiv).
|
|
|
|
Used by host/kernel grid math; not a tensor op, so no MathCmd
|
|
is emitted. Mirrors triton.cdiv.
|
|
"""
|
|
return -(-int(a) // int(b))
|
|
|
|
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
|
|
|
def program_id(self, axis: int = 0) -> int:
|
|
"""Return program instance index (ADR-0022).
|
|
|
|
axis=0: local PE id within cube.
|
|
axis=1: cube id.
|
|
"""
|
|
if axis == 1:
|
|
return self._cube_id
|
|
return self._pe_id
|
|
|
|
def num_programs(self, axis: int = 0) -> int:
|
|
"""Return total number of program instances (ADR-0022).
|
|
|
|
axis=0: num PEs per cube.
|
|
axis=1: num cubes.
|
|
"""
|
|
if axis == 1:
|
|
return self._num_cubes
|
|
return self._num_programs
|
|
|
|
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
|
"""Create index range tensor in TCM."""
|
|
n = end - start
|
|
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
|
|
|
|
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
|
"""Create zero-filled tensor in TCM."""
|
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
|
|
|
def full(
|
|
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
|
|
) -> TensorHandle:
|
|
"""Create constant-filled tensor in TCM."""
|
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
|
|
|
# ── Metadata (no compute, no DMA) ─────────────────────────────
|
|
|
|
def trans(self, x: TensorHandle) -> TensorHandle:
|
|
"""Transpose — shape change only, no command generated."""
|
|
if len(x.shape) < 2:
|
|
raise ValueError("trans requires at least 2D tensor")
|
|
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
|
|
return TensorHandle(
|
|
id=x.id, addr=x.addr, shape=new_shape,
|
|
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
|
)
|
|
|
|
# ── IPCQ (CCL) collective primitives (ADR-0023 D4) ────────────
|
|
|
|
def send(
|
|
self,
|
|
dir: str,
|
|
src: TensorHandle | None = None,
|
|
*,
|
|
src_addr: int | None = None,
|
|
nbytes: int | None = None,
|
|
shape: tuple[int, ...] | None = None,
|
|
dtype: str = "f16",
|
|
space: str = "tcm",
|
|
) -> None:
|
|
"""Send tensor data to the peer in the given direction.
|
|
|
|
Two calling forms:
|
|
tl.send(dir, handle) # use handle's metadata
|
|
tl.send(dir, src_addr=..., nbytes=..., shape=..., dtype=..., space=...)
|
|
|
|
Blocking: returns when PE_IPCQ has accepted the request and
|
|
forwarded the IpcqDmaToken to PE_DMA. Backpressure may apply.
|
|
"""
|
|
if src is not None:
|
|
src_addr = src.addr
|
|
nbytes = src.nbytes
|
|
shape = src.shape
|
|
dtype = src.dtype
|
|
space = getattr(src, "space", space)
|
|
if src_addr is None or nbytes is None or shape is None:
|
|
raise ValueError("tl.send: provide either a TensorHandle or src_addr/nbytes/shape")
|
|
# Carry the handle's .data snapshot (if available). When the source
|
|
# is a recv slot, .data holds the numpy array that was read from
|
|
# MemoryStore at recv-time. This prevents a Phase 1 race where a
|
|
# later IPCQ inbound overwrites the slot before the outbound
|
|
# PE_DMA reads it.
|
|
handle_data = getattr(src, "data", None) if src is not None else None
|
|
self._emit_dispatch_overhead()
|
|
cmd = IpcqSendCmd(
|
|
direction=dir,
|
|
src_addr=src_addr, src_space=space,
|
|
nbytes=nbytes, shape=shape, dtype=dtype,
|
|
handle_id=self._next_handle_id(),
|
|
data=handle_data,
|
|
)
|
|
self._emit(cmd)
|
|
|
|
def recv(
|
|
self,
|
|
dir: str | None = None,
|
|
shape: tuple[int, ...] = (),
|
|
dtype: str = "f16",
|
|
space: str = "tcm",
|
|
dst_addr: int | None = None,
|
|
dst_space: str | None = None,
|
|
) -> TensorHandle:
|
|
"""Receive tensor data from a peer.
|
|
|
|
Args:
|
|
dir: specific direction (e.g. "W"), or None for round-robin.
|
|
shape, dtype: expected tensor metadata.
|
|
dst_addr / dst_space: if both are provided, the slot data is
|
|
copied to (dst_space, dst_addr) before the handle is
|
|
returned ("copy_to_dst" mode). Otherwise the slot address
|
|
is returned directly ("return_slot" mode).
|
|
|
|
Returns:
|
|
TensorHandle pointing to the slot (or dst) where the data has
|
|
arrived. In greenlet/runner mode, ``handle.data`` carries the
|
|
actual ndarray; in command-list mode the handle is a placeholder.
|
|
"""
|
|
self._emit_dispatch_overhead()
|
|
if dst_addr is not None and dst_space is not None:
|
|
cmd = IpcqRecvCmd(
|
|
direction=dir,
|
|
shape=shape, dtype=dtype,
|
|
handle_id=self._next_handle_id(),
|
|
recv_mode="copy_to_dst",
|
|
dst_addr=dst_addr, dst_space=dst_space,
|
|
)
|
|
else:
|
|
cmd = IpcqRecvCmd(
|
|
direction=dir,
|
|
shape=shape, dtype=dtype,
|
|
handle_id=self._next_handle_id(),
|
|
)
|
|
result = self._emit(cmd)
|
|
if isinstance(result, dict):
|
|
slot_addr = int(result.get("src_addr", 0))
|
|
slot_space = str(result.get("src_space", "tcm"))
|
|
data = result.get("data")
|
|
return TensorHandle(
|
|
id=self._next_handle_id(),
|
|
addr=slot_addr,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
nbytes=self._nbytes(shape, dtype),
|
|
data=data,
|
|
space=slot_space,
|
|
)
|
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
|
|
|
def recv_no_consume(
|
|
self,
|
|
dir: str | None = None,
|
|
shape: tuple[int, ...] = (),
|
|
dtype: str = "f16",
|
|
) -> TensorHandle:
|
|
"""DIAGNOSTIC ONLY — recv that blocks for arrival but skips slot read.
|
|
|
|
Same blocking semantics as ``tl.recv``: the kernel waits until
|
|
the payload has landed in the IPCQ slot. Differs from ``tl.recv``
|
|
by skipping the slot-read latency charge (slot-IO + PE↔bank
|
|
fabric drain) on DST.
|
|
|
|
This entry point exists solely so the pe2pe overview plot can
|
|
draw an apples-to-apples comparison against ``tl.store`` (a
|
|
one-sided fabric write that pays no read on DST). Production
|
|
kernels MUST use ``tl.recv`` — they need to consume the data
|
|
they receive. This API is segregated from ``tl.recv`` so the
|
|
diagnostic flag can never accidentally be set in real workloads.
|
|
"""
|
|
self._emit_dispatch_overhead()
|
|
cmd = IpcqRecvCmd(
|
|
direction=dir,
|
|
shape=shape, dtype=dtype,
|
|
handle_id=self._next_handle_id(),
|
|
consume=False,
|
|
)
|
|
result = self._emit(cmd) # type: ignore[arg-type]
|
|
if isinstance(result, dict):
|
|
slot_addr = int(result.get("src_addr", 0))
|
|
slot_space = str(result.get("src_space", "tcm"))
|
|
return TensorHandle(
|
|
id=self._next_handle_id(),
|
|
addr=slot_addr,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
nbytes=self._nbytes(shape, dtype),
|
|
data=None,
|
|
space=slot_space,
|
|
)
|
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
|
|
|
def recv_async(
|
|
self,
|
|
dir: str,
|
|
shape: tuple[int, ...] = (),
|
|
dtype: str = "f16",
|
|
) -> "RecvFuture":
|
|
"""Non-blocking recv. Returns a future to pass into ``tl.wait``."""
|
|
self._emit_dispatch_overhead()
|
|
cmd = IpcqRecvCmd(
|
|
direction=dir,
|
|
shape=shape, dtype=dtype,
|
|
handle_id=self._next_handle_id(),
|
|
blocking=False,
|
|
)
|
|
future = RecvFuture(cmd=cmd)
|
|
if self._runner is not None:
|
|
self._runner.switch_to_simpy(("recv_async", future))
|
|
return future
|
|
|
|
# ── Composite + Control ───────────────────────────────────────
|
|
|
|
def composite(
|
|
self,
|
|
op: Literal["gemm", "math"],
|
|
a: TensorHandle,
|
|
b: TensorHandle | None = None,
|
|
out_ptr: int = 0,
|
|
math_op: str | None = None,
|
|
) -> CompletionHandle:
|
|
"""Submit a composite command (non-blocking, tiled pipeline).
|
|
|
|
Returns CompletionHandle for use with wait().
|
|
"""
|
|
# Compute output size based on op
|
|
if op == "gemm" and b is not None:
|
|
m, k = a.shape[-2], a.shape[-1]
|
|
n = b.shape[-1]
|
|
out_dtype = a.dtype
|
|
out_nbytes = m * n * self._dtype_bytes(out_dtype)
|
|
else:
|
|
out_nbytes = a.nbytes
|
|
|
|
completion = CompletionHandle(id=self._next_completion_id())
|
|
self._emit_dispatch_overhead()
|
|
self._emit(CompositeCmd(
|
|
completion=completion, op=op,
|
|
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
|
math_op=math_op,
|
|
))
|
|
return completion
|
|
|
|
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
|
|
"""Wait for a composite, a recv future, or all pending composites.
|
|
|
|
- ``CompletionHandle`` (or None): wait for composite completion.
|
|
- ``RecvFuture``: wait for a non-blocking ``recv_async`` to finish.
|
|
Returns the resolved ``TensorHandle``.
|
|
"""
|
|
if isinstance(handle, RecvFuture):
|
|
if handle.resolved:
|
|
return handle.result
|
|
if self._runner is None:
|
|
raise RuntimeError(
|
|
"tl.wait(RecvFuture) requires runner mode (greenlet)"
|
|
)
|
|
result_dict = self._runner.switch_to_simpy(("recv_wait", handle))
|
|
slot_addr = int(result_dict.get("src_addr", 0))
|
|
slot_space = str(result_dict.get("src_space", "tcm"))
|
|
data = result_dict.get("data")
|
|
th = TensorHandle(
|
|
id=self._next_handle_id(),
|
|
addr=slot_addr,
|
|
shape=handle.cmd.shape,
|
|
dtype=handle.cmd.dtype,
|
|
nbytes=self._nbytes(handle.cmd.shape, handle.cmd.dtype),
|
|
data=data,
|
|
space=slot_space,
|
|
)
|
|
handle.resolved = True
|
|
handle.result = th
|
|
return th
|
|
|
|
# Composite path (existing behaviour)
|
|
self._emit(WaitCmd(handle=handle))
|
|
return None
|
|
|
|
def cycles(self, n: int) -> None:
|
|
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
|
self._emit(PeCpuOverheadCmd(cycles=n))
|
|
|
|
|
|
# ── TensorHandle arithmetic operators ─────────────────────────────
|
|
# Enables: a + b, a * b, a - b, a / b in kernel code.
|
|
# Each creates a MathCmd via a module-level helper that requires a
|
|
# TLContext. We attach the context to handles via a closure approach.
|
|
|
|
|
|
def _enable_tensor_ops() -> None:
|
|
"""Patch TensorHandle with arithmetic operators.
|
|
|
|
Called once at module load. Operators create MathCmd entries via
|
|
a thread-local TLContext reference set during kernel execution.
|
|
"""
|
|
import threading
|
|
|
|
_local = threading.local()
|
|
|
|
def set_active_context(ctx: TLContext | None) -> None:
|
|
_local.ctx = ctx
|
|
|
|
def get_active_context() -> TLContext:
|
|
ctx = getattr(_local, "ctx", None)
|
|
if ctx is None:
|
|
raise RuntimeError("TensorHandle ops require an active TLContext")
|
|
return ctx
|
|
|
|
def _binop(op: str):
|
|
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
|
|
ctx = get_active_context()
|
|
return ctx._binary_math(op, self, other)
|
|
return method
|
|
|
|
# Patch TensorHandle class with operators
|
|
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
|
|
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
|
|
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
|
|
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
|
|
|
|
# Expose context management
|
|
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
|
|
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
|
|
|
|
|
|
_enable_tensor_ops()
|
|
|
|
|
|
def run_kernel(
|
|
kernel_fn,
|
|
tl_ctx: TLContext,
|
|
*args,
|
|
**kwargs,
|
|
) -> list[PeCommand]:
|
|
"""Execute a kernel function with the given TLContext and return commands.
|
|
|
|
Sets tl_ctx as the active context for TensorHandle operators,
|
|
calls the kernel, then clears the context.
|
|
"""
|
|
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
|
|
try:
|
|
kernel_fn(*args, tl=tl_ctx, **kwargs)
|
|
finally:
|
|
TLContext._set_active(None) # type: ignore[attr-defined]
|
|
return tl_ctx.commands
|