tl.composite: fused epilogue ops with per-op scope

Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-15 10:16:47 -07:00
parent a76487ca48
commit a7fe785e5f
12 changed files with 382 additions and 20 deletions
+10 -1
View File
@@ -51,11 +51,15 @@ class OpLogger:
record_end fires.
"""
snap: dict[str, Any] = {}
# TileToken (ADR-0021 pipeline) — capture which stage this is.
# TileToken (ADR-0021 pipeline) — capture which stage this is and its
# per-stage params (e.g. op_kind/scope for epilogue MATH stages) so
# we can recover them at record_end even after the token advances.
try:
stage = getattr(msg, "current_stage", None)
if stage is not None:
snap["stage_type"] = stage.stage_type.name
if isinstance(getattr(stage, "params", None), dict):
snap["stage_params"] = dict(stage.params)
except Exception:
pass
self._pending[id(msg)] = {
@@ -78,6 +82,11 @@ class OpLogger:
stage_type = snap.get("stage_type")
if stage_type is not None:
params = dict(params)
# Merge per-stage params (e.g. op_kind, scope) captured at start.
stage_params = snap.get("stage_params")
if isinstance(stage_params, dict):
for k, v in stage_params.items():
params.setdefault(k, v)
params["stage_type"] = stage_type
if op_name == "TileToken":
op_name = f"TileToken/{stage_type}"