commit - release 1
This commit is contained in:
@@ -0,0 +1,150 @@
|
||||
"""PE-internal command types and handles (ADR-0014).
|
||||
|
||||
Generated by triton_emu (TLContext) and consumed by PE component
|
||||
implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH).
|
||||
|
||||
Command lifecycle:
|
||||
Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import simpy
|
||||
|
||||
|
||||
# ── Handles ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorHandle:
|
||||
"""Opaque reference to a tensor residing in PE_TCM.
|
||||
|
||||
Returned by tl.load, tl.dot, tl.exp, etc.
|
||||
Carries metadata for command generation; data field is reserved
|
||||
for future validate mode (numpy array).
|
||||
"""
|
||||
|
||||
id: str
|
||||
pa: int # physical address in HBM/TCM
|
||||
shape: tuple[int, ...]
|
||||
dtype: str
|
||||
nbytes: int # total byte size
|
||||
data: object = None # reserved for validate mode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompletionHandle:
|
||||
"""Opaque handle for a non-blocking composite command.
|
||||
|
||||
Returned by tl.composite, consumed by tl.wait.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
# ── PE Commands ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DmaReadCmd:
|
||||
"""DMA READ: HBM → PE_TCM."""
|
||||
|
||||
handle: TensorHandle
|
||||
src_pa: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DmaWriteCmd:
|
||||
"""DMA WRITE: PE_TCM → HBM."""
|
||||
|
||||
handle: TensorHandle
|
||||
dst_pa: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GemmCmd:
|
||||
"""GEMM engine command: matrix multiply on TCM data.
|
||||
|
||||
out = a @ b, all operands in TCM.
|
||||
"""
|
||||
|
||||
a: TensorHandle
|
||||
b: TensorHandle
|
||||
out: TensorHandle
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MathCmd:
|
||||
"""MATH engine command: unary/binary/reduction on TCM data.
|
||||
|
||||
op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin",
|
||||
"add", "sub", "mul", "div", "where",
|
||||
"sum", "max", "min"
|
||||
"""
|
||||
|
||||
op: str
|
||||
inputs: tuple[TensorHandle, ...]
|
||||
out: TensorHandle
|
||||
axis: int | None = None # for reductions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompositeCmd:
|
||||
"""Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE.
|
||||
|
||||
Non-blocking — submitted to PE_SCHEDULER which manages tile splitting
|
||||
and pipeline overlaps (ADR-0014 D3.2).
|
||||
"""
|
||||
|
||||
completion: CompletionHandle
|
||||
op: Literal["gemm", "math"]
|
||||
a: TensorHandle
|
||||
b: TensorHandle | None
|
||||
out_pa: int
|
||||
out_nbytes: int
|
||||
math_op: str | None = None # for op="math": which math operation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WaitCmd:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
|
||||
handle: CompletionHandle | None = None # None = wait all
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PeCpuOverheadCmd:
|
||||
"""PE_CPU scalar execution overhead (cycles)."""
|
||||
|
||||
cycles: int
|
||||
|
||||
|
||||
# Union type for all PE commands
|
||||
PeCommand = (
|
||||
DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd
|
||||
| CompositeCmd | WaitCmd | PeCpuOverheadCmd
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeInternalTxn:
|
||||
"""PE-internal message flowing PE_CPU → PE_SCHEDULER → engines.
|
||||
|
||||
Carries a single PeCommand and a completion event. PE_CPU creates one
|
||||
PeInternalTxn per command during the replay phase and sends it to
|
||||
PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA,
|
||||
PE_GEMM, PE_MATH). The engine signals ``done`` on completion.
|
||||
"""
|
||||
|
||||
command: PeCommand
|
||||
done: simpy.Event # succeeded when the engine completes this command
|
||||
pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution
|
||||
result_data: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NewType, Protocol, TypeAlias
|
||||
|
||||
RequestHandle = NewType("RequestHandle", str)
|
||||
|
||||
Trace: TypeAlias = Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Completion:
|
||||
ok: bool
|
||||
error_code: str | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class SimEngine(Protocol):
|
||||
"""
|
||||
Backend simulation/runner engine contract.
|
||||
|
||||
Engine must be able to:
|
||||
- accept requests created by RuntimeContext (submit/dispatch)
|
||||
- report completion and optional trace for a given handle
|
||||
"""
|
||||
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: ...
|
||||
def submit(self, request: Any) -> RequestHandle: ...
|
||||
def wait(self, handle: RequestHandle) -> None: ...
|
||||
Reference in New Issue
Block a user