"""PE_GEMM: matrix multiplication engine (ADR-0021 D6). Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken. In pipeline mode, receives token after fetch stage, computes MAC, chains to next. MAC latency model (from pe_accel): cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n) latency_ns = cycles / clock_freq_ghz Falls back to TFLOPS model when mac dimensions not configured. """ from __future__ import annotations from collections.abc import Generator from math import ceil from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import PeEngineBase if TYPE_CHECKING: from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.context import ComponentContext from kernbench.topology.types import Node _DTYPE_BITS: dict[str, int] = { "f16": 16, "fp16": 16, "float16": 16, "bf16": 16, "f32": 32, "fp32": 32, "float32": 32, "i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32, } class PeGemmComponent(PeEngineBase): """PE_GEMM: MAC array (ADR-0021 D6). In pipeline mode: pure compute — register data already fetched. In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._accel: simpy.Resource | None = None self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0)) # Cycle-accurate MAC dimensions (from pe_accel) self._mac_m: int = int(node.attrs.get("mac_m", 0)) self._mac_k: int = int(node.attrs.get("mac_k", 0)) self._mac_n: int = int(node.attrs.get("mac_n", 0)) self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0)) def init_resources(self, env: simpy.Environment) -> None: resource_name = self.node.attrs.get("shared_resource") if resource_name and self.ctx: self._accel = self.ctx.get_shared_resource( env, f"{self._pe_prefix}.{resource_name}" ) def _compute_ns_mac(self, m: int, k: int, n: int) -> float: """Cycle-accurate MAC latency (pe_accel model).""" if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0: cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n) return cycles / self._clock_freq return 0.0 def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float: """TFLOPS-based latency (legacy model).""" if self._peak_tflops_f16 <= 0: return float(self.node.attrs.get("overhead_ns", 0.0)) dtype_bits = _DTYPE_BITS.get(dtype, 16) effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits) flops = 2.0 * m * k * n return flops / (effective_tflops * 1e3) def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float: """Choose best available latency model.""" mac_ns = self._compute_ns_mac(m, k, n) if mac_ns > 0: return mac_ns return self._compute_ns_tflops(m, k, n, dtype) def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _worker(self, env: simpy.Environment) -> Generator: from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.builtin.pe_types import TileToken while True: msg: Any = yield self._inbox.get() if isinstance(msg, TileToken): env.process(self._pipeline_process(env, msg)) elif isinstance(msg, PeInternalTxn): env.process(self._handle_with_hooks(env, msg)) else: env.process(self._forward_txn(env, msg)) def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: """Pipeline mode: pure MAC compute, then self-route.""" self._on_process_start(env, token) m = token.params.get("m", 0) k = token.params.get("k", 0) n = token.params.get("n", 0) if self._accel: with self._accel.request() as req: yield req ns = self._compute_ns(m, k, n) yield env.timeout(ns) else: ns = self._compute_ns(m, k, n) yield env.timeout(ns) self._on_process_end(env, token) # Self-routing next_stage = token.advance() if next_stage is not None: yield self.out_ports[next_stage.component].put(token) else: token.pipeline_ctx.complete_tile() def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """Legacy PeInternalTxn handling.""" from kernbench.common.pe_commands import GemmCmd cmd = pe_txn.command if self._accel: with self._accel.request() as req: yield req if isinstance(cmd, GemmCmd): ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype) yield env.timeout(ns) else: yield from self.run(env, 0) else: if isinstance(cmd, GemmCmd): ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype) yield env.timeout(ns) else: yield from self.run(env, 0) pe_txn.done.succeed() def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: if self._accel: with self._accel.request() as req: yield req yield from super()._forward_txn(env, txn) else: yield from super()._forward_txn(env, txn)