Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-13 15:00:41 -07:00
parent 5accd98171
commit 83ea97b05f
11 changed files with 4219 additions and 51 deletions
+64 -3
View File
@@ -150,7 +150,11 @@ def test_gemm_plan_stage_sequence():
def test_gemm_plan_intermediate_k_no_dma_write():
"""Intermediate K-tiles don't have DMA_WRITE stage."""
"""Intermediate K-tiles don't have DMA_WRITE or STORE stage.
The C accumulator stays in RegFile across the K loop; STORE +
DMA_WRITE only fire on the last K-tile per (m,n).
"""
from kernbench.components.builtin.tiling import generate_gemm_plan
plan = generate_gemm_plan(
@@ -162,15 +166,72 @@ def test_gemm_plan_intermediate_k_no_dma_write():
)
assert len(plan.tiles) == 2
# First tile (k=0): no DMA_WRITE
# First tile (k=0): no STORE, no DMA_WRITE — accumulator stays in RegFile
t0_types = [s.stage_type for s in plan.tiles[0].stages]
assert StageType.STORE not in t0_types
assert StageType.DMA_WRITE not in t0_types
# Last tile (k=1, last_k=True): has DMA_WRITE
# Last tile (k=1, last_k=True): has both STORE and DMA_WRITE
t1_types = [s.stage_type for s in plan.tiles[1].stages]
assert StageType.STORE in t1_types
assert StageType.DMA_WRITE in t1_types
def test_gemm_plan_pinned_operand_skips_dma_read():
"""When a_pinned=True, A's per-tile DMA_READ is omitted.
Same for b_pinned. FETCH is unaffected — it still stages from TCM
into RegFile.
"""
from kernbench.components.builtin.tiling import generate_gemm_plan
# Baseline: neither pinned — both A and B get DMA_READ per tile.
base = generate_gemm_plan(
M=32, K=128, N=32, # K_tiles=2
tile_m=32, tile_k=64, tile_n=32,
bytes_per_element=2,
A_addr=0, B_addr=0x1000, C_addr=0x2000,
pe_prefix="sip0.cube0.pe0",
)
for tile in base.tiles:
operands = [s.params.get("operand") for s in tile.stages
if s.stage_type == StageType.DMA_READ]
assert operands == ["A", "B"], \
f"baseline tile should DMA_READ A and B, got {operands}"
# a_pinned: no A DMA_READ.
plan_a = generate_gemm_plan(
M=32, K=128, N=32,
tile_m=32, tile_k=64, tile_n=32,
bytes_per_element=2,
A_addr=0, B_addr=0x1000, C_addr=0x2000,
pe_prefix="sip0.cube0.pe0",
a_pinned=True,
)
for tile in plan_a.tiles:
operands = [s.params.get("operand") for s in tile.stages
if s.stage_type == StageType.DMA_READ]
assert operands == ["B"], \
f"a_pinned should leave only B DMA_READ, got {operands}"
# FETCH must still exist
assert any(s.stage_type == StageType.FETCH for s in tile.stages)
# Both pinned: no DMA_READ at all.
plan_both = generate_gemm_plan(
M=32, K=128, N=32,
tile_m=32, tile_k=64, tile_n=32,
bytes_per_element=2,
A_addr=0, B_addr=0x1000, C_addr=0x2000,
pe_prefix="sip0.cube0.pe0",
a_pinned=True, b_pinned=True,
)
for tile in plan_both.tiles:
dma_reads = [s for s in tile.stages
if s.stage_type == StageType.DMA_READ]
assert dma_reads == [], \
f"both pinned should skip all DMA_READ, got {dma_reads}"
def test_math_plan_stage_sequence():
"""Math plan has READ→FETCH→MATH→STORE→WRITE sequence."""
from kernbench.components.builtin.tiling import generate_math_plan