"""Tests for multi-op tl.composite() with epilogue scopes. Public-surface tests only: we exercise tl.composite() and inspect the resulting CompositeCmd. Validation, plan generation, and scheduling are covered implicitly — they're internal to tl_context / pe_scheduler / tiling. """ from __future__ import annotations import pytest from kernbench.common.pe_commands import ( EPILOGUE_OPS, CompositeCmd, Scope, TensorHandle, ) from kernbench.triton_emu.tl_context import TLContext def _h(idx: int, shape=(32, 32)) -> TensorHandle: nbytes = 1 for d in shape: nbytes *= d return TensorHandle( id=f"h{idx}", addr=0x1000 + idx * 0x100, shape=shape, dtype="f16", nbytes=nbytes * 2, ) def test_composite_epilogue_roundtrip(): """tl.composite() with mixed-scope epilogue produces a CompositeCmd whose ops tuple preserves order, kinds, and default scopes.""" tl = TLContext() a, b = _h(0), _h(1) bias = _h(2, shape=(32,)) scale = _h(3, shape=(2,)) tl.composite( op="gemm", a=a, b=b, out_ptr=0x2000, epilogue=[ {"op": "bias", "bias": bias}, # default OUTPUT_TILE {"op": "dequant", "scale": scale}, # default K_TILE {"op": "relu"}, # default OUTPUT_TILE {"op": "scale", "factor": 0.5}, ], ) cmd = tl._commands[-1] assert isinstance(cmd, CompositeCmd) kinds_scopes = [(o.kind, o.scope) for o in cmd.ops] assert kinds_scopes == [ ("gemm", Scope.OUTPUT_TILE), ("bias", Scope.OUTPUT_TILE), ("dequant", Scope.K_TILE), ("relu", Scope.OUTPUT_TILE), ("scale", Scope.OUTPUT_TILE), ] # Single-op call (no epilogue) keeps the legacy code path: ops stays empty. tl2 = TLContext() tl2.composite(op="gemm", a=a, b=b, out_ptr=0x2000) cmd2 = tl2._commands[-1] assert isinstance(cmd2, CompositeCmd) assert cmd2.ops == () @pytest.mark.parametrize("bad,match", [ ([{"op": "biass"}], "unknown op 'biass'"), ([{"op": "bias"}], "missing required field"), (["relu"], "must be a dict"), ]) def test_composite_epilogue_rejects_bad_input(bad, match): tl = TLContext() with pytest.raises(ValueError, match=match): tl.composite(op="gemm", a=_h(0), b=_h(1), out_ptr=0x2000, epilogue=bad) def test_epilogue_registry_contract(): """EPILOGUE_OPS is the registry tl.composite validates against.""" for kind, (required, scope) in EPILOGUE_OPS.items(): assert isinstance(kind, str) and kind assert isinstance(required, tuple) assert isinstance(scope, Scope) def test_composite_epilogue_e2e(): """Drive a GEMM + bias + relu composite through the simulator and check op_log: exactly one MATH(bias) and one MATH(relu) record, ordered after GEMM and before STORE for the single (m,n) output tile.""" from pathlib import Path from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology topo_path = Path(__file__).parent.parent / "topology.yaml" topo = resolve_topology(str(topo_path)) device = resolve_device(None) def _kernel(a_ptr, b_ptr, bias_ptr, out_ptr, tl): a = tl.ref(int(a_ptr), shape=(32, 32), dtype="f16") b = tl.ref(int(b_ptr), shape=(32, 32), dtype="f16") bias = tl.load(int(bias_ptr), shape=(32,), dtype="f16") h = tl.composite( op="gemm", a=a, b=b, out_ptr=int(out_ptr), epilogue=[{"op": "bias", "bias": bias}, {"op": "relu"}], ) tl.wait(h) def _bench(torch): from kernbench.policy.placement.dp import DPPolicy dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) a = torch.empty((32, 32), dtype="f16", dp=dp, name="a") b = torch.empty((32, 32), dtype="f16", dp=dp, name="b") bias = torch.empty((32,), dtype="f16", dp=dp, name="bias") out = torch.empty((32, 32), dtype="f16", dp=dp, name="out") torch.launch("composite_epi", _kernel, a, b, bias, out) result = run_bench( topology=topo, bench_fn=_bench, device=device, engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True, ), ) assert result.completion.ok math = [r for r in result.engine.op_log if r.params.get("stage_type") == "MATH"] assert [r.params.get("op_kind") for r in math] == ["bias", "relu"] gemm = [r for r in result.engine.op_log if r.params.get("stage_type") == "GEMM"] store = [r for r in result.engine.op_log if r.params.get("stage_type") == "STORE"] assert gemm and store assert gemm[0].t_end <= math[0].t_start + 1e-6 assert math[-1].t_end <= store[0].t_start + 1e-6