tl.composite: fused epilogue ops with per-op scope

Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-15 10:16:47 -07:00
parent a76487ca48
commit a7fe785e5f
12 changed files with 382 additions and 20 deletions
+42
View File
@@ -9,12 +9,50 @@ Command lifecycle:
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
import simpy
class Scope(Enum):
K_TILE = "k_tile"
OUTPUT_TILE = "output_tile"
KERNEL = "kernel"
@dataclass(frozen=True)
class OpSpec:
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
ops are epilogue stages whose ``scope`` controls how often they fire
(per-K-tile, per-output-tile, or once per kernel).
"""
kind: str # "gemm" | "bias" | "relu" | ...
scope: "Scope" = Scope.OUTPUT_TILE
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
scalar: float | None = None
extra: dict[str, Any] = field(default_factory=dict)
# Epilogue op contracts: kind → (required field names, default scope).
# Used by tl.composite to validate user-provided epilogue dicts at
# command-emit time so typos fail before reaching the scheduler.
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
"bias": (("bias",), Scope.OUTPUT_TILE),
"relu": ((), Scope.OUTPUT_TILE),
"gelu": ((), Scope.OUTPUT_TILE),
"sigmoid": ((), Scope.OUTPUT_TILE),
"scale": (("factor",), Scope.OUTPUT_TILE),
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
"dequant": (("scale",), Scope.K_TILE),
"add": (("other",), Scope.OUTPUT_TILE),
}
# ── Handles ───────────────────────────────────────────────────────
@@ -118,6 +156,10 @@ class CompositeCmd:
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
data_op: bool = True
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
# head and ops[1:] are epilogue stages with explicit scope. When empty,
# the legacy single-op semantics (op/a/b/math_op) apply.
ops: tuple[OpSpec, ...] = ()
@dataclass(frozen=True)
+22 -15
View File
@@ -68,29 +68,36 @@ class PeMathComponent(PeEngineBase):
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure SIMD compute, then self-route."""
self._on_process_start(env, token)
"""Pipeline mode: pure SIMD compute, then self-route.
num_elements = token.params.get("num_elements", 0)
Handles consecutive same-component MATH stages (e.g. chained
epilogue ops bias→relu) by looping locally before forwarding,
mirroring the pattern in PeDmaComponent._pipeline_process.
"""
yield from self._do_compute(env, token)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
else:
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
while next_stage is not None and next_stage.component == self.node.id:
yield from self._do_compute(env, token)
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def _do_compute(self, env: simpy.Environment, token: Any) -> Generator:
"""Single MATH stage: latency model + op_log bracketing."""
self._on_process_start(env, token)
num_elements = token.params.get("num_elements", 0)
if self._accel:
with self._accel.request() as req:
yield req
yield env.timeout(self._compute_ns(num_elements))
else:
yield env.timeout(self._compute_ns(num_elements))
self._on_process_end(env, token)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""PeInternalTxn handling for standalone MathCmd (CCL kernels).
@@ -157,6 +157,9 @@ class PeSchedulerComponent(ComponentBase):
b = cmd.b
M, K = a.shape[-2], a.shape[-1]
N = b.shape[-1]
# When CompositeCmd.ops is populated, ops[0] is the head and
# ops[1:] is the epilogue spec list. Empty ops → legacy path.
epi_specs = tuple(cmd.ops[1:]) if cmd.ops else ()
return generate_gemm_plan(
M=M, K=K, N=N,
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
@@ -165,6 +168,7 @@ class PeSchedulerComponent(ComponentBase):
pe_prefix=pp,
a_pinned=getattr(a, "pinned", False),
b_pinned=getattr(b, "pinned", False),
epilogue_specs=epi_specs,
)
else:
# Math composite
+36 -1
View File
@@ -23,6 +23,7 @@ def generate_gemm_plan(
pe_prefix: str,
a_pinned: bool = False,
b_pinned: bool = False,
epilogue_specs: tuple = (),
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
@@ -46,7 +47,15 @@ def generate_gemm_plan(
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
gemm_id = f"{pe_prefix}.pe_gemm"
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
math_id = f"{pe_prefix}.pe_math"
# Split epilogue_specs by scope. Lazy import to avoid circular dep
# between pe_commands ← pe_types ← tiling.
from kernbench.common.pe_commands import Scope as _Scope
k_tile_ops = [o for o in epilogue_specs
if getattr(o, "scope", None) == _Scope.K_TILE]
out_tile_ops = [o for o in epilogue_specs
if getattr(o, "scope", None) == _Scope.OUTPUT_TILE]
tiles: list[TilePlan] = []
tile_id = 0
@@ -106,9 +115,35 @@ def generate_gemm_plan(
},
))
# K-tile-scope epilogue MATH stages (e.g. dequant) — run on
# every K-tile right after GEMM, before the accumulator
# advances to the next K slice.
for op in k_tile_ops:
stages.append(Stage(
stage_type=StageType.MATH,
component=math_id,
params={
"op_kind": op.kind,
"num_elements": tile_m * tile_n,
"scope": "k_tile",
},
))
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
# accumulator stays in RegFile across the K loop.
if last_k:
# Output-tile-scope epilogue MATH (bias, relu, ...) runs
# ONCE per (m,n) after the final K-tile, before writeback.
for op in out_tile_ops:
stages.append(Stage(
stage_type=StageType.MATH,
component=math_id,
params={
"op_kind": op.kind,
"num_elements": tile_m * tile_n,
"scope": "output_tile",
},
))
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
+10 -1
View File
@@ -51,11 +51,15 @@ class OpLogger:
record_end fires.
"""
snap: dict[str, Any] = {}
# TileToken (ADR-0021 pipeline) — capture which stage this is.
# TileToken (ADR-0021 pipeline) — capture which stage this is and its
# per-stage params (e.g. op_kind/scope for epilogue MATH stages) so
# we can recover them at record_end even after the token advances.
try:
stage = getattr(msg, "current_stage", None)
if stage is not None:
snap["stage_type"] = stage.stage_type.name
if isinstance(getattr(stage, "params", None), dict):
snap["stage_params"] = dict(stage.params)
except Exception:
pass
self._pending[id(msg)] = {
@@ -78,6 +82,11 @@ class OpLogger:
stage_type = snap.get("stage_type")
if stage_type is not None:
params = dict(params)
# Merge per-stage params (e.g. op_kind, scope) captured at start.
stage_params = snap.get("stage_params")
if isinstance(stage_params, dict):
for k, v in stage_params.items():
params.setdefault(k, v)
params["stage_type"] = stage_type
if op_name == "TileToken":
op_name = f"TileToken/{stage_type}"
+1
View File
@@ -700,6 +700,7 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"),
("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"),
("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"),
("pe_gemm", "pe_math", "gemm_to_math_mm"),
("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"),
("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"),
]
+70 -1
View File
@@ -19,14 +19,17 @@ from typing import Literal
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
from kernbench.common.pe_commands import (
EPILOGUE_OPS,
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
OpSpec,
PeCommand,
PeCpuOverheadCmd,
Scope,
TensorHandle,
WaitCmd,
)
@@ -565,9 +568,18 @@ class TLContext:
b: TensorHandle | None = None,
out_ptr: int = 0,
math_op: str | None = None,
*,
epilogue: list[dict] | None = None,
acc_dtype: str | None = None,
tile_shape: tuple[int, int, int] | None = None,
) -> CompletionHandle:
"""Submit a composite command (non-blocking, tiled pipeline).
Optional ``epilogue`` is an ordered list of dicts; each dict has a
required ``"op"`` key (one of ``EPILOGUE_OPS``) plus op-specific
fields and an optional ``"scope"``. Validation happens here so
typos fail before the command is emitted.
Returns CompletionHandle for use with wait().
"""
# Compute output size based on op
@@ -579,15 +591,72 @@ class TLContext:
else:
out_nbytes = a.nbytes
ops_tuple: tuple[OpSpec, ...] = ()
if epilogue is not None:
head_operands = (a, b) if (op == "gemm" and b is not None) else (a,)
head_spec = OpSpec(
kind=op, scope=Scope.OUTPUT_TILE, operands=head_operands,
extra={
k: v for k, v in (
("acc_dtype", acc_dtype),
("tile_shape", tile_shape),
("math_op", math_op),
) if v is not None
},
)
epi_specs = tuple(self._build_epilogue_spec(e, i)
for i, e in enumerate(epilogue))
ops_tuple = (head_spec, *epi_specs)
completion = CompletionHandle(id=self._next_completion_id())
self._emit_dispatch_overhead()
self._emit(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
math_op=math_op, ops=ops_tuple,
))
return completion
@staticmethod
def _build_epilogue_spec(entry: dict, idx: int) -> OpSpec:
if not isinstance(entry, dict) or "op" not in entry:
raise ValueError(
f"epilogue[{idx}]: each entry must be a dict with an 'op' key"
)
kind = entry["op"]
if kind not in EPILOGUE_OPS:
known = ", ".join(sorted(EPILOGUE_OPS))
raise ValueError(
f"epilogue[{idx}]: unknown op {kind!r} "
f"(known ops: {known})"
)
required, default_scope = EPILOGUE_OPS[kind]
missing = [f for f in required if f not in entry]
if missing:
raise ValueError(
f"epilogue[{idx}] op {kind!r} missing required field(s): "
f"{', '.join(missing)}"
)
scope = Scope(entry["scope"]) if "scope" in entry else default_scope
operands: list = []
scalar: float | None = None
extra: dict = {}
for f in required:
v = entry[f]
if isinstance(v, TensorHandle):
operands.append(v)
elif isinstance(v, (int, float)):
if scalar is None:
scalar = float(v)
else:
extra[f] = v
else:
extra[f] = v
return OpSpec(
kind=kind, scope=scope,
operands=tuple(operands), scalar=scalar, extra=extra,
)
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
"""Wait for a composite, a recv future, or all pending composites.