"""PE-internal command types and handles (ADR-0014). Generated by triton_emu (TLContext) and consumed by PE component implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH). Command lifecycle: Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines """ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: import simpy class Scope(Enum): K_TILE = "k_tile" OUTPUT_TILE = "output_tile" KERNEL = "kernel" @dataclass(frozen=True) class OpSpec: """One operation in a multi-op composite (head + epilogue, ADR-0021). The head op (first in CompositeCmd.ops) defines tile geometry; subsequent ops are epilogue stages whose ``scope`` controls how often they fire (per-K-tile, per-output-tile, or once per kernel). """ kind: str # "gemm" | "bias" | "relu" | ... scope: "Scope" = Scope.OUTPUT_TILE operands: tuple[Any, ...] = () # tuple[TensorHandle, ...] scalar: float | None = None extra: dict[str, Any] = field(default_factory=dict) # Epilogue op contracts: kind → (required field names, default scope). # Used by tl.composite to validate user-provided epilogue dicts at # command-emit time so typos fail before reaching the scheduler. EPILOGUE_OPS: dict[str, tuple[tuple[str, ...], "Scope"]] = { "bias": (("bias",), Scope.OUTPUT_TILE), "relu": ((), Scope.OUTPUT_TILE), "gelu": ((), Scope.OUTPUT_TILE), "sigmoid": ((), Scope.OUTPUT_TILE), "scale": (("factor",), Scope.OUTPUT_TILE), "clamp": (("lo", "hi"), Scope.OUTPUT_TILE), "dequant": (("scale",), Scope.K_TILE), "add": (("other",), Scope.OUTPUT_TILE), } # ── Handles ─────────────────────────────────────────────────────── @dataclass(frozen=True) class TensorHandle: """Opaque reference to a tensor residing in PE_TCM. Returned by tl.load, tl.dot, tl.exp, etc. Carries metadata for command generation; data field is reserved for future validate mode (numpy array). """ id: str addr: int # address (VA when MMU enabled, PA otherwise) shape: tuple[int, ...] dtype: str nbytes: int # total byte size data: object = None # reserved for validate mode space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram") pinned: bool = False # operand already DMA-staged in TCM (via tl.load) @dataclass(frozen=True) class CompletionHandle: """Opaque handle for a non-blocking composite command. Returned by tl.composite, consumed by tl.wait. """ id: str # ── PE Commands ─────────────────────────────────────────────────── @dataclass(frozen=True) class DmaReadCmd: """DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA).""" handle: TensorHandle src_addr: int nbytes: int data_op: bool = True @dataclass(frozen=True) class DmaWriteCmd: """DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA).""" handle: TensorHandle dst_addr: int nbytes: int data_op: bool = True @dataclass(frozen=True) class GemmCmd: """GEMM engine command: matrix multiply on TCM data. out = a @ b, all operands in TCM. """ a: TensorHandle b: TensorHandle out: TensorHandle m: int k: int n: int data_op: bool = True @dataclass(frozen=True) class MathCmd: """MATH engine command: unary/binary/reduction on TCM data. op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin", "add", "sub", "mul", "div", "where", "sum", "max", "min" """ op: str inputs: tuple[TensorHandle, ...] out: TensorHandle axis: int | None = None # for reductions data_op: bool = True @dataclass(frozen=True) class CompositeCmd: """Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE. Non-blocking — submitted to PE_SCHEDULER which manages tile splitting and pipeline overlaps (ADR-0014 D3.2). """ completion: CompletionHandle op: Literal["gemm", "math"] a: TensorHandle b: TensorHandle | None out_addr: int out_nbytes: int math_op: str | None = None # for op="math": which math operation data_op: bool = True # Multi-op composite (ADR-0021 extension): when non-empty, ops[0] is the # head and ops[1:] are epilogue stages with explicit scope. When empty, # the legacy single-op semantics (op/a/b/math_op) apply. ops: tuple[OpSpec, ...] = () @dataclass(frozen=True) class WaitCmd: """Wait for a specific composite or all pending composites.""" handle: CompletionHandle | None = None # None = wait all @dataclass(frozen=True) class PeCpuOverheadCmd: """PE_CPU scalar execution overhead (cycles).""" cycles: int # Union type for all PE commands PeCommand = ( DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd | CompositeCmd | WaitCmd | PeCpuOverheadCmd ) @dataclass class PeInternalTxn: """PE-internal message flowing PE_CPU → PE_SCHEDULER → engines. Carries a single PeCommand and a completion event. PE_CPU creates one PeInternalTxn per command during the replay phase and sends it to PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA, PE_GEMM, PE_MATH). The engine signals ``done`` on completion. """ command: PeCommand done: simpy.Event # succeeded when the engine completes this command pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution result_data: dict[str, Any] = field(default_factory=dict)