Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -163,6 +163,8 @@ class PeSchedulerComponent(ComponentBase):
|
||||
bytes_per_element=bpe,
|
||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||
pe_prefix=pp,
|
||||
a_pinned=getattr(a, "pinned", False),
|
||||
b_pinned=getattr(b, "pinned", False),
|
||||
)
|
||||
else:
|
||||
# Math composite
|
||||
|
||||
@@ -21,15 +21,22 @@ def generate_gemm_plan(
|
||||
bytes_per_element: int,
|
||||
A_addr: int, B_addr: int, C_addr: int,
|
||||
pe_prefix: str,
|
||||
a_pinned: bool = False,
|
||||
b_pinned: bool = False,
|
||||
) -> PipelinePlan:
|
||||
"""Generate GEMM tile plan: M→N→K order.
|
||||
|
||||
Each tile follows stage sequence:
|
||||
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
|
||||
On last K-tile per (m,n): → DMA_WRITE
|
||||
[DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
|
||||
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
|
||||
DMA_READ(B) skipped when b_pinned=True.
|
||||
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
|
||||
stays in RegFile across K loop.
|
||||
|
||||
Args:
|
||||
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
||||
a_pinned: A operand already resident in TCM (via prior tl.load).
|
||||
b_pinned: B operand already resident in TCM.
|
||||
"""
|
||||
M_tiles = max(1, ceil(M / tile_m))
|
||||
K_tiles = max(1, ceil(K / tile_k))
|
||||
@@ -58,23 +65,26 @@ def generate_gemm_plan(
|
||||
|
||||
stages: list[Stage] = []
|
||||
|
||||
# DMA READ: load A and B tiles from HBM → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": a_addr, "nbytes": a_bytes,
|
||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": b_addr, "nbytes": b_bytes,
|
||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||
},
|
||||
))
|
||||
# DMA READ: load A and B tiles from HBM → TCM.
|
||||
# Skip if the operand is already pre-staged via tl.load.
|
||||
if not a_pinned:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": a_addr, "nbytes": a_bytes,
|
||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||
},
|
||||
))
|
||||
if not b_pinned:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": b_addr, "nbytes": b_bytes,
|
||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||
},
|
||||
))
|
||||
|
||||
# FETCH: TCM → Register File
|
||||
stages.append(Stage(
|
||||
@@ -96,18 +106,17 @@ def generate_gemm_plan(
|
||||
},
|
||||
))
|
||||
|
||||
# STORE: Register File → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "write",
|
||||
"nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
|
||||
# DMA WRITE: TCM → HBM (only on last K-tile)
|
||||
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
||||
# accumulator stays in RegFile across the K loop.
|
||||
if last_k:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "write",
|
||||
"nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_WRITE,
|
||||
component=dma_id,
|
||||
|
||||
Reference in New Issue
Block a user