tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -68,29 +68,36 @@ class PeMathComponent(PeEngineBase):
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: pure SIMD compute, then self-route."""
|
||||
self._on_process_start(env, token)
|
||||
"""Pipeline mode: pure SIMD compute, then self-route.
|
||||
|
||||
num_elements = token.params.get("num_elements", 0)
|
||||
Handles consecutive same-component MATH stages (e.g. chained
|
||||
epilogue ops bias→relu) by looping locally before forwarding,
|
||||
mirroring the pattern in PeDmaComponent._pipeline_process.
|
||||
"""
|
||||
yield from self._do_compute(env, token)
|
||||
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
ns = self._compute_ns(num_elements)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
ns = self._compute_ns(num_elements)
|
||||
yield env.timeout(ns)
|
||||
|
||||
self._on_process_end(env, token)
|
||||
|
||||
# Self-routing
|
||||
next_stage = token.advance()
|
||||
while next_stage is not None and next_stage.component == self.node.id:
|
||||
yield from self._do_compute(env, token)
|
||||
next_stage = token.advance()
|
||||
|
||||
if next_stage is not None:
|
||||
yield self.out_ports[next_stage.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
def _do_compute(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Single MATH stage: latency model + op_log bracketing."""
|
||||
self._on_process_start(env, token)
|
||||
num_elements = token.params.get("num_elements", 0)
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield env.timeout(self._compute_ns(num_elements))
|
||||
else:
|
||||
yield env.timeout(self._compute_ns(num_elements))
|
||||
self._on_process_end(env, token)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""PeInternalTxn handling for standalone MathCmd (CCL kernels).
|
||||
|
||||
|
||||
@@ -157,6 +157,9 @@ class PeSchedulerComponent(ComponentBase):
|
||||
b = cmd.b
|
||||
M, K = a.shape[-2], a.shape[-1]
|
||||
N = b.shape[-1]
|
||||
# When CompositeCmd.ops is populated, ops[0] is the head and
|
||||
# ops[1:] is the epilogue spec list. Empty ops → legacy path.
|
||||
epi_specs = tuple(cmd.ops[1:]) if cmd.ops else ()
|
||||
return generate_gemm_plan(
|
||||
M=M, K=K, N=N,
|
||||
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
|
||||
@@ -165,6 +168,7 @@ class PeSchedulerComponent(ComponentBase):
|
||||
pe_prefix=pp,
|
||||
a_pinned=getattr(a, "pinned", False),
|
||||
b_pinned=getattr(b, "pinned", False),
|
||||
epilogue_specs=epi_specs,
|
||||
)
|
||||
else:
|
||||
# Math composite
|
||||
|
||||
@@ -23,6 +23,7 @@ def generate_gemm_plan(
|
||||
pe_prefix: str,
|
||||
a_pinned: bool = False,
|
||||
b_pinned: bool = False,
|
||||
epilogue_specs: tuple = (),
|
||||
) -> PipelinePlan:
|
||||
"""Generate GEMM tile plan: M→N→K order.
|
||||
|
||||
@@ -46,7 +47,15 @@ def generate_gemm_plan(
|
||||
dma_id = f"{pe_prefix}.pe_dma"
|
||||
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
||||
gemm_id = f"{pe_prefix}.pe_gemm"
|
||||
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
|
||||
math_id = f"{pe_prefix}.pe_math"
|
||||
|
||||
# Split epilogue_specs by scope. Lazy import to avoid circular dep
|
||||
# between pe_commands ← pe_types ← tiling.
|
||||
from kernbench.common.pe_commands import Scope as _Scope
|
||||
k_tile_ops = [o for o in epilogue_specs
|
||||
if getattr(o, "scope", None) == _Scope.K_TILE]
|
||||
out_tile_ops = [o for o in epilogue_specs
|
||||
if getattr(o, "scope", None) == _Scope.OUTPUT_TILE]
|
||||
|
||||
tiles: list[TilePlan] = []
|
||||
tile_id = 0
|
||||
@@ -106,9 +115,35 @@ def generate_gemm_plan(
|
||||
},
|
||||
))
|
||||
|
||||
# K-tile-scope epilogue MATH stages (e.g. dequant) — run on
|
||||
# every K-tile right after GEMM, before the accumulator
|
||||
# advances to the next K slice.
|
||||
for op in k_tile_ops:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.MATH,
|
||||
component=math_id,
|
||||
params={
|
||||
"op_kind": op.kind,
|
||||
"num_elements": tile_m * tile_n,
|
||||
"scope": "k_tile",
|
||||
},
|
||||
))
|
||||
|
||||
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
||||
# accumulator stays in RegFile across the K loop.
|
||||
if last_k:
|
||||
# Output-tile-scope epilogue MATH (bias, relu, ...) runs
|
||||
# ONCE per (m,n) after the final K-tile, before writeback.
|
||||
for op in out_tile_ops:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.MATH,
|
||||
component=math_id,
|
||||
params={
|
||||
"op_kind": op.kind,
|
||||
"num_elements": tile_m * tile_n,
|
||||
"scope": "output_tile",
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
|
||||
Reference in New Issue
Block a user