"""Tests for ADR-0021 PE pipeline: TileToken self-routing, pipeline overlap, e2e accuracy. Test plan items: 3. Phase 1 → Phase 2 end-to-end (op_log → DataExecutor → verify) 4. TileToken self-routing (stage sequence, PipelineContext completion) 5. Async pipeline overlap (intra-command tile overlap, FIFO ordering) """ from __future__ import annotations import simpy from kernbench.components.builtin.pe_types import ( PipelineContext, PipelinePlan, Stage, StageType, TilePlan, TileToken, ) # ── 4. TileToken self-routing ──────────────────────────────────────── def test_tile_token_advance(): """TileToken.advance() increments stage_idx and returns next Stage.""" stages = ( Stage(StageType.DMA_READ, "pe_dma", {"src_addr": 0}), Stage(StageType.FETCH, "pe_fetch_store", {"direction": "read"}), Stage(StageType.GEMM, "pe_gemm", {"m": 32, "k": 64, "n": 32}), ) plan = TilePlan(tile_id=0, stages=stages) ctx = PipelineContext(id="p1", total_tiles=1) token = TileToken( tile_id=0, pipeline_ctx=ctx, plan=plan, stage_idx=0, params=stages[0].params, ) assert token.current_stage.stage_type == StageType.DMA_READ next_s = token.advance() assert next_s is not None assert next_s.stage_type == StageType.FETCH assert token.stage_idx == 1 assert token.params == {"direction": "read"} next_s = token.advance() assert next_s is not None assert next_s.stage_type == StageType.GEMM assert token.stage_idx == 2 # Last stage — advance returns None assert token.advance() is None assert token.stage_idx == 3 def test_pipeline_context_completion(): """PipelineContext.complete_tile() fires done_event on last tile.""" env = simpy.Environment() done = env.event() ctx = PipelineContext(id="p1", total_tiles=3, done_event=done) ctx.complete_tile() assert not done.triggered ctx.complete_tile() assert not done.triggered ctx.complete_tile() assert done.triggered def test_pipeline_context_exactly_once(): """PipelineContext tracks completed_tiles correctly.""" ctx = PipelineContext(id="p1", total_tiles=2) assert ctx.completed_tiles == 0 ctx.complete_tile() assert ctx.completed_tiles == 1 ctx.complete_tile() assert ctx.completed_tiles == 2 def test_tile_token_self_routing_chain(): """Simulated self-routing: component reads next stage from token.""" stages = ( Stage(StageType.DMA_READ, "dma", {}), Stage(StageType.FETCH, "fetch", {}), Stage(StageType.GEMM, "gemm", {}), Stage(StageType.STORE, "fetch", {}), Stage(StageType.DMA_WRITE, "dma", {}), ) plan = TilePlan(tile_id=0, stages=stages) ctx = PipelineContext(id="p1", total_tiles=1) token = TileToken( tile_id=0, pipeline_ctx=ctx, plan=plan, stage_idx=0, params=stages[0].params, ) visited = [] while True: visited.append(token.current_stage.component) next_s = token.advance() if next_s is None: ctx.complete_tile() break assert visited == ["dma", "fetch", "gemm", "fetch", "dma"] assert ctx.completed_tiles == 1 # ── 5. Tiling plan generation ──────────────────────────────────────── def test_gemm_plan_tile_count(): """generate_gemm_plan produces correct number of tiles.""" from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=64, K=128, N=64, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) # M_tiles=2, K_tiles=2, N_tiles=2 → 2*2*2 = 8 tiles assert len(plan.tiles) == 8 assert plan.m_tiles == 2 assert plan.k_tiles == 2 assert plan.n_tiles == 2 def test_gemm_plan_stage_sequence(): """Each GEMM tile has correct stage sequence.""" from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=32, K=64, N=32, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) # Single tile (1x1x1), last_k=True → includes DMA_WRITE assert len(plan.tiles) == 1 tile = plan.tiles[0] stage_types = [s.stage_type for s in tile.stages] assert stage_types == [ StageType.DMA_READ, StageType.DMA_READ, # A and B StageType.FETCH, StageType.GEMM, StageType.STORE, StageType.DMA_WRITE, ] def test_gemm_plan_intermediate_k_no_dma_write(): """Intermediate K-tiles don't have DMA_WRITE stage.""" from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=32, K=128, N=32, # K_tiles=2 tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) assert len(plan.tiles) == 2 # First tile (k=0): no DMA_WRITE t0_types = [s.stage_type for s in plan.tiles[0].stages] assert StageType.DMA_WRITE not in t0_types # Last tile (k=1, last_k=True): has DMA_WRITE t1_types = [s.stage_type for s in plan.tiles[1].stages] assert StageType.DMA_WRITE in t1_types def test_math_plan_stage_sequence(): """Math plan has READ→FETCH→MATH→STORE→WRITE sequence.""" from kernbench.components.builtin.tiling import generate_math_plan plan = generate_math_plan( M=32, N=32, tile_m=32, tile_n=32, bytes_per_element=2, math_op="exp", src_addr=0, dst_addr=0x1000, pe_prefix="sip0.cube0.pe0", ) assert len(plan.tiles) == 1 stage_types = [s.stage_type for s in plan.tiles[0].stages] assert stage_types == [ StageType.DMA_READ, StageType.FETCH, StageType.MATH, StageType.STORE, StageType.DMA_WRITE, ] # ── 5. Async pipeline (SimPy simulation) ───────────────────────────── def test_pipeline_overlap_within_command(): """Tiles within same command overlap: tile1 DMA while tile0 in GEMM.""" env = simpy.Environment() done_event = env.event() ctx = PipelineContext(id="p1", total_tiles=2, done_event=done_event) # Track when each tile enters each stage stage_times: dict[tuple[int, str], float] = {} def mock_component(env, inbox, stage_name, latency_ns, out_ports): while True: token = yield inbox.get() stage_times[(token.tile_id, stage_name)] = env.now yield env.timeout(latency_ns) next_s = token.advance() if next_s is not None: yield out_ports[next_s.component].put(token) else: token.pipeline_ctx.complete_tile() dma_q = simpy.Store(env) gemm_q = simpy.Store(env) out_ports = {"dma": dma_q, "gemm": gemm_q} env.process(mock_component(env, dma_q, "dma", 10.0, out_ports)) env.process(mock_component(env, gemm_q, "gemm", 20.0, out_ports)) # Create 2 tiles: DMA → GEMM for i in range(2): stages = ( Stage(StageType.DMA_READ, "dma", {}), Stage(StageType.GEMM, "gemm", {}), ) plan = TilePlan(tile_id=i, stages=stages) token = TileToken( tile_id=i, pipeline_ctx=ctx, plan=plan, stage_idx=0, params={}, ) dma_q.put(token) env.run() assert done_event.triggered # tile0 DMA starts at 0, finishes at 10 # tile1 DMA starts at 10, finishes at 20 # tile0 GEMM starts at 10, finishes at 30 # tile1 GEMM starts at 20 (wait for DMA) but actually at 30 (gemm queue) assert stage_times[(0, "dma")] == 0.0 assert stage_times[(0, "gemm")] == 10.0 assert stage_times[(1, "dma")] == 10.0 # tile1 gemm starts when tile0 gemm finishes (serialized at gemm queue) assert stage_times[(1, "gemm")] == 30.0