Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-13 15:00:41 -07:00
parent 5accd98171
commit 83ea97b05f
11 changed files with 4219 additions and 51 deletions
+39 -30
View File
@@ -21,15 +21,22 @@ def generate_gemm_plan(
bytes_per_element: int,
A_addr: int, B_addr: int, C_addr: int,
pe_prefix: str,
a_pinned: bool = False,
b_pinned: bool = False,
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
Each tile follows stage sequence:
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
On last K-tile per (m,n): → DMA_WRITE
[DMA_READ(A)][DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
DMA_READ(B) skipped when b_pinned=True.
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
stays in RegFile across K loop.
Args:
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
a_pinned: A operand already resident in TCM (via prior tl.load).
b_pinned: B operand already resident in TCM.
"""
M_tiles = max(1, ceil(M / tile_m))
K_tiles = max(1, ceil(K / tile_k))
@@ -58,23 +65,26 @@ def generate_gemm_plan(
stages: list[Stage] = []
# DMA READ: load A and B tiles from HBM → TCM
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# DMA READ: load A and B tiles from HBM → TCM.
# Skip if the operand is already pre-staged via tl.load.
if not a_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
if not b_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# FETCH: TCM → Register File
stages.append(Stage(
@@ -96,18 +106,17 @@ def generate_gemm_plan(
},
))
# STORE: Register File → TCM
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
# DMA WRITE: TCM → HBM (only on last K-tile)
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
# accumulator stays in RegFile across the K loop.
if last_k:
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
stages.append(Stage(
stage_type=StageType.DMA_WRITE,
component=dma_id,