tl.composite: fused epilogue ops with per-op scope

Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-15 10:16:47 -07:00
parent a76487ca48
commit a7fe785e5f
12 changed files with 382 additions and 20 deletions
+22 -15
View File
@@ -68,29 +68,36 @@ class PeMathComponent(PeEngineBase):
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure SIMD compute, then self-route."""
self._on_process_start(env, token)
"""Pipeline mode: pure SIMD compute, then self-route.
num_elements = token.params.get("num_elements", 0)
Handles consecutive same-component MATH stages (e.g. chained
epilogue ops bias→relu) by looping locally before forwarding,
mirroring the pattern in PeDmaComponent._pipeline_process.
"""
yield from self._do_compute(env, token)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
else:
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
while next_stage is not None and next_stage.component == self.node.id:
yield from self._do_compute(env, token)
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def _do_compute(self, env: simpy.Environment, token: Any) -> Generator:
"""Single MATH stage: latency model + op_log bracketing."""
self._on_process_start(env, token)
num_elements = token.params.get("num_elements", 0)
if self._accel:
with self._accel.request() as req:
yield req
yield env.timeout(self._compute_ns(num_elements))
else:
yield env.timeout(self._compute_ns(num_elements))
self._on_process_end(env, token)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""PeInternalTxn handling for standalone MathCmd (CCL kernels).
@@ -157,6 +157,9 @@ class PeSchedulerComponent(ComponentBase):
b = cmd.b
M, K = a.shape[-2], a.shape[-1]
N = b.shape[-1]
# When CompositeCmd.ops is populated, ops[0] is the head and
# ops[1:] is the epilogue spec list. Empty ops → legacy path.
epi_specs = tuple(cmd.ops[1:]) if cmd.ops else ()
return generate_gemm_plan(
M=M, K=K, N=N,
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
@@ -165,6 +168,7 @@ class PeSchedulerComponent(ComponentBase):
pe_prefix=pp,
a_pinned=getattr(a, "pinned", False),
b_pinned=getattr(b, "pinned", False),
epilogue_specs=epi_specs,
)
else:
# Math composite
+36 -1
View File
@@ -23,6 +23,7 @@ def generate_gemm_plan(
pe_prefix: str,
a_pinned: bool = False,
b_pinned: bool = False,
epilogue_specs: tuple = (),
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
@@ -46,7 +47,15 @@ def generate_gemm_plan(
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
gemm_id = f"{pe_prefix}.pe_gemm"
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
math_id = f"{pe_prefix}.pe_math"
# Split epilogue_specs by scope. Lazy import to avoid circular dep
# between pe_commands ← pe_types ← tiling.
from kernbench.common.pe_commands import Scope as _Scope
k_tile_ops = [o for o in epilogue_specs
if getattr(o, "scope", None) == _Scope.K_TILE]
out_tile_ops = [o for o in epilogue_specs
if getattr(o, "scope", None) == _Scope.OUTPUT_TILE]
tiles: list[TilePlan] = []
tile_id = 0
@@ -106,9 +115,35 @@ def generate_gemm_plan(
},
))
# K-tile-scope epilogue MATH stages (e.g. dequant) — run on
# every K-tile right after GEMM, before the accumulator
# advances to the next K slice.
for op in k_tile_ops:
stages.append(Stage(
stage_type=StageType.MATH,
component=math_id,
params={
"op_kind": op.kind,
"num_elements": tile_m * tile_n,
"scope": "k_tile",
},
))
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
# accumulator stays in RegFile across the K loop.
if last_k:
# Output-tile-scope epilogue MATH (bias, relu, ...) runs
# ONCE per (m,n) after the final K-tile, before writeback.
for op in out_tile_ops:
stages.append(Stage(
stage_type=StageType.MATH,
component=math_id,
params={
"op_kind": op.kind,
"num_elements": tile_m * tile_n,
"scope": "output_tile",
},
))
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,