Files
kernbench2/src/kernbench/components/builtin/pe_gemm.py
T
ywkang b6eb97c49a Implement ADR-0021: PE pipeline refactor with token self-routing
Step 1-2: Backup existing code
- builtin/ → builtin_legacy/ (unchanged backup)
- custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup)

Step 3-4: New pipeline types and tiling
- pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken
- tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel)

Step 5: Component implementations (ADR-0021 D4-D6)
- PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation
- PE_FETCH_STORE: new component — TCM ↔ Register File
- PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode
- PE_MATH: TileToken pipeline + legacy dual-mode
- PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode
- PE_TCM: TcmRequest handler with dual-channel BW serialization

Step 6: Infrastructure
- topology.yaml: pe_fetch_store component + chaining edges
- components.yaml: pe_fetch_store_v1 registration
- builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions
- Tests: node/edge counts, PE component sets updated

All components handle both TileToken (pipeline) and PeInternalTxn (legacy).
Token self-routing: components read next stage from token.plan, chain via out_port.
366 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-08 23:35:31 -07:00

152 lines
5.7 KiB
Python

"""PE_GEMM: matrix multiplication engine (ADR-0021 D6).
Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken.
In pipeline mode, receives token after fetch stage, computes MAC, chains to next.
MAC latency model (from pe_accel):
cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n)
latency_ns = cycles / clock_freq_ghz
Falls back to TFLOPS model when mac dimensions not configured.
"""
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: MAC array (ADR-0021 D6).
In pipeline mode: pure compute — register data already fetched.
In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
# Cycle-accurate MAC dimensions (from pe_accel)
self._mac_m: int = int(node.attrs.get("mac_m", 0))
self._mac_k: int = int(node.attrs.get("mac_k", 0))
self._mac_n: int = int(node.attrs.get("mac_n", 0))
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns_mac(self, m: int, k: int, n: int) -> float:
"""Cycle-accurate MAC latency (pe_accel model)."""
if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0:
cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n)
return cycles / self._clock_freq
return 0.0
def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""TFLOPS-based latency (legacy model)."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""Choose best available latency model."""
mac_ns = self._compute_ns_mac(m, k, n)
if mac_ns > 0:
return mac_ns
return self._compute_ns_tflops(m, k, n, dtype)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure MAC compute, then self-route."""
self._on_process_start(env, token)
m = token.params.get("m", 0)
k = token.params.get("k", 0)
n = token.params.get("n", 0)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
else:
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Legacy PeInternalTxn handling."""
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
if self._accel:
with self._accel.request() as req:
yield req
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
else:
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)