from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase if TYPE_CHECKING: from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class PeSchedulerComponent(ComponentBase): """PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1). Receives PeInternalTxn from PE_CPU, routes to the appropriate engine: - DmaReadCmd / DmaWriteCmd → PE_DMA - GemmCmd → PE_GEMM - MathCmd → PE_MATH - CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2) Composite GEMM pipeline (32x64x32 tiles): DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t) with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) Applies scheduler overhead_ns before dispatching each command. Non-PeInternalTxn messages are forwarded via inherited _forward_txn(). """ # Scheduler tile dimensions (ADR-0014 D3.2) TILE_M = 32 TILE_K = 64 TILE_N = 32 # Command → engine suffix dispatch table. # New engines: add a single entry here (e.g. ConvCmd: "pe_conv"). _CMD_DISPATCH: dict[type, str] = {} @classmethod def _ensure_dispatch_table(cls) -> None: if cls._CMD_DISPATCH: return from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd cls._CMD_DISPATCH = { DmaReadCmd: "pe_dma", DmaWriteCmd: "pe_dma", GemmCmd: "pe_gemm", MathCmd: "pe_math", } def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._pe_prefix = node.id.rsplit(".", 1)[0] self._ensure_dispatch_table() def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _worker(self, env: simpy.Environment) -> Generator: from kernbench.common.pe_commands import PeInternalTxn while True: msg: Any = yield self._inbox.get() if isinstance(msg, PeInternalTxn): env.process(self._dispatch(env, msg)) else: yield from self._forward_txn(env, msg) def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """Route a PeInternalTxn to the correct engine via dispatch table.""" from kernbench.common.pe_commands import CompositeCmd # Scheduler overhead yield from self.run(env, 0) cmd = pe_txn.command # Check dispatch table first engine_suffix = self._CMD_DISPATCH.get(type(cmd)) if engine_suffix is not None: yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn) return # CompositeCmd: tiled pipeline (not a simple forward) if isinstance(cmd, CompositeCmd): yield from self._dispatch_composite(env, pe_txn) return # Unknown command — signal done immediately pe_txn.done.succeed() def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """Composite tiled pipeline (ADR-0014 D3.2). GEMM: 3-stage pipeline with b-tile streaming from HBM. MATH: sequential compute + DMA_WRITE (no tiling). """ from kernbench.common.pe_commands import CompositeCmd cmd = pe_txn.command assert isinstance(cmd, CompositeCmd) if cmd.op == "gemm" and cmd.b is not None: yield from self._pipeline_gemm(env, pe_txn, cmd) else: yield from self._pipeline_math(env, pe_txn, cmd) def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: """Tiled GEMM pipeline: stream b tiles from HBM, compute, write results. Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref). Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t) Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) """ from kernbench.common.pe_commands import ( DmaReadCmd, DmaWriteCmd, GemmCmd, PeInternalTxn as PeTxn, TensorHandle, ) pp = self._pe_prefix a = cmd.a # already in TCM b = cmd.b # HBM reference (via tl.ref) M, K_a = a.shape[-2], a.shape[-1] K_b, N = b.shape[-2], b.shape[-1] dtype = a.dtype dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2 # Tile counts n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K) n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N) n_tiles = n_tiles_k * n_tiles_n prev_compute_done = None prev_write_done = None total_dma_ns = 0.0 total_compute_ns = 0.0 for tile_idx in range(n_tiles): tk = tile_idx // n_tiles_n tn = tile_idx % n_tiles_n k_start = tk * self.TILE_K n_start = tn * self.TILE_N tile_k = min(self.TILE_K, K_a - k_start) tile_n = min(self.TILE_N, N - n_start) tile_nbytes = tile_k * tile_n * dtype_bytes # --- Stage 1: DMA_READ b_tile from HBM --- read_done = env.event() b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes b_tile_handle = TensorHandle( id=f"b_tile_{tile_idx}", addr=b_tile_addr, shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes, ) read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes) read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp) t0 = env.now yield self.out_ports[f"{pp}.pe_dma"].put(read_txn) # Wait for previous compute before starting this tile's compute if prev_compute_done is not None: yield prev_compute_done # Wait for this tile's DMA_READ yield read_done total_dma_ns += env.now - t0 # --- Stage 2: COMPUTE (GEMM) --- compute_done = env.event() out_handle = TensorHandle( id=f"out_tile_{tile_idx}", addr=0, shape=(M, tile_n), dtype=dtype, nbytes=M * tile_n * dtype_bytes, ) compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle, m=M, k=tile_k, n=tile_n) compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) t0 = env.now yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn) # Wait for previous write (DMA_WRITE serialization) if prev_write_done is not None: yield prev_write_done # Wait for compute of THIS tile yield compute_done total_compute_ns += env.now - t0 prev_compute_done = compute_done # --- Stage 3: DMA_WRITE out_tile to HBM --- write_done = env.event() out_tile_pa = cmd.out_addr + n_start * dtype_bytes write_nbytes = M * tile_n * dtype_bytes write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes) write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) t0 = env.now yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) prev_write_done = write_done # Wait for final write if prev_write_done is not None: t0 = env.now yield prev_write_done total_dma_ns += env.now - t0 pe_txn.result_data["dma_ns"] = total_dma_ns pe_txn.result_data["compute_ns"] = total_compute_ns pe_txn.done.succeed() def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: """Non-GEMM composite: sequential compute + DMA_WRITE (no tiling).""" from kernbench.common.pe_commands import ( DmaWriteCmd, MathCmd, PeInternalTxn as PeTxn, ) pp = self._pe_prefix # Step 1: Compute (MATH) compute_done = env.event() compute_cmd = MathCmd( op=cmd.math_op or "identity", inputs=(cmd.a,), out=cmd.a, ) compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) yield self.out_ports[f"{pp}.pe_math"].put(compute_txn) yield compute_done # Step 2: DMA_WRITE result to HBM write_done = env.event() write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes) write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) yield write_done pe_txn.done.succeed()