83ea97b05f
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
"""Tile plan generators for PE pipeline (ADR-0021).
|
|
|
|
Generates TilePlan with stage sequences for GEMM and Math operations.
|
|
Ported from pe_accel tiling.py with stage-based plan structure.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from math import ceil
|
|
|
|
from kernbench.components.builtin.pe_types import (
|
|
PipelinePlan,
|
|
Stage,
|
|
StageType,
|
|
TilePlan,
|
|
)
|
|
|
|
|
|
def generate_gemm_plan(
|
|
M: int, K: int, N: int,
|
|
tile_m: int, tile_k: int, tile_n: int,
|
|
bytes_per_element: int,
|
|
A_addr: int, B_addr: int, C_addr: int,
|
|
pe_prefix: str,
|
|
a_pinned: bool = False,
|
|
b_pinned: bool = False,
|
|
) -> PipelinePlan:
|
|
"""Generate GEMM tile plan: M→N→K order.
|
|
|
|
Each tile follows stage sequence:
|
|
[DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
|
|
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
|
|
DMA_READ(B) skipped when b_pinned=True.
|
|
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
|
|
stays in RegFile across K loop.
|
|
|
|
Args:
|
|
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
|
a_pinned: A operand already resident in TCM (via prior tl.load).
|
|
b_pinned: B operand already resident in TCM.
|
|
"""
|
|
M_tiles = max(1, ceil(M / tile_m))
|
|
K_tiles = max(1, ceil(K / tile_k))
|
|
N_tiles = max(1, ceil(N / tile_n))
|
|
bpe = bytes_per_element
|
|
|
|
dma_id = f"{pe_prefix}.pe_dma"
|
|
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
|
gemm_id = f"{pe_prefix}.pe_gemm"
|
|
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
|
|
|
|
tiles: list[TilePlan] = []
|
|
tile_id = 0
|
|
|
|
for m in range(M_tiles):
|
|
for n in range(N_tiles):
|
|
c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
|
|
for k in range(K_tiles):
|
|
last_k = k == K_tiles - 1
|
|
a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
|
|
b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
|
|
|
|
a_bytes = tile_m * tile_k * bpe
|
|
b_bytes = tile_k * tile_n * bpe
|
|
out_bytes = tile_m * tile_n * bpe
|
|
|
|
stages: list[Stage] = []
|
|
|
|
# DMA READ: load A and B tiles from HBM → TCM.
|
|
# Skip if the operand is already pre-staged via tl.load.
|
|
if not a_pinned:
|
|
stages.append(Stage(
|
|
stage_type=StageType.DMA_READ,
|
|
component=dma_id,
|
|
params={
|
|
"src_addr": a_addr, "nbytes": a_bytes,
|
|
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
|
},
|
|
))
|
|
if not b_pinned:
|
|
stages.append(Stage(
|
|
stage_type=StageType.DMA_READ,
|
|
component=dma_id,
|
|
params={
|
|
"src_addr": b_addr, "nbytes": b_bytes,
|
|
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
|
},
|
|
))
|
|
|
|
# FETCH: TCM → Register File
|
|
stages.append(Stage(
|
|
stage_type=StageType.FETCH,
|
|
component=fetch_id,
|
|
params={
|
|
"direction": "read",
|
|
"nbytes": a_bytes + b_bytes,
|
|
},
|
|
))
|
|
|
|
# GEMM: MAC compute
|
|
stages.append(Stage(
|
|
stage_type=StageType.GEMM,
|
|
component=gemm_id,
|
|
params={
|
|
"m": tile_m, "k": tile_k, "n": tile_n,
|
|
"is_last_k": last_k,
|
|
},
|
|
))
|
|
|
|
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
|
# accumulator stays in RegFile across the K loop.
|
|
if last_k:
|
|
stages.append(Stage(
|
|
stage_type=StageType.STORE,
|
|
component=fetch_id,
|
|
params={
|
|
"direction": "write",
|
|
"nbytes": out_bytes,
|
|
},
|
|
))
|
|
stages.append(Stage(
|
|
stage_type=StageType.DMA_WRITE,
|
|
component=dma_id,
|
|
params={
|
|
"dst_addr": c_addr, "nbytes": out_bytes,
|
|
},
|
|
))
|
|
|
|
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
|
|
tile_id += 1
|
|
|
|
return PipelinePlan(
|
|
tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles,
|
|
)
|
|
|
|
|
|
def generate_math_plan(
|
|
M: int, N: int,
|
|
tile_m: int, tile_n: int,
|
|
bytes_per_element: int,
|
|
math_op: str,
|
|
src_addr: int, dst_addr: int,
|
|
pe_prefix: str,
|
|
) -> PipelinePlan:
|
|
"""Generate element-wise math tile plan.
|
|
|
|
Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE
|
|
"""
|
|
M_tiles = max(1, ceil(M / tile_m))
|
|
N_tiles = max(1, ceil(N / tile_n))
|
|
bpe = bytes_per_element
|
|
|
|
dma_id = f"{pe_prefix}.pe_dma"
|
|
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
|
math_id = f"{pe_prefix}.pe_math"
|
|
|
|
tiles: list[TilePlan] = []
|
|
tile_id = 0
|
|
|
|
for m in range(M_tiles):
|
|
for n in range(N_tiles):
|
|
offset = (m * tile_m * N + n * tile_n) * bpe
|
|
tile_bytes = tile_m * tile_n * bpe
|
|
|
|
stages = [
|
|
Stage(StageType.DMA_READ, dma_id, {
|
|
"src_addr": src_addr + offset, "nbytes": tile_bytes,
|
|
}),
|
|
Stage(StageType.FETCH, fetch_id, {
|
|
"direction": "read", "nbytes": tile_bytes,
|
|
}),
|
|
Stage(StageType.MATH, math_id, {
|
|
"op": math_op, "num_elements": tile_m * tile_n,
|
|
}),
|
|
Stage(StageType.STORE, fetch_id, {
|
|
"direction": "write", "nbytes": tile_bytes,
|
|
}),
|
|
Stage(StageType.DMA_WRITE, dma_id, {
|
|
"dst_addr": dst_addr + offset, "nbytes": tile_bytes,
|
|
}),
|
|
]
|
|
|
|
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
|
|
tile_id += 1
|
|
|
|
return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)
|