a7fe785e5f
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).
API stays backward-compatible - `epilogue=` is opt-in.
Example:
h = tl.composite(
op="gemm", a=a, b=b, out_ptr=int(out),
epilogue=[
{"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
{"op": "bias", "bias": bias_vec},
{"op": "relu"},
{"op": "scale", "factor": 0.5},
],
)
tl.wait(h)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
200 lines
5.9 KiB
Python
200 lines
5.9 KiB
Python
"""PE-internal command types and handles (ADR-0014).
|
|
|
|
Generated by triton_emu (TLContext) and consumed by PE component
|
|
implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH).
|
|
|
|
Command lifecycle:
|
|
Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
if TYPE_CHECKING:
|
|
import simpy
|
|
|
|
|
|
class Scope(Enum):
|
|
K_TILE = "k_tile"
|
|
OUTPUT_TILE = "output_tile"
|
|
KERNEL = "kernel"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OpSpec:
|
|
"""One operation in a multi-op composite (head + epilogue, ADR-0021).
|
|
|
|
The head op (first in CompositeCmd.ops) defines tile geometry; subsequent
|
|
ops are epilogue stages whose ``scope`` controls how often they fire
|
|
(per-K-tile, per-output-tile, or once per kernel).
|
|
"""
|
|
|
|
kind: str # "gemm" | "bias" | "relu" | ...
|
|
scope: "Scope" = Scope.OUTPUT_TILE
|
|
operands: tuple[Any, ...] = () # tuple[TensorHandle, ...]
|
|
scalar: float | None = None
|
|
extra: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
# Epilogue op contracts: kind → (required field names, default scope).
|
|
# Used by tl.composite to validate user-provided epilogue dicts at
|
|
# command-emit time so typos fail before reaching the scheduler.
|
|
EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = {
|
|
"bias": (("bias",), Scope.OUTPUT_TILE),
|
|
"relu": ((), Scope.OUTPUT_TILE),
|
|
"gelu": ((), Scope.OUTPUT_TILE),
|
|
"sigmoid": ((), Scope.OUTPUT_TILE),
|
|
"scale": (("factor",), Scope.OUTPUT_TILE),
|
|
"clamp": (("lo", "hi"), Scope.OUTPUT_TILE),
|
|
"dequant": (("scale",), Scope.K_TILE),
|
|
"add": (("other",), Scope.OUTPUT_TILE),
|
|
}
|
|
|
|
|
|
# ── Handles ───────────────────────────────────────────────────────
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TensorHandle:
|
|
"""Opaque reference to a tensor residing in PE_TCM.
|
|
|
|
Returned by tl.load, tl.dot, tl.exp, etc.
|
|
Carries metadata for command generation; data field is reserved
|
|
for future validate mode (numpy array).
|
|
"""
|
|
|
|
id: str
|
|
addr: int # address (VA when MMU enabled, PA otherwise)
|
|
shape: tuple[int, ...]
|
|
dtype: str
|
|
nbytes: int # total byte size
|
|
data: object = None # reserved for validate mode
|
|
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
|
|
pinned: bool = False # operand already DMA-staged in TCM (via tl.load)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CompletionHandle:
|
|
"""Opaque handle for a non-blocking composite command.
|
|
|
|
Returned by tl.composite, consumed by tl.wait.
|
|
"""
|
|
|
|
id: str
|
|
|
|
|
|
# ── PE Commands ───────────────────────────────────────────────────
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DmaReadCmd:
|
|
"""DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA)."""
|
|
|
|
handle: TensorHandle
|
|
src_addr: int
|
|
nbytes: int
|
|
data_op: bool = True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DmaWriteCmd:
|
|
"""DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA)."""
|
|
|
|
handle: TensorHandle
|
|
dst_addr: int
|
|
nbytes: int
|
|
data_op: bool = True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GemmCmd:
|
|
"""GEMM engine command: matrix multiply on TCM data.
|
|
|
|
out = a @ b, all operands in TCM.
|
|
"""
|
|
|
|
a: TensorHandle
|
|
b: TensorHandle
|
|
out: TensorHandle
|
|
m: int
|
|
k: int
|
|
n: int
|
|
data_op: bool = True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MathCmd:
|
|
"""MATH engine command: unary/binary/reduction on TCM data.
|
|
|
|
op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin",
|
|
"add", "sub", "mul", "div", "where",
|
|
"sum", "max", "min"
|
|
"""
|
|
|
|
op: str
|
|
inputs: tuple[TensorHandle, ...]
|
|
out: TensorHandle
|
|
axis: int | None = None # for reductions
|
|
data_op: bool = True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CompositeCmd:
|
|
"""Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE.
|
|
|
|
Non-blocking — submitted to PE_SCHEDULER which manages tile splitting
|
|
and pipeline overlaps (ADR-0014 D3.2).
|
|
"""
|
|
|
|
completion: CompletionHandle
|
|
op: Literal["gemm", "math"]
|
|
a: TensorHandle
|
|
b: TensorHandle | None
|
|
out_addr: int
|
|
out_nbytes: int
|
|
math_op: str | None = None # for op="math": which math operation
|
|
data_op: bool = True
|
|
# Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the
|
|
# head and ops[1:] are epilogue stages with explicit scope. When empty,
|
|
# the legacy single-op semantics (op/a/b/math_op) apply.
|
|
ops: tuple[OpSpec, ...] = ()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WaitCmd:
|
|
"""Wait for a specific composite or all pending composites."""
|
|
|
|
handle: CompletionHandle | None = None # None = wait all
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PeCpuOverheadCmd:
|
|
"""PE_CPU scalar execution overhead (cycles)."""
|
|
|
|
cycles: int
|
|
|
|
|
|
# Union type for all PE commands
|
|
PeCommand = (
|
|
DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd
|
|
| CompositeCmd | WaitCmd | PeCpuOverheadCmd
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PeInternalTxn:
|
|
"""PE-internal message flowing PE_CPU → PE_SCHEDULER → engines.
|
|
|
|
Carries a single PeCommand and a completion event. PE_CPU creates one
|
|
PeInternalTxn per command during the replay phase and sends it to
|
|
PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA,
|
|
PE_GEMM, PE_MATH). The engine signals ``done`` on completion.
|
|
"""
|
|
|
|
command: PeCommand
|
|
done: simpy.Event # succeeded when the engine completes this command
|
|
pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution
|
|
result_data: dict[str, Any] = field(default_factory=dict)
|