"""Tile plan generators for PE pipeline (ADR-0021). Generates TilePlan with stage sequences for GEMM and Math operations. Ported from pe_accel tiling.py with stage-based plan structure. """ from __future__ import annotations from math import ceil from kernbench.components.builtin.pe_types import ( PipelinePlan, Stage, StageType, TilePlan, ) def generate_gemm_plan( M: int, K: int, N: int, tile_m: int, tile_k: int, tile_n: int, bytes_per_element: int, A_addr: int, B_addr: int, C_addr: int, pe_prefix: str, a_pinned: bool = False, b_pinned: bool = False, ) -> PipelinePlan: """Generate GEMM tile plan: M→N→K order. Each tile follows stage sequence: [DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE] DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM). DMA_READ(B) skipped when b_pinned=True. STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator stays in RegFile across K loop. Args: pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs. a_pinned: A operand already resident in TCM (via prior tl.load). b_pinned: B operand already resident in TCM. """ M_tiles = max(1, ceil(M / tile_m)) K_tiles = max(1, ceil(K / tile_k)) N_tiles = max(1, ceil(N / tile_n)) bpe = bytes_per_element dma_id = f"{pe_prefix}.pe_dma" fetch_id = f"{pe_prefix}.pe_fetch_store" gemm_id = f"{pe_prefix}.pe_gemm" # math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed tiles: list[TilePlan] = [] tile_id = 0 for m in range(M_tiles): for n in range(N_tiles): c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe for k in range(K_tiles): last_k = k == K_tiles - 1 a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe a_bytes = tile_m * tile_k * bpe b_bytes = tile_k * tile_n * bpe out_bytes = tile_m * tile_n * bpe stages: list[Stage] = [] # DMA READ: load A and B tiles from HBM → TCM. # Skip if the operand is already pre-staged via tl.load. if not a_pinned: stages.append(Stage( stage_type=StageType.DMA_READ, component=dma_id, params={ "src_addr": a_addr, "nbytes": a_bytes, "operand": "A", "tile_m": tile_m, "tile_k": tile_k, }, )) if not b_pinned: stages.append(Stage( stage_type=StageType.DMA_READ, component=dma_id, params={ "src_addr": b_addr, "nbytes": b_bytes, "operand": "B", "tile_k": tile_k, "tile_n": tile_n, }, )) # FETCH: TCM → Register File stages.append(Stage( stage_type=StageType.FETCH, component=fetch_id, params={ "direction": "read", "nbytes": a_bytes + b_bytes, }, )) # GEMM: MAC compute stages.append(Stage( stage_type=StageType.GEMM, component=gemm_id, params={ "m": tile_m, "k": tile_k, "n": tile_n, "is_last_k": last_k, }, )) # STORE + DMA_WRITE only on last K-tile per (m,n). The C # accumulator stays in RegFile across the K loop. if last_k: stages.append(Stage( stage_type=StageType.STORE, component=fetch_id, params={ "direction": "write", "nbytes": out_bytes, }, )) stages.append(Stage( stage_type=StageType.DMA_WRITE, component=dma_id, params={ "dst_addr": c_addr, "nbytes": out_bytes, }, )) tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages))) tile_id += 1 return PipelinePlan( tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles, ) def generate_math_plan( M: int, N: int, tile_m: int, tile_n: int, bytes_per_element: int, math_op: str, src_addr: int, dst_addr: int, pe_prefix: str, ) -> PipelinePlan: """Generate element-wise math tile plan. Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE """ M_tiles = max(1, ceil(M / tile_m)) N_tiles = max(1, ceil(N / tile_n)) bpe = bytes_per_element dma_id = f"{pe_prefix}.pe_dma" fetch_id = f"{pe_prefix}.pe_fetch_store" math_id = f"{pe_prefix}.pe_math" tiles: list[TilePlan] = [] tile_id = 0 for m in range(M_tiles): for n in range(N_tiles): offset = (m * tile_m * N + n * tile_n) * bpe tile_bytes = tile_m * tile_n * bpe stages = [ Stage(StageType.DMA_READ, dma_id, { "src_addr": src_addr + offset, "nbytes": tile_bytes, }), Stage(StageType.FETCH, fetch_id, { "direction": "read", "nbytes": tile_bytes, }), Stage(StageType.MATH, math_id, { "op": math_op, "num_elements": tile_m * tile_n, }), Stage(StageType.STORE, fetch_id, { "direction": "write", "nbytes": tile_bytes, }), Stage(StageType.DMA_WRITE, dma_id, { "dst_addr": dst_addr + offset, "nbytes": tile_bytes, }), ] tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages))) tile_id += 1 return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)