tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,140 @@
|
||||
"""Tests for multi-op tl.composite() with epilogue scopes.
|
||||
|
||||
Public-surface tests only: we exercise tl.composite() and inspect the
|
||||
resulting CompositeCmd. Validation, plan generation, and scheduling are
|
||||
covered implicitly — they're internal to tl_context / pe_scheduler / tiling.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
EPILOGUE_OPS,
|
||||
CompositeCmd,
|
||||
Scope,
|
||||
TensorHandle,
|
||||
)
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
|
||||
def _h(idx: int, shape=(32, 32)) -> TensorHandle:
|
||||
nbytes = 1
|
||||
for d in shape:
|
||||
nbytes *= d
|
||||
return TensorHandle(
|
||||
id=f"h{idx}", addr=0x1000 + idx * 0x100,
|
||||
shape=shape, dtype="f16", nbytes=nbytes * 2,
|
||||
)
|
||||
|
||||
|
||||
def test_composite_epilogue_roundtrip():
|
||||
"""tl.composite() with mixed-scope epilogue produces a CompositeCmd whose
|
||||
ops tuple preserves order, kinds, and default scopes."""
|
||||
tl = TLContext()
|
||||
a, b = _h(0), _h(1)
|
||||
bias = _h(2, shape=(32,))
|
||||
scale = _h(3, shape=(2,))
|
||||
|
||||
tl.composite(
|
||||
op="gemm", a=a, b=b, out_ptr=0x2000,
|
||||
epilogue=[
|
||||
{"op": "bias", "bias": bias}, # default OUTPUT_TILE
|
||||
{"op": "dequant", "scale": scale}, # default K_TILE
|
||||
{"op": "relu"}, # default OUTPUT_TILE
|
||||
{"op": "scale", "factor": 0.5},
|
||||
],
|
||||
)
|
||||
cmd = tl._commands[-1]
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
|
||||
kinds_scopes = [(o.kind, o.scope) for o in cmd.ops]
|
||||
assert kinds_scopes == [
|
||||
("gemm", Scope.OUTPUT_TILE),
|
||||
("bias", Scope.OUTPUT_TILE),
|
||||
("dequant", Scope.K_TILE),
|
||||
("relu", Scope.OUTPUT_TILE),
|
||||
("scale", Scope.OUTPUT_TILE),
|
||||
]
|
||||
|
||||
# Single-op call (no epilogue) keeps the legacy code path: ops stays empty.
|
||||
tl2 = TLContext()
|
||||
tl2.composite(op="gemm", a=a, b=b, out_ptr=0x2000)
|
||||
cmd2 = tl2._commands[-1]
|
||||
assert isinstance(cmd2, CompositeCmd)
|
||||
assert cmd2.ops == ()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad,match", [
|
||||
([{"op": "biass"}], "unknown op 'biass'"),
|
||||
([{"op": "bias"}], "missing required field"),
|
||||
(["relu"], "must be a dict"),
|
||||
])
|
||||
def test_composite_epilogue_rejects_bad_input(bad, match):
|
||||
tl = TLContext()
|
||||
with pytest.raises(ValueError, match=match):
|
||||
tl.composite(op="gemm", a=_h(0), b=_h(1), out_ptr=0x2000,
|
||||
epilogue=bad)
|
||||
|
||||
|
||||
def test_epilogue_registry_contract():
|
||||
"""EPILOGUE_OPS is the registry tl.composite validates against."""
|
||||
for kind, (required, scope) in EPILOGUE_OPS.items():
|
||||
assert isinstance(kind, str) and kind
|
||||
assert isinstance(required, tuple)
|
||||
assert isinstance(scope, Scope)
|
||||
|
||||
|
||||
def test_composite_epilogue_e2e():
|
||||
"""Drive a GEMM + bias + relu composite through the simulator and check
|
||||
op_log: exactly one MATH(bias) and one MATH(relu) record, ordered after
|
||||
GEMM and before STORE for the single (m,n) output tile."""
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo_path = Path(__file__).parent.parent / "topology.yaml"
|
||||
topo = resolve_topology(str(topo_path))
|
||||
device = resolve_device(None)
|
||||
|
||||
def _kernel(a_ptr, b_ptr, bias_ptr, out_ptr, tl):
|
||||
a = tl.ref(int(a_ptr), shape=(32, 32), dtype="f16")
|
||||
b = tl.ref(int(b_ptr), shape=(32, 32), dtype="f16")
|
||||
bias = tl.load(int(bias_ptr), shape=(32,), dtype="f16")
|
||||
h = tl.composite(
|
||||
op="gemm", a=a, b=b, out_ptr=int(out_ptr),
|
||||
epilogue=[{"op": "bias", "bias": bias}, {"op": "relu"}],
|
||||
)
|
||||
tl.wait(h)
|
||||
|
||||
def _bench(torch):
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1)
|
||||
a = torch.empty((32, 32), dtype="f16", dp=dp, name="a")
|
||||
b = torch.empty((32, 32), dtype="f16", dp=dp, name="b")
|
||||
bias = torch.empty((32,), dtype="f16", dp=dp, name="bias")
|
||||
out = torch.empty((32, 32), dtype="f16", dp=dp, name="out")
|
||||
torch.launch("composite_epi", _kernel, a, b, bias, out)
|
||||
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=_bench, device=device,
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True,
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
|
||||
math = [r for r in result.engine.op_log
|
||||
if r.params.get("stage_type") == "MATH"]
|
||||
assert [r.params.get("op_kind") for r in math] == ["bias", "relu"]
|
||||
|
||||
gemm = [r for r in result.engine.op_log
|
||||
if r.params.get("stage_type") == "GEMM"]
|
||||
store = [r for r in result.engine.op_log
|
||||
if r.params.get("stage_type") == "STORE"]
|
||||
assert gemm and store
|
||||
assert gemm[0].t_end <= math[0].t_start + 1e-6
|
||||
assert math[-1].t_end <= store[0].t_start + 1e-6
|
||||
Reference in New Issue
Block a user