Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -150,7 +150,11 @@ def test_gemm_plan_stage_sequence():
|
||||
|
||||
|
||||
def test_gemm_plan_intermediate_k_no_dma_write():
|
||||
"""Intermediate K-tiles don't have DMA_WRITE stage."""
|
||||
"""Intermediate K-tiles don't have DMA_WRITE or STORE stage.
|
||||
|
||||
The C accumulator stays in RegFile across the K loop; STORE +
|
||||
DMA_WRITE only fire on the last K-tile per (m,n).
|
||||
"""
|
||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||
|
||||
plan = generate_gemm_plan(
|
||||
@@ -162,15 +166,72 @@ def test_gemm_plan_intermediate_k_no_dma_write():
|
||||
)
|
||||
assert len(plan.tiles) == 2
|
||||
|
||||
# First tile (k=0): no DMA_WRITE
|
||||
# First tile (k=0): no STORE, no DMA_WRITE — accumulator stays in RegFile
|
||||
t0_types = [s.stage_type for s in plan.tiles[0].stages]
|
||||
assert StageType.STORE not in t0_types
|
||||
assert StageType.DMA_WRITE not in t0_types
|
||||
|
||||
# Last tile (k=1, last_k=True): has DMA_WRITE
|
||||
# Last tile (k=1, last_k=True): has both STORE and DMA_WRITE
|
||||
t1_types = [s.stage_type for s in plan.tiles[1].stages]
|
||||
assert StageType.STORE in t1_types
|
||||
assert StageType.DMA_WRITE in t1_types
|
||||
|
||||
|
||||
def test_gemm_plan_pinned_operand_skips_dma_read():
|
||||
"""When a_pinned=True, A's per-tile DMA_READ is omitted.
|
||||
|
||||
Same for b_pinned. FETCH is unaffected — it still stages from TCM
|
||||
into RegFile.
|
||||
"""
|
||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||
|
||||
# Baseline: neither pinned — both A and B get DMA_READ per tile.
|
||||
base = generate_gemm_plan(
|
||||
M=32, K=128, N=32, # K_tiles=2
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
)
|
||||
for tile in base.tiles:
|
||||
operands = [s.params.get("operand") for s in tile.stages
|
||||
if s.stage_type == StageType.DMA_READ]
|
||||
assert operands == ["A", "B"], \
|
||||
f"baseline tile should DMA_READ A and B, got {operands}"
|
||||
|
||||
# a_pinned: no A DMA_READ.
|
||||
plan_a = generate_gemm_plan(
|
||||
M=32, K=128, N=32,
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
a_pinned=True,
|
||||
)
|
||||
for tile in plan_a.tiles:
|
||||
operands = [s.params.get("operand") for s in tile.stages
|
||||
if s.stage_type == StageType.DMA_READ]
|
||||
assert operands == ["B"], \
|
||||
f"a_pinned should leave only B DMA_READ, got {operands}"
|
||||
# FETCH must still exist
|
||||
assert any(s.stage_type == StageType.FETCH for s in tile.stages)
|
||||
|
||||
# Both pinned: no DMA_READ at all.
|
||||
plan_both = generate_gemm_plan(
|
||||
M=32, K=128, N=32,
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
a_pinned=True, b_pinned=True,
|
||||
)
|
||||
for tile in plan_both.tiles:
|
||||
dma_reads = [s for s in tile.stages
|
||||
if s.stage_type == StageType.DMA_READ]
|
||||
assert dma_reads == [], \
|
||||
f"both pinned should skip all DMA_READ, got {dma_reads}"
|
||||
|
||||
|
||||
def test_math_plan_stage_sequence():
|
||||
"""Math plan has READ→FETCH→MATH→STORE→WRITE sequence."""
|
||||
from kernbench.components.builtin.tiling import generate_math_plan
|
||||
|
||||
Reference in New Issue
Block a user