tl.composite: fused epilogue ops with per-op scope

Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-15 10:16:47 -07:00
parent a76487ca48
commit a7fe785e5f
12 changed files with 382 additions and 20 deletions
+42
View File
@@ -9,12 +9,50 @@ Command lifecycle:
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
import simpy
class Scope(Enum):
K_TILE = "k_tile"
OUTPUT_TILE = "output_tile"
KERNEL = "kernel"
@dataclass(frozen=True)
class OpSpec:
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
ops are epilogue stages whose ``scope`` controls how often they fire
(per-K-tile, per-output-tile, or once per kernel).
"""
kind: str # "gemm" | "bias" | "relu" | ...
scope: "Scope" = Scope.OUTPUT_TILE
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
scalar: float | None = None
extra: dict[str, Any] = field(default_factory=dict)
# Epilogue op contracts: kind → (required field names, default scope).
# Used by tl.composite to validate user-provided epilogue dicts at
# command-emit time so typos fail before reaching the scheduler.
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
"bias": (("bias",), Scope.OUTPUT_TILE),
"relu": ((), Scope.OUTPUT_TILE),
"gelu": ((), Scope.OUTPUT_TILE),
"sigmoid": ((), Scope.OUTPUT_TILE),
"scale": (("factor",), Scope.OUTPUT_TILE),
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
"dequant": (("scale",), Scope.K_TILE),
"add": (("other",), Scope.OUTPUT_TILE),
}
# ── Handles ───────────────────────────────────────────────────────
@@ -118,6 +156,10 @@ class CompositeCmd:
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
data_op: bool = True
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
# head and ops[1:] are epilogue stages with explicit scope. When empty,
# the legacy single-op semantics (op/a/b/math_op) apply.
ops: tuple[OpSpec, ...] = ()
@dataclass(frozen=True)