tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+53
-1
@@ -9,7 +9,59 @@
|
|||||||
"Bash(python -m kernbench.cli.main probe --topology topology.yaml)",
|
"Bash(python -m kernbench.cli.main probe --topology topology.yaml)",
|
||||||
"Bash(xargs grep -l \"class.*ComponentBase\\\\|class.*DefaultComponent\")",
|
"Bash(xargs grep -l \"class.*ComponentBase\\\\|class.*DefaultComponent\")",
|
||||||
"Bash(python -m pytest tests/test_probe.py -v)",
|
"Bash(python -m pytest tests/test_probe.py -v)",
|
||||||
"Bash(python -m pytest tests/test_probe.py tests/test_component_registry.py -v)"
|
"Bash(python -m pytest tests/test_probe.py tests/test_component_registry.py -v)",
|
||||||
|
"Bash(python -m pytest -o \"addopts=\" --no-header tests/test_intercube_root_center.py)",
|
||||||
|
"Bash(python -m pytest -o \"addopts=\" --no-header tests/test_tp_layers.py tests/test_tp_mlp.py)",
|
||||||
|
"Bash(git commit -m ' *)",
|
||||||
|
"Bash(git stash *)",
|
||||||
|
"Bash(python scripts/emit_overview_with_external_ref.py)",
|
||||||
|
"Bash(where inkscape *)",
|
||||||
|
"Bash(\"/c/Program Files \\(x86\\)/Microsoft/Edge/Application/msedge.exe\" --headless --disable-gpu --screenshot=\"$\\(pwd\\)/docs/diagrams/cube_mesh_view.png\" --window-size=1400,1300 \"file:///$\\(pwd)",
|
||||||
|
"Bash(python scripts/build_overview_slides.py)",
|
||||||
|
"Bash(git fetch *)",
|
||||||
|
"Bash(git pull *)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_allreduce_buffer_kind_sweep.py)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_locations.py -v)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_locations.py tests/test_ipcq_buffer_kind_latency.py tests/test_allreduce_buffer_kind_sweep.py)",
|
||||||
|
"Bash(git checkout *)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_ipcq_buffer_kind_latency.py::test_slot_write_latency_orders_tcm_hbm_sram)",
|
||||||
|
"Bash(python scripts/emit_ipcq_send_recv_model_plots.py)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py -x)",
|
||||||
|
"Bash(python -m pytest --no-header tests/test_pe_to_pe_latency.py tests/test_ipcq_buffer_kind_locations.py tests/test_ipcq_buffer_kind_latency.py tests/test_allreduce_buffer_kind_sweep.py)",
|
||||||
|
"Bash(kill %1)",
|
||||||
|
"Bash(awk '{print $2}')",
|
||||||
|
"Bash(xargs -r kill)",
|
||||||
|
"Bash(python scripts/_debug_op_log.py)",
|
||||||
|
"Bash(SWEEP_SHAPES=\"16,32,64,128,256\" python scripts/gemm_sweep.py)",
|
||||||
|
"Bash(python scripts/plot_gemm_sweep.py)",
|
||||||
|
"Bash(python scripts/gemm_sweep.py)",
|
||||||
|
"Bash(python scripts/gen_pe_pipeline_diagram.py)",
|
||||||
|
"Bash(python scripts/gen_matmul_32x128x32_diagram.py)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py -x --tb=short)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py tests/test_e2e_pipeline.py tests/test_op_log.py -x --tb=short -q)",
|
||||||
|
"Bash(ls -la C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/ 2>&1 | head -20)",
|
||||||
|
"Read(//c/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/**)",
|
||||||
|
"Bash(awk 'NR==1812 || NR==1815' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(awk 'NR==1058' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(awk -F: '$1 > 1700 && $1 < 1815 {print $1}')",
|
||||||
|
"Bash(awk 'NR==1812' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(awk 'NR>=1815 && NR<=1825' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(awk 'NR>1815' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(awk 'NR==1839' C:/Users/mukes/.claude/projects/c--Users-mukes-Mukesh-ywkang-git-kernbench2/e55237ed-5c1f-4a89-a3b9-9b74fec45366.jsonl)",
|
||||||
|
"Bash(git log *)",
|
||||||
|
"Bash(python -m pytest tests/test_op_log.py tests/test_pe_components.py tests/test_pe_pipeline.py -x --tb=short)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_to_pe_latency.py tests/test_e2e_pipeline.py tests/test_e2e_data.py tests/test_data_executor.py tests/test_pe_dma_ipcq.py -x --tb=short)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py::test_pe_dma_record_start_after_channel_acquire -x --tb=long)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py::test_pe_dma_record_start_after_channel_acquire -x --tb=short)",
|
||||||
|
"Bash(python -m pytest tests/test_op_log.py tests/test_pe_components.py tests/test_pe_pipeline.py tests/test_pe_to_pe_latency.py tests/test_e2e_pipeline.py tests/test_e2e_data.py tests/test_data_executor.py tests/test_pe_dma_ipcq.py --tb=short)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py -q)",
|
||||||
|
"Bash(python -m pytest tests/test_pe_pipeline.py tests/test_triton_emu.py -q)",
|
||||||
|
"Bash(python -m pytest tests/test_composite_epilogue.py -v)"
|
||||||
|
],
|
||||||
|
"additionalDirectories": [
|
||||||
|
"c:\\Users\\mukes\\Mukesh\\ywkang_git\\kernbench2\\tests",
|
||||||
|
"C:\\Users\\mukes\\Mukesh\\ywkang_git\\kernbench2\\tests\\pe2pe_latency_plots"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 194 KiB After Width: | Height: | Size: 194 KiB |
@@ -9,12 +9,50 @@ Command lifecycle:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import simpy
|
import simpy
|
||||||
|
|
||||||
|
|
||||||
|
class Scope(Enum):
|
||||||
|
K_TILE = "k_tile"
|
||||||
|
OUTPUT_TILE = "output_tile"
|
||||||
|
KERNEL = "kernel"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OpSpec:
|
||||||
|
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
|
||||||
|
|
||||||
|
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
|
||||||
|
ops are epilogue stages whose ``scope`` controls how often they fire
|
||||||
|
(per-K-tile, per-output-tile, or once per kernel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind: str # "gemm" | "bias" | "relu" | ...
|
||||||
|
scope: "Scope" = Scope.OUTPUT_TILE
|
||||||
|
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
|
||||||
|
scalar: float | None = None
|
||||||
|
extra: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Epilogue op contracts: kind → (required field names, default scope).
|
||||||
|
# Used by tl.composite to validate user-provided epilogue dicts at
|
||||||
|
# command-emit time so typos fail before reaching the scheduler.
|
||||||
|
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
|
||||||
|
"bias": (("bias",), Scope.OUTPUT_TILE),
|
||||||
|
"relu": ((), Scope.OUTPUT_TILE),
|
||||||
|
"gelu": ((), Scope.OUTPUT_TILE),
|
||||||
|
"sigmoid": ((), Scope.OUTPUT_TILE),
|
||||||
|
"scale": (("factor",), Scope.OUTPUT_TILE),
|
||||||
|
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
|
||||||
|
"dequant": (("scale",), Scope.K_TILE),
|
||||||
|
"add": (("other",), Scope.OUTPUT_TILE),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ── Handles ───────────────────────────────────────────────────────
|
# ── Handles ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -118,6 +156,10 @@ class CompositeCmd:
|
|||||||
out_nbytes: int
|
out_nbytes: int
|
||||||
math_op: str | None = None # for op="math": which math operation
|
math_op: str | None = None # for op="math": which math operation
|
||||||
data_op: bool = True
|
data_op: bool = True
|
||||||
|
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
|
||||||
|
# head and ops[1:] are epilogue stages with explicit scope. When empty,
|
||||||
|
# the legacy single-op semantics (op/a/b/math_op) apply.
|
||||||
|
ops: tuple[OpSpec, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -68,29 +68,36 @@ class PeMathComponent(PeEngineBase):
|
|||||||
env.process(self._forward_txn(env, msg))
|
env.process(self._forward_txn(env, msg))
|
||||||
|
|
||||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||||
"""Pipeline mode: pure SIMD compute, then self-route."""
|
"""Pipeline mode: pure SIMD compute, then self-route.
|
||||||
self._on_process_start(env, token)
|
|
||||||
|
|
||||||
num_elements = token.params.get("num_elements", 0)
|
Handles consecutive same-component MATH stages (e.g. chained
|
||||||
|
epilogue ops bias→relu) by looping locally before forwarding,
|
||||||
|
mirroring the pattern in PeDmaComponent._pipeline_process.
|
||||||
|
"""
|
||||||
|
yield from self._do_compute(env, token)
|
||||||
|
|
||||||
if self._accel:
|
|
||||||
with self._accel.request() as req:
|
|
||||||
yield req
|
|
||||||
ns = self._compute_ns(num_elements)
|
|
||||||
yield env.timeout(ns)
|
|
||||||
else:
|
|
||||||
ns = self._compute_ns(num_elements)
|
|
||||||
yield env.timeout(ns)
|
|
||||||
|
|
||||||
self._on_process_end(env, token)
|
|
||||||
|
|
||||||
# Self-routing
|
|
||||||
next_stage = token.advance()
|
next_stage = token.advance()
|
||||||
|
while next_stage is not None and next_stage.component == self.node.id:
|
||||||
|
yield from self._do_compute(env, token)
|
||||||
|
next_stage = token.advance()
|
||||||
|
|
||||||
if next_stage is not None:
|
if next_stage is not None:
|
||||||
yield self.out_ports[next_stage.component].put(token)
|
yield self.out_ports[next_stage.component].put(token)
|
||||||
else:
|
else:
|
||||||
token.pipeline_ctx.complete_tile()
|
token.pipeline_ctx.complete_tile()
|
||||||
|
|
||||||
|
def _do_compute(self, env: simpy.Environment, token: Any) -> Generator:
|
||||||
|
"""Single MATH stage: latency model + op_log bracketing."""
|
||||||
|
self._on_process_start(env, token)
|
||||||
|
num_elements = token.params.get("num_elements", 0)
|
||||||
|
if self._accel:
|
||||||
|
with self._accel.request() as req:
|
||||||
|
yield req
|
||||||
|
yield env.timeout(self._compute_ns(num_elements))
|
||||||
|
else:
|
||||||
|
yield env.timeout(self._compute_ns(num_elements))
|
||||||
|
self._on_process_end(env, token)
|
||||||
|
|
||||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||||
"""PeInternalTxn handling for standalone MathCmd (CCL kernels).
|
"""PeInternalTxn handling for standalone MathCmd (CCL kernels).
|
||||||
|
|
||||||
|
|||||||
@@ -157,6 +157,9 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
b = cmd.b
|
b = cmd.b
|
||||||
M, K = a.shape[-2], a.shape[-1]
|
M, K = a.shape[-2], a.shape[-1]
|
||||||
N = b.shape[-1]
|
N = b.shape[-1]
|
||||||
|
# When CompositeCmd.ops is populated, ops[0] is the head and
|
||||||
|
# ops[1:] is the epilogue spec list. Empty ops → legacy path.
|
||||||
|
epi_specs = tuple(cmd.ops[1:]) if cmd.ops else ()
|
||||||
return generate_gemm_plan(
|
return generate_gemm_plan(
|
||||||
M=M, K=K, N=N,
|
M=M, K=K, N=N,
|
||||||
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
|
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
|
||||||
@@ -165,6 +168,7 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
pe_prefix=pp,
|
pe_prefix=pp,
|
||||||
a_pinned=getattr(a, "pinned", False),
|
a_pinned=getattr(a, "pinned", False),
|
||||||
b_pinned=getattr(b, "pinned", False),
|
b_pinned=getattr(b, "pinned", False),
|
||||||
|
epilogue_specs=epi_specs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Math composite
|
# Math composite
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def generate_gemm_plan(
|
|||||||
pe_prefix: str,
|
pe_prefix: str,
|
||||||
a_pinned: bool = False,
|
a_pinned: bool = False,
|
||||||
b_pinned: bool = False,
|
b_pinned: bool = False,
|
||||||
|
epilogue_specs: tuple = (),
|
||||||
) -> PipelinePlan:
|
) -> PipelinePlan:
|
||||||
"""Generate GEMM tile plan: M→N→K order.
|
"""Generate GEMM tile plan: M→N→K order.
|
||||||
|
|
||||||
@@ -46,7 +47,15 @@ def generate_gemm_plan(
|
|||||||
dma_id = f"{pe_prefix}.pe_dma"
|
dma_id = f"{pe_prefix}.pe_dma"
|
||||||
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
||||||
gemm_id = f"{pe_prefix}.pe_gemm"
|
gemm_id = f"{pe_prefix}.pe_gemm"
|
||||||
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
|
math_id = f"{pe_prefix}.pe_math"
|
||||||
|
|
||||||
|
# Split epilogue_specs by scope. Lazy import to avoid circular dep
|
||||||
|
# between pe_commands ← pe_types ← tiling.
|
||||||
|
from kernbench.common.pe_commands import Scope as _Scope
|
||||||
|
k_tile_ops = [o for o in epilogue_specs
|
||||||
|
if getattr(o, "scope", None) == _Scope.K_TILE]
|
||||||
|
out_tile_ops = [o for o in epilogue_specs
|
||||||
|
if getattr(o, "scope", None) == _Scope.OUTPUT_TILE]
|
||||||
|
|
||||||
tiles: list[TilePlan] = []
|
tiles: list[TilePlan] = []
|
||||||
tile_id = 0
|
tile_id = 0
|
||||||
@@ -106,9 +115,35 @@ def generate_gemm_plan(
|
|||||||
},
|
},
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# K-tile-scope epilogue MATH stages (e.g. dequant) — run on
|
||||||
|
# every K-tile right after GEMM, before the accumulator
|
||||||
|
# advances to the next K slice.
|
||||||
|
for op in k_tile_ops:
|
||||||
|
stages.append(Stage(
|
||||||
|
stage_type=StageType.MATH,
|
||||||
|
component=math_id,
|
||||||
|
params={
|
||||||
|
"op_kind": op.kind,
|
||||||
|
"num_elements": tile_m * tile_n,
|
||||||
|
"scope": "k_tile",
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
||||||
# accumulator stays in RegFile across the K loop.
|
# accumulator stays in RegFile across the K loop.
|
||||||
if last_k:
|
if last_k:
|
||||||
|
# Output-tile-scope epilogue MATH (bias, relu, ...) runs
|
||||||
|
# ONCE per (m,n) after the final K-tile, before writeback.
|
||||||
|
for op in out_tile_ops:
|
||||||
|
stages.append(Stage(
|
||||||
|
stage_type=StageType.MATH,
|
||||||
|
component=math_id,
|
||||||
|
params={
|
||||||
|
"op_kind": op.kind,
|
||||||
|
"num_elements": tile_m * tile_n,
|
||||||
|
"scope": "output_tile",
|
||||||
|
},
|
||||||
|
))
|
||||||
stages.append(Stage(
|
stages.append(Stage(
|
||||||
stage_type=StageType.STORE,
|
stage_type=StageType.STORE,
|
||||||
component=fetch_id,
|
component=fetch_id,
|
||||||
|
|||||||
@@ -51,11 +51,15 @@ class OpLogger:
|
|||||||
record_end fires.
|
record_end fires.
|
||||||
"""
|
"""
|
||||||
snap: dict[str, Any] = {}
|
snap: dict[str, Any] = {}
|
||||||
# TileToken (ADR-0021 pipeline) — capture which stage this is.
|
# TileToken (ADR-0021 pipeline) — capture which stage this is and its
|
||||||
|
# per-stage params (e.g. op_kind/scope for epilogue MATH stages) so
|
||||||
|
# we can recover them at record_end even after the token advances.
|
||||||
try:
|
try:
|
||||||
stage = getattr(msg, "current_stage", None)
|
stage = getattr(msg, "current_stage", None)
|
||||||
if stage is not None:
|
if stage is not None:
|
||||||
snap["stage_type"] = stage.stage_type.name
|
snap["stage_type"] = stage.stage_type.name
|
||||||
|
if isinstance(getattr(stage, "params", None), dict):
|
||||||
|
snap["stage_params"] = dict(stage.params)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._pending[id(msg)] = {
|
self._pending[id(msg)] = {
|
||||||
@@ -78,6 +82,11 @@ class OpLogger:
|
|||||||
stage_type = snap.get("stage_type")
|
stage_type = snap.get("stage_type")
|
||||||
if stage_type is not None:
|
if stage_type is not None:
|
||||||
params = dict(params)
|
params = dict(params)
|
||||||
|
# Merge per-stage params (e.g. op_kind, scope) captured at start.
|
||||||
|
stage_params = snap.get("stage_params")
|
||||||
|
if isinstance(stage_params, dict):
|
||||||
|
for k, v in stage_params.items():
|
||||||
|
params.setdefault(k, v)
|
||||||
params["stage_type"] = stage_type
|
params["stage_type"] = stage_type
|
||||||
if op_name == "TileToken":
|
if op_name == "TileToken":
|
||||||
op_name = f"TileToken/{stage_type}"
|
op_name = f"TileToken/{stage_type}"
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
|
|||||||
("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"),
|
("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"),
|
||||||
("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"),
|
("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"),
|
||||||
("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"),
|
("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"),
|
||||||
|
("pe_gemm", "pe_math", "gemm_to_math_mm"),
|
||||||
("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"),
|
("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"),
|
||||||
("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"),
|
("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,14 +19,17 @@ from typing import Literal
|
|||||||
|
|
||||||
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
|
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
|
||||||
from kernbench.common.pe_commands import (
|
from kernbench.common.pe_commands import (
|
||||||
|
EPILOGUE_OPS,
|
||||||
CompletionHandle,
|
CompletionHandle,
|
||||||
CompositeCmd,
|
CompositeCmd,
|
||||||
DmaReadCmd,
|
DmaReadCmd,
|
||||||
DmaWriteCmd,
|
DmaWriteCmd,
|
||||||
GemmCmd,
|
GemmCmd,
|
||||||
MathCmd,
|
MathCmd,
|
||||||
|
OpSpec,
|
||||||
PeCommand,
|
PeCommand,
|
||||||
PeCpuOverheadCmd,
|
PeCpuOverheadCmd,
|
||||||
|
Scope,
|
||||||
TensorHandle,
|
TensorHandle,
|
||||||
WaitCmd,
|
WaitCmd,
|
||||||
)
|
)
|
||||||
@@ -565,9 +568,18 @@ class TLContext:
|
|||||||
b: TensorHandle | None = None,
|
b: TensorHandle | None = None,
|
||||||
out_ptr: int = 0,
|
out_ptr: int = 0,
|
||||||
math_op: str | None = None,
|
math_op: str | None = None,
|
||||||
|
*,
|
||||||
|
epilogue: list[dict] | None = None,
|
||||||
|
acc_dtype: str | None = None,
|
||||||
|
tile_shape: tuple[int, int, int] | None = None,
|
||||||
) -> CompletionHandle:
|
) -> CompletionHandle:
|
||||||
"""Submit a composite command (non-blocking, tiled pipeline).
|
"""Submit a composite command (non-blocking, tiled pipeline).
|
||||||
|
|
||||||
|
Optional ``epilogue`` is an ordered list of dicts; each dict has a
|
||||||
|
required ``"op"`` key (one of ``EPILOGUE_OPS``) plus op-specific
|
||||||
|
fields and an optional ``"scope"``. Validation happens here so
|
||||||
|
typos fail before the command is emitted.
|
||||||
|
|
||||||
Returns CompletionHandle for use with wait().
|
Returns CompletionHandle for use with wait().
|
||||||
"""
|
"""
|
||||||
# Compute output size based on op
|
# Compute output size based on op
|
||||||
@@ -579,15 +591,72 @@ class TLContext:
|
|||||||
else:
|
else:
|
||||||
out_nbytes = a.nbytes
|
out_nbytes = a.nbytes
|
||||||
|
|
||||||
|
ops_tuple: tuple[OpSpec, ...] = ()
|
||||||
|
if epilogue is not None:
|
||||||
|
head_operands = (a, b) if (op == "gemm" and b is not None) else (a,)
|
||||||
|
head_spec = OpSpec(
|
||||||
|
kind=op, scope=Scope.OUTPUT_TILE, operands=head_operands,
|
||||||
|
extra={
|
||||||
|
k: v for k, v in (
|
||||||
|
("acc_dtype", acc_dtype),
|
||||||
|
("tile_shape", tile_shape),
|
||||||
|
("math_op", math_op),
|
||||||
|
) if v is not None
|
||||||
|
},
|
||||||
|
)
|
||||||
|
epi_specs = tuple(self._build_epilogue_spec(e, i)
|
||||||
|
for i, e in enumerate(epilogue))
|
||||||
|
ops_tuple = (head_spec, *epi_specs)
|
||||||
|
|
||||||
completion = CompletionHandle(id=self._next_completion_id())
|
completion = CompletionHandle(id=self._next_completion_id())
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._emit(CompositeCmd(
|
self._emit(CompositeCmd(
|
||||||
completion=completion, op=op,
|
completion=completion, op=op,
|
||||||
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
||||||
math_op=math_op,
|
math_op=math_op, ops=ops_tuple,
|
||||||
))
|
))
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_epilogue_spec(entry: dict, idx: int) -> OpSpec:
|
||||||
|
if not isinstance(entry, dict) or "op" not in entry:
|
||||||
|
raise ValueError(
|
||||||
|
f"epilogue[{idx}]: each entry must be a dict with an 'op' key"
|
||||||
|
)
|
||||||
|
kind = entry["op"]
|
||||||
|
if kind not in EPILOGUE_OPS:
|
||||||
|
known = ", ".join(sorted(EPILOGUE_OPS))
|
||||||
|
raise ValueError(
|
||||||
|
f"epilogue[{idx}]: unknown op {kind!r} "
|
||||||
|
f"(known ops: {known})"
|
||||||
|
)
|
||||||
|
required, default_scope = EPILOGUE_OPS[kind]
|
||||||
|
missing = [f for f in required if f not in entry]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"epilogue[{idx}] op {kind!r} missing required field(s): "
|
||||||
|
f"{', '.join(missing)}"
|
||||||
|
)
|
||||||
|
scope = Scope(entry["scope"]) if "scope" in entry else default_scope
|
||||||
|
operands: list = []
|
||||||
|
scalar: float | None = None
|
||||||
|
extra: dict = {}
|
||||||
|
for f in required:
|
||||||
|
v = entry[f]
|
||||||
|
if isinstance(v, TensorHandle):
|
||||||
|
operands.append(v)
|
||||||
|
elif isinstance(v, (int, float)):
|
||||||
|
if scalar is None:
|
||||||
|
scalar = float(v)
|
||||||
|
else:
|
||||||
|
extra[f] = v
|
||||||
|
else:
|
||||||
|
extra[f] = v
|
||||||
|
return OpSpec(
|
||||||
|
kind=kind, scope=scope,
|
||||||
|
operands=tuple(operands), scalar=scalar, extra=extra,
|
||||||
|
)
|
||||||
|
|
||||||
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
|
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
|
||||||
"""Wait for a composite, a recv future, or all pending composites.
|
"""Wait for a composite, a recv future, or all pending composites.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
"""Tests for multi-op tl.composite() with epilogue scopes.
|
||||||
|
|
||||||
|
Public-surface tests only: we exercise tl.composite() and inspect the
|
||||||
|
resulting CompositeCmd. Validation, plan generation, and scheduling are
|
||||||
|
covered implicitly — they're internal to tl_context / pe_scheduler / tiling.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernbench.common.pe_commands import (
|
||||||
|
EPILOGUE_OPS,
|
||||||
|
CompositeCmd,
|
||||||
|
Scope,
|
||||||
|
TensorHandle,
|
||||||
|
)
|
||||||
|
from kernbench.triton_emu.tl_context import TLContext
|
||||||
|
|
||||||
|
|
||||||
|
def _h(idx: int, shape=(32, 32)) -> TensorHandle:
|
||||||
|
nbytes = 1
|
||||||
|
for d in shape:
|
||||||
|
nbytes *= d
|
||||||
|
return TensorHandle(
|
||||||
|
id=f"h{idx}", addr=0x1000 + idx * 0x100,
|
||||||
|
shape=shape, dtype="f16", nbytes=nbytes * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_epilogue_roundtrip():
|
||||||
|
"""tl.composite() with mixed-scope epilogue produces a CompositeCmd whose
|
||||||
|
ops tuple preserves order, kinds, and default scopes."""
|
||||||
|
tl = TLContext()
|
||||||
|
a, b = _h(0), _h(1)
|
||||||
|
bias = _h(2, shape=(32,))
|
||||||
|
scale = _h(3, shape=(2,))
|
||||||
|
|
||||||
|
tl.composite(
|
||||||
|
op="gemm", a=a, b=b, out_ptr=0x2000,
|
||||||
|
epilogue=[
|
||||||
|
{"op": "bias", "bias": bias}, # default OUTPUT_TILE
|
||||||
|
{"op": "dequant", "scale": scale}, # default K_TILE
|
||||||
|
{"op": "relu"}, # default OUTPUT_TILE
|
||||||
|
{"op": "scale", "factor": 0.5},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cmd = tl._commands[-1]
|
||||||
|
assert isinstance(cmd, CompositeCmd)
|
||||||
|
|
||||||
|
kinds_scopes = [(o.kind, o.scope) for o in cmd.ops]
|
||||||
|
assert kinds_scopes == [
|
||||||
|
("gemm", Scope.OUTPUT_TILE),
|
||||||
|
("bias", Scope.OUTPUT_TILE),
|
||||||
|
("dequant", Scope.K_TILE),
|
||||||
|
("relu", Scope.OUTPUT_TILE),
|
||||||
|
("scale", Scope.OUTPUT_TILE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Single-op call (no epilogue) keeps the legacy code path: ops stays empty.
|
||||||
|
tl2 = TLContext()
|
||||||
|
tl2.composite(op="gemm", a=a, b=b, out_ptr=0x2000)
|
||||||
|
cmd2 = tl2._commands[-1]
|
||||||
|
assert isinstance(cmd2, CompositeCmd)
|
||||||
|
assert cmd2.ops == ()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad,match", [
|
||||||
|
([{"op": "biass"}], "unknown op 'biass'"),
|
||||||
|
([{"op": "bias"}], "missing required field"),
|
||||||
|
(["relu"], "must be a dict"),
|
||||||
|
])
|
||||||
|
def test_composite_epilogue_rejects_bad_input(bad, match):
|
||||||
|
tl = TLContext()
|
||||||
|
with pytest.raises(ValueError, match=match):
|
||||||
|
tl.composite(op="gemm", a=_h(0), b=_h(1), out_ptr=0x2000,
|
||||||
|
epilogue=bad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_epilogue_registry_contract():
|
||||||
|
"""EPILOGUE_OPS is the registry tl.composite validates against."""
|
||||||
|
for kind, (required, scope) in EPILOGUE_OPS.items():
|
||||||
|
assert isinstance(kind, str) and kind
|
||||||
|
assert isinstance(required, tuple)
|
||||||
|
assert isinstance(scope, Scope)
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_epilogue_e2e():
|
||||||
|
"""Drive a GEMM + bias + relu composite through the simulator and check
|
||||||
|
op_log: exactly one MATH(bias) and one MATH(relu) record, ordered after
|
||||||
|
GEMM and before STORE for the single (m,n) output tile."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kernbench.runtime_api.bench_runner import run_bench
|
||||||
|
from kernbench.runtime_api.types import resolve_device
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.topology.builder import resolve_topology
|
||||||
|
|
||||||
|
topo_path = Path(__file__).parent.parent / "topology.yaml"
|
||||||
|
topo = resolve_topology(str(topo_path))
|
||||||
|
device = resolve_device(None)
|
||||||
|
|
||||||
|
def _kernel(a_ptr, b_ptr, bias_ptr, out_ptr, tl):
|
||||||
|
a = tl.ref(int(a_ptr), shape=(32, 32), dtype="f16")
|
||||||
|
b = tl.ref(int(b_ptr), shape=(32, 32), dtype="f16")
|
||||||
|
bias = tl.load(int(bias_ptr), shape=(32,), dtype="f16")
|
||||||
|
h = tl.composite(
|
||||||
|
op="gemm", a=a, b=b, out_ptr=int(out_ptr),
|
||||||
|
epilogue=[{"op": "bias", "bias": bias}, {"op": "relu"}],
|
||||||
|
)
|
||||||
|
tl.wait(h)
|
||||||
|
|
||||||
|
def _bench(torch):
|
||||||
|
from kernbench.policy.placement.dp import DPPolicy
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||||
|
num_cubes=1, num_pes=1)
|
||||||
|
a = torch.empty((32, 32), dtype="f16", dp=dp, name="a")
|
||||||
|
b = torch.empty((32, 32), dtype="f16", dp=dp, name="b")
|
||||||
|
bias = torch.empty((32,), dtype="f16", dp=dp, name="bias")
|
||||||
|
out = torch.empty((32, 32), dtype="f16", dp=dp, name="out")
|
||||||
|
torch.launch("composite_epi", _kernel, a, b, bias, out)
|
||||||
|
|
||||||
|
result = run_bench(
|
||||||
|
topology=topo, bench_fn=_bench, device=device,
|
||||||
|
engine_factory=lambda t, d: GraphEngine(
|
||||||
|
getattr(t, "topology_obj", t), enable_data=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.completion.ok
|
||||||
|
|
||||||
|
math = [r for r in result.engine.op_log
|
||||||
|
if r.params.get("stage_type") == "MATH"]
|
||||||
|
assert [r.params.get("op_kind") for r in math] == ["bias", "relu"]
|
||||||
|
|
||||||
|
gemm = [r for r in result.engine.op_log
|
||||||
|
if r.params.get("stage_type") == "GEMM"]
|
||||||
|
store = [r for r in result.engine.op_log
|
||||||
|
if r.params.get("stage_type") == "STORE"]
|
||||||
|
assert gemm and store
|
||||||
|
assert gemm[0].t_end <= math[0].t_start + 1e-6
|
||||||
|
assert math[-1].t_end <= store[0].t_start + 1e-6
|
||||||
@@ -31,7 +31,9 @@ def test_full_graph_edge_count():
|
|||||||
# ADR-0023: +3 IPCQ edges per PE
|
# ADR-0023: +3 IPCQ edges per PE
|
||||||
# ADR-0019 D1 (restored): HBM↔router edges drop from 32 routers × 2
|
# ADR-0019 D1 (restored): HBM↔router edges drop from 32 routers × 2
|
||||||
# to 8 PE-routers × 2 per cube. 32 cubes × (16-64) = -1536 edges.
|
# to 8 PE-routers × 2 per cube. 32 cubes × (16-64) = -1536 edges.
|
||||||
assert len(g.edges) == 12156
|
# Multi-op composite (ADR-0021): +1 gemm→math edge per PE for
|
||||||
|
# epilogue chaining = 2 SIPs × 16 cubes × 8 PEs = +256 edges.
|
||||||
|
assert len(g.edges) == 12412
|
||||||
|
|
||||||
|
|
||||||
# -- Full graph: specific nodes exist -----------------------------------------
|
# -- Full graph: specific nodes exist -----------------------------------------
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ cube:
|
|||||||
fetch_store_to_gemm_mm: 0.0 # fetch → GEMM chaining (ADR-0021)
|
fetch_store_to_gemm_mm: 0.0 # fetch → GEMM chaining (ADR-0021)
|
||||||
fetch_store_to_math_mm: 0.0 # fetch → MATH chaining (ADR-0021)
|
fetch_store_to_math_mm: 0.0 # fetch → MATH chaining (ADR-0021)
|
||||||
gemm_to_fetch_store_mm: 0.0 # GEMM → store chaining (ADR-0021)
|
gemm_to_fetch_store_mm: 0.0 # GEMM → store chaining (ADR-0021)
|
||||||
|
gemm_to_math_mm: 0.0 # GEMM → MATH epilogue chaining (ADR-0021)
|
||||||
math_to_fetch_store_mm: 0.0 # MATH → store chaining (ADR-0021)
|
math_to_fetch_store_mm: 0.0 # MATH → store chaining (ADR-0021)
|
||||||
fetch_store_to_dma_mm: 0.0 # store → DMA writeback chaining (ADR-0021)
|
fetch_store_to_dma_mm: 0.0 # store → DMA writeback chaining (ADR-0021)
|
||||||
gemm_to_tcm_bw_gbs: 512.0
|
gemm_to_tcm_bw_gbs: 512.0
|
||||||
|
|||||||
Reference in New Issue
Block a user