"""Tests for ADR-0021 PE pipeline: TileToken self-routing, pipeline overlap, e2e accuracy. Test plan items: 3. Phase 1 → Phase 2 end-to-end (op_log → DataExecutor → verify) 4. TileToken self-routing (stage sequence, PipelineContext completion) 5. Async pipeline overlap (intra-command tile overlap, FIFO ordering) """ from __future__ import annotations import simpy from kernbench.components.builtin.pe_types import ( PipelineContext, PipelinePlan, Stage, StageType, TilePlan, TileToken, ) # ── 4. TileToken self-routing ──────────────────────────────────────── def test_tile_token_advance(): """TileToken.advance() increments stage_idx and returns next Stage.""" stages = ( Stage(StageType.DMA_READ, "pe_dma", {"src_addr": 0}), Stage(StageType.FETCH, "pe_fetch_store", {"direction": "read"}), Stage(StageType.GEMM, "pe_gemm", {"m": 32, "k": 64, "n": 32}), ) plan = TilePlan(tile_id=0, stages=stages) ctx = PipelineContext(id="p1", total_tiles=1) token = TileToken( tile_id=0, pipeline_ctx=ctx, plan=plan, stage_idx=0, params=stages[0].params, ) assert token.current_stage.stage_type == StageType.DMA_READ next_s = token.advance() assert next_s is not None assert next_s.stage_type == StageType.FETCH assert token.stage_idx == 1 assert token.params == {"direction": "read"} next_s = token.advance() assert next_s is not None assert next_s.stage_type == StageType.GEMM assert token.stage_idx == 2 # Last stage — advance returns None assert token.advance() is None assert token.stage_idx == 3 def test_pipeline_context_completion(): """PipelineContext.complete_tile() fires done_event on last tile.""" env = simpy.Environment() done = env.event() ctx = PipelineContext(id="p1", total_tiles=3, done_event=done) ctx.complete_tile() assert not done.triggered ctx.complete_tile() assert not done.triggered ctx.complete_tile() assert done.triggered def test_pipeline_context_exactly_once(): """PipelineContext tracks completed_tiles correctly.""" ctx = PipelineContext(id="p1", total_tiles=2) assert ctx.completed_tiles == 0 ctx.complete_tile() assert ctx.completed_tiles == 1 ctx.complete_tile() assert ctx.completed_tiles == 2 def test_tile_token_self_routing_chain(): """Simulated self-routing: component reads next stage from token.""" stages = ( Stage(StageType.DMA_READ, "dma", {}), Stage(StageType.FETCH, "fetch", {}), Stage(StageType.GEMM, "gemm", {}), Stage(StageType.STORE, "fetch", {}), Stage(StageType.DMA_WRITE, "dma", {}), ) plan = TilePlan(tile_id=0, stages=stages) ctx = PipelineContext(id="p1", total_tiles=1) token = TileToken( tile_id=0, pipeline_ctx=ctx, plan=plan, stage_idx=0, params=stages[0].params, ) visited = [] while True: visited.append(token.current_stage.component) next_s = token.advance() if next_s is None: ctx.complete_tile() break assert visited == ["dma", "fetch", "gemm", "fetch", "dma"] assert ctx.completed_tiles == 1 # ── 5. Tiling plan generation ──────────────────────────────────────── def test_gemm_plan_tile_count(): """generate_gemm_plan produces correct number of tiles.""" from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=64, K=128, N=64, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) # M_tiles=2, K_tiles=2, N_tiles=2 → 2*2*2 = 8 tiles assert len(plan.tiles) == 8 assert plan.m_tiles == 2 assert plan.k_tiles == 2 assert plan.n_tiles == 2 def test_gemm_plan_stage_sequence(): """Each GEMM tile has correct stage sequence.""" from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=32, K=64, N=32, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) # Single tile (1x1x1), last_k=True → includes DMA_WRITE assert len(plan.tiles) == 1 tile = plan.tiles[0] stage_types = [s.stage_type for s in tile.stages] assert stage_types == [ StageType.DMA_READ, StageType.DMA_READ, # A and B StageType.FETCH, StageType.GEMM, StageType.STORE, StageType.DMA_WRITE, ] def test_gemm_plan_intermediate_k_no_dma_write(): """Intermediate K-tiles don't have DMA_WRITE or STORE stage. The C accumulator stays in RegFile across the K loop; STORE + DMA_WRITE only fire on the last K-tile per (m,n). """ from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( M=32, K=128, N=32, # K_tiles=2 tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) assert len(plan.tiles) == 2 # First tile (k=0): no STORE, no DMA_WRITE — accumulator stays in RegFile t0_types = [s.stage_type for s in plan.tiles[0].stages] assert StageType.STORE not in t0_types assert StageType.DMA_WRITE not in t0_types # Last tile (k=1, last_k=True): has both STORE and DMA_WRITE t1_types = [s.stage_type for s in plan.tiles[1].stages] assert StageType.STORE in t1_types assert StageType.DMA_WRITE in t1_types def test_gemm_plan_pinned_operand_skips_dma_read(): """When a_pinned=True, A's per-tile DMA_READ is omitted. Same for b_pinned. FETCH is unaffected — it still stages from TCM into RegFile. """ from kernbench.components.builtin.tiling import generate_gemm_plan # Baseline: neither pinned — both A and B get DMA_READ per tile. base = generate_gemm_plan( M=32, K=128, N=32, # K_tiles=2 tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", ) for tile in base.tiles: operands = [s.params.get("operand") for s in tile.stages if s.stage_type == StageType.DMA_READ] assert operands == ["A", "B"], \ f"baseline tile should DMA_READ A and B, got {operands}" # a_pinned: no A DMA_READ. plan_a = generate_gemm_plan( M=32, K=128, N=32, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", a_pinned=True, ) for tile in plan_a.tiles: operands = [s.params.get("operand") for s in tile.stages if s.stage_type == StageType.DMA_READ] assert operands == ["B"], \ f"a_pinned should leave only B DMA_READ, got {operands}" # FETCH must still exist assert any(s.stage_type == StageType.FETCH for s in tile.stages) # Both pinned: no DMA_READ at all. plan_both = generate_gemm_plan( M=32, K=128, N=32, tile_m=32, tile_k=64, tile_n=32, bytes_per_element=2, A_addr=0, B_addr=0x1000, C_addr=0x2000, pe_prefix="sip0.cube0.pe0", a_pinned=True, b_pinned=True, ) for tile in plan_both.tiles: dma_reads = [s for s in tile.stages if s.stage_type == StageType.DMA_READ] assert dma_reads == [], \ f"both pinned should skip all DMA_READ, got {dma_reads}" def test_math_plan_stage_sequence(): """Math plan has READ→FETCH→MATH→STORE→WRITE sequence.""" from kernbench.components.builtin.tiling import generate_math_plan plan = generate_math_plan( M=32, N=32, tile_m=32, tile_n=32, bytes_per_element=2, math_op="exp", src_addr=0, dst_addr=0x1000, pe_prefix="sip0.cube0.pe0", ) assert len(plan.tiles) == 1 stage_types = [s.stage_type for s in plan.tiles[0].stages] assert stage_types == [ StageType.DMA_READ, StageType.FETCH, StageType.MATH, StageType.STORE, StageType.DMA_WRITE, ] # ── 5. Async pipeline (SimPy simulation) ───────────────────────────── def test_pipeline_overlap_within_command(): """Tiles within same command overlap: tile1 DMA while tile0 in GEMM.""" env = simpy.Environment() done_event = env.event() ctx = PipelineContext(id="p1", total_tiles=2, done_event=done_event) # Track when each tile enters each stage stage_times: dict[tuple[int, str], float] = {} def mock_component(env, inbox, stage_name, latency_ns, out_ports): while True: token = yield inbox.get() stage_times[(token.tile_id, stage_name)] = env.now yield env.timeout(latency_ns) next_s = token.advance() if next_s is not None: yield out_ports[next_s.component].put(token) else: token.pipeline_ctx.complete_tile() dma_q = simpy.Store(env) gemm_q = simpy.Store(env) out_ports = {"dma": dma_q, "gemm": gemm_q} env.process(mock_component(env, dma_q, "dma", 10.0, out_ports)) env.process(mock_component(env, gemm_q, "gemm", 20.0, out_ports)) # Create 2 tiles: DMA → GEMM for i in range(2): stages = ( Stage(StageType.DMA_READ, "dma", {}), Stage(StageType.GEMM, "gemm", {}), ) plan = TilePlan(tile_id=i, stages=stages) token = TileToken( tile_id=i, pipeline_ctx=ctx, plan=plan, stage_idx=0, params={}, ) dma_q.put(token) env.run() assert done_event.triggered # tile0 DMA starts at 0, finishes at 10 # tile1 DMA starts at 10, finishes at 20 # tile0 GEMM starts at 10, finishes at 30 # tile1 GEMM starts at 20 (wait for DMA) but actually at 30 (gemm queue) assert stage_times[(0, "dma")] == 0.0 assert stage_times[(0, "gemm")] == 10.0 assert stage_times[(1, "dma")] == 10.0 # tile1 gemm starts when tile0 gemm finishes (serialized at gemm queue) assert stage_times[(1, "gemm")] == 30.0 # ── 6. Option B: pe_dma record_start fires post channel-acquire ──────── def test_pe_dma_record_start_after_channel_acquire(): """Three back-to-back DMA_READs serialise on pe_dma.cap=1. With ``_DEFER_RECORD_START = True`` on PeDmaComponent, each op's ``t_start`` is captured right after ``yield req`` succeeds. Result: - op N's ``(t_end - t_start)`` is the *actual transfer time* — same across all three ops (no queueing inflation). - op N+1's ``t_start`` >= op N's ``t_end - epsilon`` (waited for the previous holder to release the channel before being recorded). Counter-example (the bug this fix addresses): if ``record_start`` fired on command entry, all three ops would share ``t_start == 0`` and the second/third would show inflated ``t_end - t_start``. """ from pathlib import Path from kernbench.common.pe_commands import DmaReadCmd, PeInternalTxn, TensorHandle from kernbench.policy.address.phyaddr import PhysAddr from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import load_topology TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" def _hbm_pa() -> int: slice_bytes = 48 * (1 << 30) // 8 pa = PhysAddr.pe_hbm_addr( sip_id=0, die_id=0, pe_id=0, pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes, ) return pa.encode() # enable_data=True wires the OpLogger into every component. engine = GraphEngine(load_topology(TOPOLOGY_PATH), enable_data=True) pe_dma_id = "sip0.cube0.pe0.pe_dma" pe_dma = engine._components[pe_dma_id] env = engine._env # Three back-to-back DMA_READ commands fed straight into pe_dma's inbox # at t=0 so they all race for the cap=1 channel. handles = [ TensorHandle(id=f"r{i}", addr=0x1000 + i * 0x1000, shape=(64, 32), dtype="f16", nbytes=4096) for i in range(3) ] cmds = [ DmaReadCmd(handle=h, src_addr=_hbm_pa(), nbytes=4096) for h in handles ] txns = [PeInternalTxn(command=c, done=env.event()) for c in cmds] def submit_all(): for txn in txns: yield pe_dma._inbox.put(txn) env.process(submit_all()) env.run() # Pull the three dma_read records out of the op log in order dma_records = [ r for r in engine.op_log if r.op_name == "dma_read" and r.component_id == pe_dma_id ] assert len(dma_records) == 3, ( f"expected 3 dma_read records, got {len(dma_records)}: {dma_records}" ) durations = [r.t_end - r.t_start for r in dma_records] # All three should have the same actual transfer time within ±1 ns. base = durations[0] assert base > 0, f"first dma duration must be positive, got {base}" for i, d in enumerate(durations): assert abs(d - base) <= 1.0, ( f"op {i} duration {d} differs from baseline {base} by >1 ns " f"— record_start may still be including queue wait" ) # Each subsequent op's t_start must be at or after the previous op's # t_end (modulo a few ns of scheduler overhead) — i.e. the wait is # *excluded* from the recorded interval, not folded into it. for i in range(1, len(dma_records)): prev_end = dma_records[i - 1].t_end cur_start = dma_records[i].t_start assert cur_start >= prev_end - 1.0, ( f"op {i} t_start={cur_start} began before op {i-1} t_end={prev_end} " f"— channel was not actually held, fix is incorrect" )