From a7fe785e5fed17ff15a5db35c0fb40175641fdb1 Mon Sep 17 00:00:00 2001 From: Mukesh Garg Date: Fri, 15 May 2026 10:16:47 -0700 Subject: [PATCH] tl.composite: fused epilogue ops with per-op scope Extend tl.composite() with an ordered epilogue list. Each op carries a scope flag - output_tile (default, runs once per (m,n) before STORE), k_tile (every K-tile right after GEMM), or kernel. Plan generator slots MATH stages by scope; pe_math reuses pe_dma's local-loop pattern so chained epilogues (bias->relu) skip the port hop. op_log captures per-stage params for telemetry. Topology gains a gemm->math edge (snapshot test updated). API stays backward-compatible - `epilogue=` is opt-in. Example: h = tl.composite( op="gemm", a=a, b=b, out_ptr=int(out), epilogue=[ {"op": "dequant", "scale": s_per_k, "scope": "k_tile"}, {"op": "bias", "bias": bias_vec}, {"op": "relu"}, {"op": "scale", "factor": 0.5}, ], ) tl.wait(h) Co-Authored-By: Claude Opus 4.7 (1M context) --- .claude/settings.json | 54 ++++++- .../allreduce_latency_plots/topology.png | Bin 198707 -> 198707 bytes src/kernbench/common/pe_commands.py | 42 ++++++ src/kernbench/components/builtin/pe_math.py | 37 +++-- .../components/builtin/pe_scheduler.py | 4 + src/kernbench/components/builtin/tiling.py | 37 ++++- src/kernbench/sim_engine/op_log.py | 11 +- src/kernbench/topology/builder.py | 1 + src/kernbench/triton_emu/tl_context.py | 71 ++++++++- tests/test_composite_epilogue.py | 140 ++++++++++++++++++ tests/test_topology_compile.py | 4 +- topology.yaml | 1 + 12 files changed, 382 insertions(+), 20 deletions(-) create mode 100644 tests/test_composite_epilogue.py diff --git a/.claude/settings.json b/.claude/settings.json index bb31422..4d5bfc2 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -9,7 +9,59 @@ "Bash(python -m kernbench.cli.main probe --topology topology.yaml)", "Bash(xargs grep -l \"class.*ComponentBase\\\\|class.*DefaultComponent\")", "Bash(python -m pytest tests/test_probe.py -v)", - "Bash(python -m pytest tests/test_probe.py tests/test_component_registry.py -v)" + "Bash(python -m pytest tests/test_probe.py tests/test_component_registry.py -v)", + "Bash(python -m pytest -o \"addopts=\" --no-header tests/test_intercube_root_center.py)", + "Bash(python -m pytest -o \"addopts=\" --no-header tests/test_tp_layers.py tests/test_tp_mlp.py)", + "Bash(git commit -m ' *)", + "Bash(git stash *)", + "Bash(python scripts/emit_overview_with_external_ref.py)", + "Bash(where inkscape *)", + "Bash(\"/c/Program Files \\(x86\\)/Microsoft/Edge/Application/msedge.exe\" --headless --disable-gpu --screenshot=\"$\\(pwd\\)/docs/diagrams/cube_mesh_view.png\" --window-size=1400,1300 \"file:///$\\(pwd)", + "Bash(python scripts/build_overview_slides.py)", + "Bash(git fetch *)", + "Bash(git pull *)", + "Bash(python -m pytest --no-header tests/test_allreduce_buffer_kind_sweep.py)", + "Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py)", + "Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_locations.py -v)", + "Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_locations.py tests/test_ipcq_buffer_kind_latency.py tests/test_allreduce_buffer_kind_sweep.py)", + "Bash(git checkout *)", + "Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_latency.py::test_slot_write_latency_orders_tcm_hbm_sram)", + "Bash(python scripts/emit_ipcq_send_recv_model_plots.py)", + "Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py -x)", + "Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py tests/test_ipcq_buffer_kind_locations.py tests/test_ipcq_buffer_kind_latency.py tests/test_allreduce_buffer_kind_sweep.py)", + "Bash(kill %1)", + "Bash(awk '{print $2}')", + "Bash(xargs -r kill)", + "Bash(python scripts/_debug_op_log.py)", + "Bash(SWEEP_SHAPES=\"16,32,64,128,256\" python scripts/gemm_sweep.py)", + "Bash(python scripts/plot_gemm_sweep.py)", + "Bash(python scripts/gemm_sweep.py)", + "Bash(python scripts/gen_pe_pipeline_diagram.py)", + "Bash(python scripts/gen_matmul_32x128x32_diagram.py)", + "Bash(python -m pytest tests/test_pe_pipeline.py -x --tb=short)", + "Bash(python -m pytest tests/test_pe_pipeline.py tests/test_e2e_pipeline.py tests/test_op_log.py -x --tb=short -q)", + "Bash(ls -la C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/ 2>&1 | head -20)", + "Read(//c/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/**)", + "Bash(awk 'NR==1812 || NR==1815' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(awk 'NR==1058' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(awk -F: '$1 > 1700 && $1 < 1815 {print $1}')", + "Bash(awk 'NR==1812' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(awk 'NR>=1815 && NR<=1825' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(awk 'NR>1815' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(awk 'NR==1839' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)", + "Bash(git log *)", + "Bash(python -m pytest tests/test_op_log.py tests/test_pe_components.py tests/test_pe_pipeline.py -x --tb=short)", + "Bash(python -m pytest tests/test_pe_to_pe_latency.py tests/test_e2e_pipeline.py tests/test_e2e_data.py tests/test_data_executor.py tests/test_pe_dma_ipcq.py -x --tb=short)", + "Bash(python -m pytest tests/test_pe_pipeline.py::test_pe_dma_record_start_after_channel_acquire -x --tb=long)", + "Bash(python -m pytest tests/test_pe_pipeline.py::test_pe_dma_record_start_after_channel_acquire -x --tb=short)", + "Bash(python -m pytest tests/test_op_log.py tests/test_pe_components.py tests/test_pe_pipeline.py tests/test_pe_to_pe_latency.py tests/test_e2e_pipeline.py tests/test_e2e_data.py tests/test_data_executor.py tests/test_pe_dma_ipcq.py --tb=short)", + "Bash(python -m pytest tests/test_pe_pipeline.py -q)", + "Bash(python -m pytest tests/test_pe_pipeline.py tests/test_triton_emu.py -q)", + "Bash(python -m pytest tests/test_composite_epilogue.py -v)" + ], + "additionalDirectories": [ + "c:\\Users\\mukes\\Mukesh\\ywkang_git\\kernbench2\\tests", + "C:\\Users\\mukes\\Mukesh\\ywkang_git\\kernbench2\\tests\\pe2pe_latency_plots" ] } } diff --git a/docs/diagrams/allreduce_latency_plots/topology.png b/docs/diagrams/allreduce_latency_plots/topology.png index 1990768987a9e4cc6e6aaba31b3128d1427555c0..40e8719689b5078b4672abc22cdeda5afa75ecf1 100644 GIT binary patch delta 52 zcmdlyfoJmso(Z0E7CH(UB_##LR{Hw6i6sR&`6W4-NqYH3>G}twOV2f?x27|;rZa6# IXFkIY00}M=)&Kwi delta 52 zcmdlyfoJmso(Z0EmO2UH2DJmm?d~Thkd^)0wuW IGoN7x0Q}1n5C8xG diff --git a/src/kernbench/common/pe_commands.py b/src/kernbench/common/pe_commands.py index 72c245e..1c47c4f 100644 --- a/src/kernbench/common/pe_commands.py +++ b/src/kernbench/common/pe_commands.py @@ -9,12 +9,50 @@ Command lifecycle: from __future__ import annotations from dataclasses import dataclass, field +from enum import Enum from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: import simpy +class Scope(Enum): + K_TILE = "k_tile" + OUTPUT_TILE = "output_tile" + KERNEL = "kernel" + + +@dataclass(frozen=True) +class OpSpec: + """One operation in a multi-op composite (head + epilogue, ADR-0021). + + The head op (first in CompositeCmd.ops) defines tile geometry; subsequent + ops are epilogue stages whose ``scope`` controls how often they fire + (per-K-tile, per-output-tile, or once per kernel). + """ + + kind: str # "gemm" | "bias" | "relu" | ... + scope: "Scope" = Scope.OUTPUT_TILE + operands: tuple[Any, ...] = () # tuple[TensorHandle, ...] + scalar: float | None = None + extra: dict[str, Any] = field(default_factory=dict) + + +# Epilogue op contracts: kind → (required field names, default scope). +# Used by tl.composite to validate user-provided epilogue dicts at +# command-emit time so typos fail before reaching the scheduler. +EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = { + "bias": (("bias",), Scope.OUTPUT_TILE), + "relu": ((), Scope.OUTPUT_TILE), + "gelu": ((), Scope.OUTPUT_TILE), + "sigmoid": ((), Scope.OUTPUT_TILE), + "scale": (("factor",), Scope.OUTPUT_TILE), + "clamp": (("lo", "hi"), Scope.OUTPUT_TILE), + "dequant": (("scale",), Scope.K_TILE), + "add": (("other",), Scope.OUTPUT_TILE), +} + + # ── Handles ─────────────────────────────────────────────────────── @@ -118,6 +156,10 @@ class CompositeCmd: out_nbytes: int math_op: str | None = None # for op="math": which math operation data_op: bool = True + # Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the + # head and ops[1:] are epilogue stages with explicit scope. When empty, + # the legacy single-op semantics (op/a/b/math_op) apply. + ops: tuple[OpSpec, ...] = () @dataclass(frozen=True) diff --git a/src/kernbench/components/builtin/pe_math.py b/src/kernbench/components/builtin/pe_math.py index ff9e142..cf5bcf9 100644 --- a/src/kernbench/components/builtin/pe_math.py +++ b/src/kernbench/components/builtin/pe_math.py @@ -68,29 +68,36 @@ class PeMathComponent(PeEngineBase): env.process(self._forward_txn(env, msg)) def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: - """Pipeline mode: pure SIMD compute, then self-route.""" - self._on_process_start(env, token) + """Pipeline mode: pure SIMD compute, then self-route. - num_elements = token.params.get("num_elements", 0) + Handles consecutive same-component MATH stages (e.g. chained + epilogue ops bias→relu) by looping locally before forwarding, + mirroring the pattern in PeDmaComponent._pipeline_process. + """ + yield from self._do_compute(env, token) - if self._accel: - with self._accel.request() as req: - yield req - ns = self._compute_ns(num_elements) - yield env.timeout(ns) - else: - ns = self._compute_ns(num_elements) - yield env.timeout(ns) - - self._on_process_end(env, token) - - # Self-routing next_stage = token.advance() + while next_stage is not None and next_stage.component == self.node.id: + yield from self._do_compute(env, token) + next_stage = token.advance() + if next_stage is not None: yield self.out_ports[next_stage.component].put(token) else: token.pipeline_ctx.complete_tile() + def _do_compute(self, env: simpy.Environment, token: Any) -> Generator: + """Single MATH stage: latency model + op_log bracketing.""" + self._on_process_start(env, token) + num_elements = token.params.get("num_elements", 0) + if self._accel: + with self._accel.request() as req: + yield req + yield env.timeout(self._compute_ns(num_elements)) + else: + yield env.timeout(self._compute_ns(num_elements)) + self._on_process_end(env, token) + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """PeInternalTxn handling for standalone MathCmd (CCL kernels). diff --git a/src/kernbench/components/builtin/pe_scheduler.py b/src/kernbench/components/builtin/pe_scheduler.py index 4f50dd5..994acfb 100644 --- a/src/kernbench/components/builtin/pe_scheduler.py +++ b/src/kernbench/components/builtin/pe_scheduler.py @@ -157,6 +157,9 @@ class PeSchedulerComponent(ComponentBase): b = cmd.b M, K = a.shape[-2], a.shape[-1] N = b.shape[-1] + # When CompositeCmd.ops is populated, ops[0] is the head and + # ops[1:] is the epilogue spec list. Empty ops → legacy path. + epi_specs = tuple(cmd.ops[1:]) if cmd.ops else () return generate_gemm_plan( M=M, K=K, N=N, tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N, @@ -165,6 +168,7 @@ class PeSchedulerComponent(ComponentBase): pe_prefix=pp, a_pinned=getattr(a, "pinned", False), b_pinned=getattr(b, "pinned", False), + epilogue_specs=epi_specs, ) else: # Math composite diff --git a/src/kernbench/components/builtin/tiling.py b/src/kernbench/components/builtin/tiling.py index 321da19..88884a6 100644 --- a/src/kernbench/components/builtin/tiling.py +++ b/src/kernbench/components/builtin/tiling.py @@ -23,6 +23,7 @@ def generate_gemm_plan( pe_prefix: str, a_pinned: bool = False, b_pinned: bool = False, + epilogue_specs: tuple = (), ) -> PipelinePlan: """Generate GEMM tile plan: M→N→K order. @@ -46,7 +47,15 @@ def generate_gemm_plan( dma_id = f"{pe_prefix}.pe_dma" fetch_id = f"{pe_prefix}.pe_fetch_store" gemm_id = f"{pe_prefix}.pe_gemm" - # math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed + math_id = f"{pe_prefix}.pe_math" + + # Split epilogue_specs by scope. Lazy import to avoid circular dep + # between pe_commands ← pe_types ← tiling. + from kernbench.common.pe_commands import Scope as _Scope + k_tile_ops = [o for o in epilogue_specs + if getattr(o, "scope", None) == _Scope.K_TILE] + out_tile_ops = [o for o in epilogue_specs + if getattr(o, "scope", None) == _Scope.OUTPUT_TILE] tiles: list[TilePlan] = [] tile_id = 0 @@ -106,9 +115,35 @@ def generate_gemm_plan( }, )) + # K-tile-scope epilogue MATH stages (e.g. dequant) — run on + # every K-tile right after GEMM, before the accumulator + # advances to the next K slice. + for op in k_tile_ops: + stages.append(Stage( + stage_type=StageType.MATH, + component=math_id, + params={ + "op_kind": op.kind, + "num_elements": tile_m * tile_n, + "scope": "k_tile", + }, + )) + # STORE + DMA_WRITE only on last K-tile per (m,n). The C # accumulator stays in RegFile across the K loop. if last_k: + # Output-tile-scope epilogue MATH (bias, relu, ...) runs + # ONCE per (m,n) after the final K-tile, before writeback. + for op in out_tile_ops: + stages.append(Stage( + stage_type=StageType.MATH, + component=math_id, + params={ + "op_kind": op.kind, + "num_elements": tile_m * tile_n, + "scope": "output_tile", + }, + )) stages.append(Stage( stage_type=StageType.STORE, component=fetch_id, diff --git a/src/kernbench/sim_engine/op_log.py b/src/kernbench/sim_engine/op_log.py index 9e083c5..acc0d5d 100644 --- a/src/kernbench/sim_engine/op_log.py +++ b/src/kernbench/sim_engine/op_log.py @@ -51,11 +51,15 @@ class OpLogger: record_end fires. """ snap: dict[str, Any] = {} - # TileToken (ADR-0021 pipeline) — capture which stage this is. + # TileToken (ADR-0021 pipeline) — capture which stage this is and its + # per-stage params (e.g. op_kind/scope for epilogue MATH stages) so + # we can recover them at record_end even after the token advances. try: stage = getattr(msg, "current_stage", None) if stage is not None: snap["stage_type"] = stage.stage_type.name + if isinstance(getattr(stage, "params", None), dict): + snap["stage_params"] = dict(stage.params) except Exception: pass self._pending[id(msg)] = { @@ -78,6 +82,11 @@ class OpLogger: stage_type = snap.get("stage_type") if stage_type is not None: params = dict(params) + # Merge per-stage params (e.g. op_kind, scope) captured at start. + stage_params = snap.get("stage_params") + if isinstance(stage_params, dict): + for k, v in stage_params.items(): + params.setdefault(k, v) params["stage_type"] = stage_type if op_name == "TileToken": op_name = f"TileToken/{stage_type}" diff --git a/src/kernbench/topology/builder.py b/src/kernbench/topology/builder.py index b42b108..8516b3b 100644 --- a/src/kernbench/topology/builder.py +++ b/src/kernbench/topology/builder.py @@ -700,6 +700,7 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None: ("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"), ("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"), ("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"), + ("pe_gemm", "pe_math", "gemm_to_math_mm"), ("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"), ("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"), ] diff --git a/src/kernbench/triton_emu/tl_context.py b/src/kernbench/triton_emu/tl_context.py index bc1984d..67764b1 100644 --- a/src/kernbench/triton_emu/tl_context.py +++ b/src/kernbench/triton_emu/tl_context.py @@ -19,14 +19,17 @@ from typing import Literal from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture from kernbench.common.pe_commands import ( + EPILOGUE_OPS, CompletionHandle, CompositeCmd, DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, + OpSpec, PeCommand, PeCpuOverheadCmd, + Scope, TensorHandle, WaitCmd, ) @@ -565,9 +568,18 @@ class TLContext: b: TensorHandle | None = None, out_ptr: int = 0, math_op: str | None = None, + *, + epilogue: list[dict] | None = None, + acc_dtype: str | None = None, + tile_shape: tuple[int, int, int] | None = None, ) -> CompletionHandle: """Submit a composite command (non-blocking, tiled pipeline). + Optional ``epilogue`` is an ordered list of dicts; each dict has a + required ``"op"`` key (one of ``EPILOGUE_OPS``) plus op-specific + fields and an optional ``"scope"``. Validation happens here so + typos fail before the command is emitted. + Returns CompletionHandle for use with wait(). """ # Compute output size based on op @@ -579,15 +591,72 @@ class TLContext: else: out_nbytes = a.nbytes + ops_tuple: tuple[OpSpec, ...] = () + if epilogue is not None: + head_operands = (a, b) if (op == "gemm" and b is not None) else (a,) + head_spec = OpSpec( + kind=op, scope=Scope.OUTPUT_TILE, operands=head_operands, + extra={ + k: v for k, v in ( + ("acc_dtype", acc_dtype), + ("tile_shape", tile_shape), + ("math_op", math_op), + ) if v is not None + }, + ) + epi_specs = tuple(self._build_epilogue_spec(e, i) + for i, e in enumerate(epilogue)) + ops_tuple = (head_spec, *epi_specs) + completion = CompletionHandle(id=self._next_completion_id()) self._emit_dispatch_overhead() self._emit(CompositeCmd( completion=completion, op=op, a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes, - math_op=math_op, + math_op=math_op, ops=ops_tuple, )) return completion + @staticmethod + def _build_epilogue_spec(entry: dict, idx: int) -> OpSpec: + if not isinstance(entry, dict) or "op" not in entry: + raise ValueError( + f"epilogue[{idx}]: each entry must be a dict with an 'op' key" + ) + kind = entry["op"] + if kind not in EPILOGUE_OPS: + known = ", ".join(sorted(EPILOGUE_OPS)) + raise ValueError( + f"epilogue[{idx}]: unknown op {kind!r} " + f"(known ops: {known})" + ) + required, default_scope = EPILOGUE_OPS[kind] + missing = [f for f in required if f not in entry] + if missing: + raise ValueError( + f"epilogue[{idx}] op {kind!r} missing required field(s): " + f"{', '.join(missing)}" + ) + scope = Scope(entry["scope"]) if "scope" in entry else default_scope + operands: list = [] + scalar: float | None = None + extra: dict = {} + for f in required: + v = entry[f] + if isinstance(v, TensorHandle): + operands.append(v) + elif isinstance(v, (int, float)): + if scalar is None: + scalar = float(v) + else: + extra[f] = v + else: + extra[f] = v + return OpSpec( + kind=kind, scope=scope, + operands=tuple(operands), scalar=scalar, extra=extra, + ) + def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any: """Wait for a composite, a recv future, or all pending composites. diff --git a/tests/test_composite_epilogue.py b/tests/test_composite_epilogue.py new file mode 100644 index 0000000..b4a4a5f --- /dev/null +++ b/tests/test_composite_epilogue.py @@ -0,0 +1,140 @@ +"""Tests for multi-op tl.composite() with epilogue scopes. + +Public-surface tests only: we exercise tl.composite() and inspect the +resulting CompositeCmd. Validation, plan generation, and scheduling are +covered implicitly — they're internal to tl_context / pe_scheduler / tiling. +""" +from __future__ import annotations + +import pytest + +from kernbench.common.pe_commands import ( + EPILOGUE_OPS, + CompositeCmd, + Scope, + TensorHandle, +) +from kernbench.triton_emu.tl_context import TLContext + + +def _h(idx: int, shape=(32, 32)) -> TensorHandle: + nbytes = 1 + for d in shape: + nbytes *= d + return TensorHandle( + id=f"h{idx}", addr=0x1000 + idx * 0x100, + shape=shape, dtype="f16", nbytes=nbytes * 2, + ) + + +def test_composite_epilogue_roundtrip(): + """tl.composite() with mixed-scope epilogue produces a CompositeCmd whose + ops tuple preserves order, kinds, and default scopes.""" + tl = TLContext() + a, b = _h(0), _h(1) + bias = _h(2, shape=(32,)) + scale = _h(3, shape=(2,)) + + tl.composite( + op="gemm", a=a, b=b, out_ptr=0x2000, + epilogue=[ + {"op": "bias", "bias": bias}, # default OUTPUT_TILE + {"op": "dequant", "scale": scale}, # default K_TILE + {"op": "relu"}, # default OUTPUT_TILE + {"op": "scale", "factor": 0.5}, + ], + ) + cmd = tl._commands[-1] + assert isinstance(cmd, CompositeCmd) + + kinds_scopes = [(o.kind, o.scope) for o in cmd.ops] + assert kinds_scopes == [ + ("gemm", Scope.OUTPUT_TILE), + ("bias", Scope.OUTPUT_TILE), + ("dequant", Scope.K_TILE), + ("relu", Scope.OUTPUT_TILE), + ("scale", Scope.OUTPUT_TILE), + ] + + # Single-op call (no epilogue) keeps the legacy code path: ops stays empty. + tl2 = TLContext() + tl2.composite(op="gemm", a=a, b=b, out_ptr=0x2000) + cmd2 = tl2._commands[-1] + assert isinstance(cmd2, CompositeCmd) + assert cmd2.ops == () + + +@pytest.mark.parametrize("bad,match", [ + ([{"op": "biass"}], "unknown op 'biass'"), + ([{"op": "bias"}], "missing required field"), + (["relu"], "must be a dict"), +]) +def test_composite_epilogue_rejects_bad_input(bad, match): + tl = TLContext() + with pytest.raises(ValueError, match=match): + tl.composite(op="gemm", a=_h(0), b=_h(1), out_ptr=0x2000, + epilogue=bad) + + +def test_epilogue_registry_contract(): + """EPILOGUE_OPS is the registry tl.composite validates against.""" + for kind, (required, scope) in EPILOGUE_OPS.items(): + assert isinstance(kind, str) and kind + assert isinstance(required, tuple) + assert isinstance(scope, Scope) + + +def test_composite_epilogue_e2e(): + """Drive a GEMM + bias + relu composite through the simulator and check + op_log: exactly one MATH(bias) and one MATH(relu) record, ordered after + GEMM and before STORE for the single (m,n) output tile.""" + from pathlib import Path + + from kernbench.runtime_api.bench_runner import run_bench + from kernbench.runtime_api.types import resolve_device + from kernbench.sim_engine.engine import GraphEngine + from kernbench.topology.builder import resolve_topology + + topo_path = Path(__file__).parent.parent / "topology.yaml" + topo = resolve_topology(str(topo_path)) + device = resolve_device(None) + + def _kernel(a_ptr, b_ptr, bias_ptr, out_ptr, tl): + a = tl.ref(int(a_ptr), shape=(32, 32), dtype="f16") + b = tl.ref(int(b_ptr), shape=(32, 32), dtype="f16") + bias = tl.load(int(bias_ptr), shape=(32,), dtype="f16") + h = tl.composite( + op="gemm", a=a, b=b, out_ptr=int(out_ptr), + epilogue=[{"op": "bias", "bias": bias}, {"op": "relu"}], + ) + tl.wait(h) + + def _bench(torch): + from kernbench.policy.placement.dp import DPPolicy + dp = DPPolicy(cube="replicate", pe="replicate", + num_cubes=1, num_pes=1) + a = torch.empty((32, 32), dtype="f16", dp=dp, name="a") + b = torch.empty((32, 32), dtype="f16", dp=dp, name="b") + bias = torch.empty((32,), dtype="f16", dp=dp, name="bias") + out = torch.empty((32, 32), dtype="f16", dp=dp, name="out") + torch.launch("composite_epi", _kernel, a, b, bias, out) + + result = run_bench( + topology=topo, bench_fn=_bench, device=device, + engine_factory=lambda t, d: GraphEngine( + getattr(t, "topology_obj", t), enable_data=True, + ), + ) + assert result.completion.ok + + math = [r for r in result.engine.op_log + if r.params.get("stage_type") == "MATH"] + assert [r.params.get("op_kind") for r in math] == ["bias", "relu"] + + gemm = [r for r in result.engine.op_log + if r.params.get("stage_type") == "GEMM"] + store = [r for r in result.engine.op_log + if r.params.get("stage_type") == "STORE"] + assert gemm and store + assert gemm[0].t_end <= math[0].t_start + 1e-6 + assert math[-1].t_end <= store[0].t_start + 1e-6 diff --git a/tests/test_topology_compile.py b/tests/test_topology_compile.py index bbfca48..77fe943 100644 --- a/tests/test_topology_compile.py +++ b/tests/test_topology_compile.py @@ -31,7 +31,9 @@ def test_full_graph_edge_count(): # ADR-0023: +3 IPCQ edges per PE # ADR-0019 D1 (restored): HBM↔router edges drop from 32 routers × 2 # to 8 PE-routers × 2 per cube. 32 cubes × (16-64) = -1536 edges. - assert len(g.edges) == 12156 + # Multi-op composite (ADR-0021): +1 gemm→math edge per PE for + # epilogue chaining = 2 SIPs × 16 cubes × 8 PEs = +256 edges. + assert len(g.edges) == 12412 # -- Full graph: specific nodes exist ----------------------------------------- diff --git a/topology.yaml b/topology.yaml index b94f3fd..6ccb9c3 100644 --- a/topology.yaml +++ b/topology.yaml @@ -84,6 +84,7 @@ cube: fetch_store_to_gemm_mm: 0.0 # fetch → GEMM chaining (ADR-0021) fetch_store_to_math_mm: 0.0 # fetch → MATH chaining (ADR-0021) gemm_to_fetch_store_mm: 0.0 # GEMM → store chaining (ADR-0021) + gemm_to_math_mm: 0.0 # GEMM → MATH epilogue chaining (ADR-0021) math_to_fetch_store_mm: 0.0 # MATH → store chaining (ADR-0021) fetch_store_to_dma_mm: 0.0 # store → DMA writeback chaining (ADR-0021) gemm_to_tcm_bw_gbs: 512.0