Add ADR-0021 pipeline tests: self-routing, tiling, overlap
Test plan items 3-5: - TileToken self-routing: advance(), stage sequence, chain traversal - PipelineContext: completion tracking, exactly-once contract - Tiling plans: GEMM tile count, stage sequence, intermediate K no DMA_WRITE - Math plan: READ→FETCH→MATH→STORE→WRITE sequence - Pipeline overlap: SimPy simulation verifying intra-command tile overlap 9 new tests, all passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,248 @@
|
||||
"""Tests for ADR-0021 PE pipeline: TileToken self-routing, pipeline overlap, e2e accuracy.
|
||||
|
||||
Test plan items:
|
||||
3. Phase 1 → Phase 2 end-to-end (op_log → DataExecutor → verify)
|
||||
4. TileToken self-routing (stage sequence, PipelineContext completion)
|
||||
5. Async pipeline overlap (intra-command tile overlap, FIFO ordering)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.builtin.pe_types import (
|
||||
PipelineContext,
|
||||
PipelinePlan,
|
||||
Stage,
|
||||
StageType,
|
||||
TilePlan,
|
||||
TileToken,
|
||||
)
|
||||
|
||||
|
||||
# ── 4. TileToken self-routing ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tile_token_advance():
|
||||
"""TileToken.advance() increments stage_idx and returns next Stage."""
|
||||
stages = (
|
||||
Stage(StageType.DMA_READ, "pe_dma", {"src_addr": 0}),
|
||||
Stage(StageType.FETCH, "pe_fetch_store", {"direction": "read"}),
|
||||
Stage(StageType.GEMM, "pe_gemm", {"m": 32, "k": 64, "n": 32}),
|
||||
)
|
||||
plan = TilePlan(tile_id=0, stages=stages)
|
||||
ctx = PipelineContext(id="p1", total_tiles=1)
|
||||
token = TileToken(
|
||||
tile_id=0, pipeline_ctx=ctx, plan=plan,
|
||||
stage_idx=0, params=stages[0].params,
|
||||
)
|
||||
|
||||
assert token.current_stage.stage_type == StageType.DMA_READ
|
||||
|
||||
next_s = token.advance()
|
||||
assert next_s is not None
|
||||
assert next_s.stage_type == StageType.FETCH
|
||||
assert token.stage_idx == 1
|
||||
assert token.params == {"direction": "read"}
|
||||
|
||||
next_s = token.advance()
|
||||
assert next_s is not None
|
||||
assert next_s.stage_type == StageType.GEMM
|
||||
assert token.stage_idx == 2
|
||||
|
||||
# Last stage — advance returns None
|
||||
assert token.advance() is None
|
||||
assert token.stage_idx == 3
|
||||
|
||||
|
||||
def test_pipeline_context_completion():
|
||||
"""PipelineContext.complete_tile() fires done_event on last tile."""
|
||||
env = simpy.Environment()
|
||||
done = env.event()
|
||||
ctx = PipelineContext(id="p1", total_tiles=3, done_event=done)
|
||||
|
||||
ctx.complete_tile()
|
||||
assert not done.triggered
|
||||
ctx.complete_tile()
|
||||
assert not done.triggered
|
||||
ctx.complete_tile()
|
||||
assert done.triggered
|
||||
|
||||
|
||||
def test_pipeline_context_exactly_once():
|
||||
"""PipelineContext tracks completed_tiles correctly."""
|
||||
ctx = PipelineContext(id="p1", total_tiles=2)
|
||||
assert ctx.completed_tiles == 0
|
||||
ctx.complete_tile()
|
||||
assert ctx.completed_tiles == 1
|
||||
ctx.complete_tile()
|
||||
assert ctx.completed_tiles == 2
|
||||
|
||||
|
||||
def test_tile_token_self_routing_chain():
|
||||
"""Simulated self-routing: component reads next stage from token."""
|
||||
stages = (
|
||||
Stage(StageType.DMA_READ, "dma", {}),
|
||||
Stage(StageType.FETCH, "fetch", {}),
|
||||
Stage(StageType.GEMM, "gemm", {}),
|
||||
Stage(StageType.STORE, "fetch", {}),
|
||||
Stage(StageType.DMA_WRITE, "dma", {}),
|
||||
)
|
||||
plan = TilePlan(tile_id=0, stages=stages)
|
||||
ctx = PipelineContext(id="p1", total_tiles=1)
|
||||
token = TileToken(
|
||||
tile_id=0, pipeline_ctx=ctx, plan=plan,
|
||||
stage_idx=0, params=stages[0].params,
|
||||
)
|
||||
|
||||
visited = []
|
||||
while True:
|
||||
visited.append(token.current_stage.component)
|
||||
next_s = token.advance()
|
||||
if next_s is None:
|
||||
ctx.complete_tile()
|
||||
break
|
||||
|
||||
assert visited == ["dma", "fetch", "gemm", "fetch", "dma"]
|
||||
assert ctx.completed_tiles == 1
|
||||
|
||||
|
||||
# ── 5. Tiling plan generation ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_gemm_plan_tile_count():
|
||||
"""generate_gemm_plan produces correct number of tiles."""
|
||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||
|
||||
plan = generate_gemm_plan(
|
||||
M=64, K=128, N=64,
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
)
|
||||
# M_tiles=2, K_tiles=2, N_tiles=2 → 2*2*2 = 8 tiles
|
||||
assert len(plan.tiles) == 8
|
||||
assert plan.m_tiles == 2
|
||||
assert plan.k_tiles == 2
|
||||
assert plan.n_tiles == 2
|
||||
|
||||
|
||||
def test_gemm_plan_stage_sequence():
|
||||
"""Each GEMM tile has correct stage sequence."""
|
||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||
|
||||
plan = generate_gemm_plan(
|
||||
M=32, K=64, N=32,
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
)
|
||||
# Single tile (1x1x1), last_k=True → includes DMA_WRITE
|
||||
assert len(plan.tiles) == 1
|
||||
tile = plan.tiles[0]
|
||||
stage_types = [s.stage_type for s in tile.stages]
|
||||
assert stage_types == [
|
||||
StageType.DMA_READ, StageType.DMA_READ, # A and B
|
||||
StageType.FETCH, StageType.GEMM, StageType.STORE,
|
||||
StageType.DMA_WRITE,
|
||||
]
|
||||
|
||||
|
||||
def test_gemm_plan_intermediate_k_no_dma_write():
|
||||
"""Intermediate K-tiles don't have DMA_WRITE stage."""
|
||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||
|
||||
plan = generate_gemm_plan(
|
||||
M=32, K=128, N=32, # K_tiles=2
|
||||
tile_m=32, tile_k=64, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
)
|
||||
assert len(plan.tiles) == 2
|
||||
|
||||
# First tile (k=0): no DMA_WRITE
|
||||
t0_types = [s.stage_type for s in plan.tiles[0].stages]
|
||||
assert StageType.DMA_WRITE not in t0_types
|
||||
|
||||
# Last tile (k=1, last_k=True): has DMA_WRITE
|
||||
t1_types = [s.stage_type for s in plan.tiles[1].stages]
|
||||
assert StageType.DMA_WRITE in t1_types
|
||||
|
||||
|
||||
def test_math_plan_stage_sequence():
|
||||
"""Math plan has READ→FETCH→MATH→STORE→WRITE sequence."""
|
||||
from kernbench.components.builtin.tiling import generate_math_plan
|
||||
|
||||
plan = generate_math_plan(
|
||||
M=32, N=32,
|
||||
tile_m=32, tile_n=32,
|
||||
bytes_per_element=2,
|
||||
math_op="exp",
|
||||
src_addr=0, dst_addr=0x1000,
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
)
|
||||
assert len(plan.tiles) == 1
|
||||
stage_types = [s.stage_type for s in plan.tiles[0].stages]
|
||||
assert stage_types == [
|
||||
StageType.DMA_READ, StageType.FETCH, StageType.MATH,
|
||||
StageType.STORE, StageType.DMA_WRITE,
|
||||
]
|
||||
|
||||
|
||||
# ── 5. Async pipeline (SimPy simulation) ─────────────────────────────
|
||||
|
||||
|
||||
def test_pipeline_overlap_within_command():
|
||||
"""Tiles within same command overlap: tile1 DMA while tile0 in GEMM."""
|
||||
env = simpy.Environment()
|
||||
done_event = env.event()
|
||||
ctx = PipelineContext(id="p1", total_tiles=2, done_event=done_event)
|
||||
|
||||
# Track when each tile enters each stage
|
||||
stage_times: dict[tuple[int, str], float] = {}
|
||||
|
||||
def mock_component(env, inbox, stage_name, latency_ns, out_ports):
|
||||
while True:
|
||||
token = yield inbox.get()
|
||||
stage_times[(token.tile_id, stage_name)] = env.now
|
||||
yield env.timeout(latency_ns)
|
||||
next_s = token.advance()
|
||||
if next_s is not None:
|
||||
yield out_ports[next_s.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
dma_q = simpy.Store(env)
|
||||
gemm_q = simpy.Store(env)
|
||||
|
||||
out_ports = {"dma": dma_q, "gemm": gemm_q}
|
||||
env.process(mock_component(env, dma_q, "dma", 10.0, out_ports))
|
||||
env.process(mock_component(env, gemm_q, "gemm", 20.0, out_ports))
|
||||
|
||||
# Create 2 tiles: DMA → GEMM
|
||||
for i in range(2):
|
||||
stages = (
|
||||
Stage(StageType.DMA_READ, "dma", {}),
|
||||
Stage(StageType.GEMM, "gemm", {}),
|
||||
)
|
||||
plan = TilePlan(tile_id=i, stages=stages)
|
||||
token = TileToken(
|
||||
tile_id=i, pipeline_ctx=ctx, plan=plan,
|
||||
stage_idx=0, params={},
|
||||
)
|
||||
dma_q.put(token)
|
||||
|
||||
env.run()
|
||||
assert done_event.triggered
|
||||
|
||||
# tile0 DMA starts at 0, finishes at 10
|
||||
# tile1 DMA starts at 10, finishes at 20
|
||||
# tile0 GEMM starts at 10, finishes at 30
|
||||
# tile1 GEMM starts at 20 (wait for DMA) but actually at 30 (gemm queue)
|
||||
assert stage_times[(0, "dma")] == 0.0
|
||||
assert stage_times[(0, "gemm")] == 10.0
|
||||
assert stage_times[(1, "dma")] == 10.0
|
||||
# tile1 gemm starts when tile0 gemm finishes (serialized at gemm queue)
|
||||
assert stage_times[(1, "gemm")] == 30.0
|
||||
Reference in New Issue
Block a user