b6eb97c49a
Step 1-2: Backup existing code - builtin/ → builtin_legacy/ (unchanged backup) - custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup) Step 3-4: New pipeline types and tiling - pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken - tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel) Step 5: Component implementations (ADR-0021 D4-D6) - PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation - PE_FETCH_STORE: new component — TCM ↔ Register File - PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode - PE_MATH: TileToken pipeline + legacy dual-mode - PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode - PE_TCM: TcmRequest handler with dual-channel BW serialization Step 6: Infrastructure - topology.yaml: pe_fetch_store component + chaining edges - components.yaml: pe_fetch_store_v1 registration - builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions - Tests: node/edge counts, PE component sets updated All components handle both TileToken (pipeline) and PeInternalTxn (legacy). Token self-routing: components read next stage from token.plan, chain via out_port. 366 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
"""PE_GEMM: matrix multiplication engine (ADR-0021 D6).
|
|
|
|
Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken.
|
|
In pipeline mode, receives token after fetch stage, computes MAC, chains to next.
|
|
|
|
MAC latency model (from pe_accel):
|
|
cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n)
|
|
latency_ns = cycles / clock_freq_ghz
|
|
|
|
Falls back to TFLOPS model when mac dimensions not configured.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from math import ceil
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import PeEngineBase
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.common.pe_commands import PeInternalTxn
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
_DTYPE_BITS: dict[str, int] = {
|
|
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
|
"f32": 32, "fp32": 32, "float32": 32,
|
|
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
|
|
}
|
|
|
|
|
|
class PeGemmComponent(PeEngineBase):
|
|
"""PE_GEMM: MAC array (ADR-0021 D6).
|
|
|
|
In pipeline mode: pure compute — register data already fetched.
|
|
In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot.
|
|
"""
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
self._accel: simpy.Resource | None = None
|
|
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
|
|
# Cycle-accurate MAC dimensions (from pe_accel)
|
|
self._mac_m: int = int(node.attrs.get("mac_m", 0))
|
|
self._mac_k: int = int(node.attrs.get("mac_k", 0))
|
|
self._mac_n: int = int(node.attrs.get("mac_n", 0))
|
|
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
|
|
|
|
def init_resources(self, env: simpy.Environment) -> None:
|
|
resource_name = self.node.attrs.get("shared_resource")
|
|
if resource_name and self.ctx:
|
|
self._accel = self.ctx.get_shared_resource(
|
|
env, f"{self._pe_prefix}.{resource_name}"
|
|
)
|
|
|
|
def _compute_ns_mac(self, m: int, k: int, n: int) -> float:
|
|
"""Cycle-accurate MAC latency (pe_accel model)."""
|
|
if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0:
|
|
cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n)
|
|
return cycles / self._clock_freq
|
|
return 0.0
|
|
|
|
def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
|
|
"""TFLOPS-based latency (legacy model)."""
|
|
if self._peak_tflops_f16 <= 0:
|
|
return float(self.node.attrs.get("overhead_ns", 0.0))
|
|
dtype_bits = _DTYPE_BITS.get(dtype, 16)
|
|
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
|
|
flops = 2.0 * m * k * n
|
|
return flops / (effective_tflops * 1e3)
|
|
|
|
def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
|
|
"""Choose best available latency model."""
|
|
mac_ns = self._compute_ns_mac(m, k, n)
|
|
if mac_ns > 0:
|
|
return mac_ns
|
|
return self._compute_ns_tflops(m, k, n, dtype)
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
|
yield env.timeout(overhead_ns)
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
from kernbench.common.pe_commands import PeInternalTxn
|
|
from kernbench.components.builtin.pe_types import TileToken
|
|
|
|
while True:
|
|
msg: Any = yield self._inbox.get()
|
|
if isinstance(msg, TileToken):
|
|
env.process(self._pipeline_process(env, msg))
|
|
elif isinstance(msg, PeInternalTxn):
|
|
env.process(self._handle_with_hooks(env, msg))
|
|
else:
|
|
env.process(self._forward_txn(env, msg))
|
|
|
|
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
|
"""Pipeline mode: pure MAC compute, then self-route."""
|
|
self._on_process_start(env, token)
|
|
|
|
m = token.params.get("m", 0)
|
|
k = token.params.get("k", 0)
|
|
n = token.params.get("n", 0)
|
|
|
|
if self._accel:
|
|
with self._accel.request() as req:
|
|
yield req
|
|
ns = self._compute_ns(m, k, n)
|
|
yield env.timeout(ns)
|
|
else:
|
|
ns = self._compute_ns(m, k, n)
|
|
yield env.timeout(ns)
|
|
|
|
self._on_process_end(env, token)
|
|
|
|
# Self-routing
|
|
next_stage = token.advance()
|
|
if next_stage is not None:
|
|
yield self.out_ports[next_stage.component].put(token)
|
|
else:
|
|
token.pipeline_ctx.complete_tile()
|
|
|
|
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
|
"""Legacy PeInternalTxn handling."""
|
|
from kernbench.common.pe_commands import GemmCmd
|
|
|
|
cmd = pe_txn.command
|
|
if self._accel:
|
|
with self._accel.request() as req:
|
|
yield req
|
|
if isinstance(cmd, GemmCmd):
|
|
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
|
yield env.timeout(ns)
|
|
else:
|
|
yield from self.run(env, 0)
|
|
else:
|
|
if isinstance(cmd, GemmCmd):
|
|
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
|
yield env.timeout(ns)
|
|
else:
|
|
yield from self.run(env, 0)
|
|
pe_txn.done.succeed()
|
|
|
|
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
if self._accel:
|
|
with self._accel.request() as req:
|
|
yield req
|
|
yield from super()._forward_txn(env, txn)
|
|
else:
|
|
yield from super()._forward_txn(env, txn)
|