tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -9,12 +9,50 @@ Command lifecycle:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import simpy
|
||||
|
||||
|
||||
class Scope(Enum):
|
||||
K_TILE = "k_tile"
|
||||
OUTPUT_TILE = "output_tile"
|
||||
KERNEL = "kernel"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpSpec:
|
||||
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
|
||||
|
||||
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
|
||||
ops are epilogue stages whose ``scope`` controls how often they fire
|
||||
(per-K-tile, per-output-tile, or once per kernel).
|
||||
"""
|
||||
|
||||
kind: str # "gemm" | "bias" | "relu" | ...
|
||||
scope: "Scope" = Scope.OUTPUT_TILE
|
||||
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
|
||||
scalar: float | None = None
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Epilogue op contracts: kind → (required field names, default scope).
|
||||
# Used by tl.composite to validate user-provided epilogue dicts at
|
||||
# command-emit time so typos fail before reaching the scheduler.
|
||||
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
|
||||
"bias": (("bias",), Scope.OUTPUT_TILE),
|
||||
"relu": ((), Scope.OUTPUT_TILE),
|
||||
"gelu": ((), Scope.OUTPUT_TILE),
|
||||
"sigmoid": ((), Scope.OUTPUT_TILE),
|
||||
"scale": (("factor",), Scope.OUTPUT_TILE),
|
||||
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
|
||||
"dequant": (("scale",), Scope.K_TILE),
|
||||
"add": (("other",), Scope.OUTPUT_TILE),
|
||||
}
|
||||
|
||||
|
||||
# ── Handles ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -118,6 +156,10 @@ class CompositeCmd:
|
||||
out_nbytes: int
|
||||
math_op: str | None = None # for op="math": which math operation
|
||||
data_op: bool = True
|
||||
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
|
||||
# head and ops[1:] are epilogue stages with explicit scope. When empty,
|
||||
# the legacy single-op semantics (op/a/b/math_op) apply.
|
||||
ops: tuple[OpSpec, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user