Files
kernbench2/src/kernbench/common/pe_commands.py
T
mukesh a7fe785e5f tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 10:16:47 -07:00

200 lines
5.9 KiB
Python

"""PE-internal command types and handles (ADR-0014).
Generated by triton_emu (TLContext) and consumed by PE component
implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH).
Command lifecycle:
Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
import simpy
class Scope(Enum):
K_TILE = "k_tile"
OUTPUT_TILE = "output_tile"
KERNEL = "kernel"
@dataclass(frozen=True)
class OpSpec:
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
ops are epilogue stages whose ``scope`` controls how often they fire
(per-K-tile, per-output-tile, or once per kernel).
"""
kind: str # "gemm" | "bias" | "relu" | ...
scope: "Scope" = Scope.OUTPUT_TILE
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
scalar: float | None = None
extra: dict[str, Any] = field(default_factory=dict)
# Epilogue op contracts: kind → (required field names, default scope).
# Used by tl.composite to validate user-provided epilogue dicts at
# command-emit time so typos fail before reaching the scheduler.
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
"bias": (("bias",), Scope.OUTPUT_TILE),
"relu": ((), Scope.OUTPUT_TILE),
"gelu": ((), Scope.OUTPUT_TILE),
"sigmoid": ((), Scope.OUTPUT_TILE),
"scale": (("factor",), Scope.OUTPUT_TILE),
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
"dequant": (("scale",), Scope.K_TILE),
"add": (("other",), Scope.OUTPUT_TILE),
}
# ── Handles ───────────────────────────────────────────────────────
@dataclass(frozen=True)
class TensorHandle:
"""Opaque reference to a tensor residing in PE_TCM.
Returned by tl.load, tl.dot, tl.exp, etc.
Carries metadata for command generation; data field is reserved
for future validate mode (numpy array).
"""
id: str
addr: int # address (VA when MMU enabled, PA otherwise)
shape: tuple[int, ...]
dtype: str
nbytes: int # total byte size
data: object = None # reserved for validate mode
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
pinned: bool = False # operand already DMA-staged in TCM (via tl.load)
@dataclass(frozen=True)
class CompletionHandle:
"""Opaque handle for a non-blocking composite command.
Returned by tl.composite, consumed by tl.wait.
"""
id: str
# ── PE Commands ───────────────────────────────────────────────────
@dataclass(frozen=True)
class DmaReadCmd:
"""DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
src_addr: int
nbytes: int
data_op: bool = True
@dataclass(frozen=True)
class DmaWriteCmd:
"""DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
dst_addr: int
nbytes: int
data_op: bool = True
@dataclass(frozen=True)
class GemmCmd:
"""GEMM engine command: matrix multiply on TCM data.
out = a @ b, all operands in TCM.
"""
a: TensorHandle
b: TensorHandle
out: TensorHandle
m: int
k: int
n: int
data_op: bool = True
@dataclass(frozen=True)
class MathCmd:
"""MATH engine command: unary/binary/reduction on TCM data.
op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin",
"add", "sub", "mul", "div", "where",
"sum", "max", "min"
"""
op: str
inputs: tuple[TensorHandle, ...]
out: TensorHandle
axis: int | None = None # for reductions
data_op: bool = True
@dataclass(frozen=True)
class CompositeCmd:
"""Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE.
Non-blocking — submitted to PE_SCHEDULER which manages tile splitting
and pipeline overlaps (ADR-0014 D3.2).
"""
completion: CompletionHandle
op: Literal["gemm", "math"]
a: TensorHandle
b: TensorHandle | None
out_addr: int
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
data_op: bool = True
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
# head and ops[1:] are epilogue stages with explicit scope. When empty,
# the legacy single-op semantics (op/a/b/math_op) apply.
ops: tuple[OpSpec, ...] = ()
@dataclass(frozen=True)
class WaitCmd:
"""Wait for a specific composite or all pending composites."""
handle: CompletionHandle | None = None # None = wait all
@dataclass(frozen=True)
class PeCpuOverheadCmd:
"""PE_CPU scalar execution overhead (cycles)."""
cycles: int
# Union type for all PE commands
PeCommand = (
DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd
| CompositeCmd | WaitCmd | PeCpuOverheadCmd
)
@dataclass
class PeInternalTxn:
"""PE-internal message flowing PE_CPU → PE_SCHEDULER → engines.
Carries a single PeCommand and a completion event. PE_CPU creates one
PeInternalTxn per command during the replay phase and sends it to
PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA,
PE_GEMM, PE_MATH). The engine signals ``done`` on completion.
"""
command: PeCommand
done: simpy.Event # succeeded when the engine completes this command
pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution
result_data: dict[str, Any] = field(default_factory=dict)