tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -19,14 +19,17 @@ from typing import Literal
|
||||
|
||||
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
|
||||
from kernbench.common.pe_commands import (
|
||||
EPILOGUE_OPS,
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
OpSpec,
|
||||
PeCommand,
|
||||
PeCpuOverheadCmd,
|
||||
Scope,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
@@ -565,9 +568,18 @@ class TLContext:
|
||||
b: TensorHandle | None = None,
|
||||
out_ptr: int = 0,
|
||||
math_op: str | None = None,
|
||||
*,
|
||||
epilogue: list[dict] | None = None,
|
||||
acc_dtype: str | None = None,
|
||||
tile_shape: tuple[int, int, int] | None = None,
|
||||
) -> CompletionHandle:
|
||||
"""Submit a composite command (non-blocking, tiled pipeline).
|
||||
|
||||
Optional ``epilogue`` is an ordered list of dicts; each dict has a
|
||||
required ``"op"`` key (one of ``EPILOGUE_OPS``) plus op-specific
|
||||
fields and an optional ``"scope"``. Validation happens here so
|
||||
typos fail before the command is emitted.
|
||||
|
||||
Returns CompletionHandle for use with wait().
|
||||
"""
|
||||
# Compute output size based on op
|
||||
@@ -579,15 +591,72 @@ class TLContext:
|
||||
else:
|
||||
out_nbytes = a.nbytes
|
||||
|
||||
ops_tuple: tuple[OpSpec, ...] = ()
|
||||
if epilogue is not None:
|
||||
head_operands = (a, b) if (op == "gemm" and b is not None) else (a,)
|
||||
head_spec = OpSpec(
|
||||
kind=op, scope=Scope.OUTPUT_TILE, operands=head_operands,
|
||||
extra={
|
||||
k: v for k, v in (
|
||||
("acc_dtype", acc_dtype),
|
||||
("tile_shape", tile_shape),
|
||||
("math_op", math_op),
|
||||
) if v is not None
|
||||
},
|
||||
)
|
||||
epi_specs = tuple(self._build_epilogue_spec(e, i)
|
||||
for i, e in enumerate(epilogue))
|
||||
ops_tuple = (head_spec, *epi_specs)
|
||||
|
||||
completion = CompletionHandle(id=self._next_completion_id())
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(CompositeCmd(
|
||||
completion=completion, op=op,
|
||||
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
||||
math_op=math_op,
|
||||
math_op=math_op, ops=ops_tuple,
|
||||
))
|
||||
return completion
|
||||
|
||||
@staticmethod
|
||||
def _build_epilogue_spec(entry: dict, idx: int) -> OpSpec:
|
||||
if not isinstance(entry, dict) or "op" not in entry:
|
||||
raise ValueError(
|
||||
f"epilogue[{idx}]: each entry must be a dict with an 'op' key"
|
||||
)
|
||||
kind = entry["op"]
|
||||
if kind not in EPILOGUE_OPS:
|
||||
known = ", ".join(sorted(EPILOGUE_OPS))
|
||||
raise ValueError(
|
||||
f"epilogue[{idx}]: unknown op {kind!r} "
|
||||
f"(known ops: {known})"
|
||||
)
|
||||
required, default_scope = EPILOGUE_OPS[kind]
|
||||
missing = [f for f in required if f not in entry]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"epilogue[{idx}] op {kind!r} missing required field(s): "
|
||||
f"{', '.join(missing)}"
|
||||
)
|
||||
scope = Scope(entry["scope"]) if "scope" in entry else default_scope
|
||||
operands: list = []
|
||||
scalar: float | None = None
|
||||
extra: dict = {}
|
||||
for f in required:
|
||||
v = entry[f]
|
||||
if isinstance(v, TensorHandle):
|
||||
operands.append(v)
|
||||
elif isinstance(v, (int, float)):
|
||||
if scalar is None:
|
||||
scalar = float(v)
|
||||
else:
|
||||
extra[f] = v
|
||||
else:
|
||||
extra[f] = v
|
||||
return OpSpec(
|
||||
kind=kind, scope=scope,
|
||||
operands=tuple(operands), scalar=scalar, extra=extra,
|
||||
)
|
||||
|
||||
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
|
||||
"""Wait for a composite, a recv future, or all pending composites.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user