Files
kernbench2/src/kernbench/components/builtin/tiling.py
T
2026-05-13 15:00:41 -07:00

186 lines
6.3 KiB
Python

"""Tile plan generators for PE pipeline (ADR-0021).
Generates TilePlan with stage sequences for GEMM and Math operations.
Ported from pe_accel tiling.py with stage-based plan structure.
"""
from __future__ import annotations
from math import ceil
from kernbench.components.builtin.pe_types import (
PipelinePlan,
Stage,
StageType,
TilePlan,
)
def generate_gemm_plan(
M: int, K: int, N: int,
tile_m: int, tile_k: int, tile_n: int,
bytes_per_element: int,
A_addr: int, B_addr: int, C_addr: int,
pe_prefix: str,
a_pinned: bool = False,
b_pinned: bool = False,
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
Each tile follows stage sequence:
[DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
DMA_READ(B) skipped when b_pinned=True.
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
stays in RegFile across K loop.
Args:
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
a_pinned: A operand already resident in TCM (via prior tl.load).
b_pinned: B operand already resident in TCM.
"""
M_tiles = max(1, ceil(M / tile_m))
K_tiles = max(1, ceil(K / tile_k))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
gemm_id = f"{pe_prefix}.pe_gemm"
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
for k in range(K_tiles):
last_k = k == K_tiles - 1
a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
a_bytes = tile_m * tile_k * bpe
b_bytes = tile_k * tile_n * bpe
out_bytes = tile_m * tile_n * bpe
stages: list[Stage] = []
# DMA READ: load A and B tiles from HBM → TCM.
# Skip if the operand is already pre-staged via tl.load.
if not a_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
if not b_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# FETCH: TCM → Register File
stages.append(Stage(
stage_type=StageType.FETCH,
component=fetch_id,
params={
"direction": "read",
"nbytes": a_bytes + b_bytes,
},
))
# GEMM: MAC compute
stages.append(Stage(
stage_type=StageType.GEMM,
component=gemm_id,
params={
"m": tile_m, "k": tile_k, "n": tile_n,
"is_last_k": last_k,
},
))
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
# accumulator stays in RegFile across the K loop.
if last_k:
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
stages.append(Stage(
stage_type=StageType.DMA_WRITE,
component=dma_id,
params={
"dst_addr": c_addr, "nbytes": out_bytes,
},
))
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(
tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles,
)
def generate_math_plan(
M: int, N: int,
tile_m: int, tile_n: int,
bytes_per_element: int,
math_op: str,
src_addr: int, dst_addr: int,
pe_prefix: str,
) -> PipelinePlan:
"""Generate element-wise math tile plan.
Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE
"""
M_tiles = max(1, ceil(M / tile_m))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
math_id = f"{pe_prefix}.pe_math"
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
offset = (m * tile_m * N + n * tile_n) * bpe
tile_bytes = tile_m * tile_n * bpe
stages = [
Stage(StageType.DMA_READ, dma_id, {
"src_addr": src_addr + offset, "nbytes": tile_bytes,
}),
Stage(StageType.FETCH, fetch_id, {
"direction": "read", "nbytes": tile_bytes,
}),
Stage(StageType.MATH, math_id, {
"op": math_op, "num_elements": tile_m * tile_n,
}),
Stage(StageType.STORE, fetch_id, {
"direction": "write", "nbytes": tile_bytes,
}),
Stage(StageType.DMA_WRITE, dma_id, {
"dst_addr": dst_addr + offset, "nbytes": tile_bytes,
}),
]
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)