"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution. Flow per tile: DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile) Mirrors GemmTilingPipeline but for unary element-wise operations. Pipeline overlap across tiles: while one tile is in MATH_OP, the next tile's DMA_IN can proceed concurrently. Constructor starts two SimPy processes: - _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q - _collect_completions(): waits for all tiles to writeback """ from __future__ import annotations from math import ceil import simpy from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles from kernbench.components.custom.pe_accel.types import ( DmaInDescriptor, DmaRequest, DmaWBDescriptor, MathOpDescriptor, Trigger, ) class MathPipeline: """Coordinates one tiled element-wise math operation across shared blocks.""" def __init__( self, env: simpy.Environment, M: int, N: int, tile_m: int, tile_n: int, bytes_per_element: int, pipeline_id: int, reply_queue: simpy.Store, dmaIN_cmd_q: simpy.Store, dmaIN_to_fetch_trig: simpy.Store, op: str = "exp", src_addr: int = 0, dst_addr: int = 0, dma_in: bool = True, dma_out: bool = True, ) -> None: self.env = env self.M, self.N = M, N self.pipeline_id = pipeline_id self.reply_queue = reply_queue self.dmaIN_cmd_q = dmaIN_cmd_q self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig self.op = op self._dma_in = dma_in self._skip_dmaWB = not dma_out _Tm = min(tile_m, M) _Tn = min(tile_n, N) self.M_tiles = ceil(M / tile_m) self.N_tiles = ceil(N / tile_n) # Generate tile schedule with pre-computed addresses self.schedule = generate_math_tiles( self.M_tiles, self.N_tiles, M=M, N=N, tile_m=_Tm, tile_n=_Tn, bytes_per_element=bytes_per_element, src_addr=src_addr, dst_addr=dst_addr, pipeline_id=pipeline_id, ) # Build descriptor tables for shared blocks pid = pipeline_id tile_bytes = _Tm * _Tn * bytes_per_element self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {} self.math_op_descs: dict[tuple, MathOpDescriptor] = {} self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {} for cmd in self.schedule.commands: t = cmd.tile_id if dma_in: self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor( size_bytes=tile_bytes, src_addr=cmd.src_tile_addr, next_block="MATH", ) self.math_op_descs[(pid, t)] = MathOpDescriptor( Tm=_Tm, Tn=_Tn, op=op, src_addr=cmd.src_tile_addr, dst_addr=cmd.dst_tile_addr, ) if not self._skip_dmaWB: self.dmaWB_descs[(pid, t)] = DmaWBDescriptor( Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr, ) self.expected_flushes = self.M_tiles * self.N_tiles self.completed_flushes = 0 self.done_at: int = 0 self.done: simpy.Event = env.event() env.process(self._feed_commands()) env.process(self._collect_completions()) def _feed_commands(self): """Send DmaRequests for each tile's input to dmaIN_cmd_q.""" for cmd in self.schedule.commands: if self._dma_in: yield self.dmaIN_cmd_q.put(DmaRequest( tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A", )) else: yield self.dmaIN_to_fetch_trig.put(Trigger( tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE", )) def _collect_completions(self): """Wait for all tile completions, then signal done.""" while self.completed_flushes < self.expected_flushes: yield self.reply_queue.get() self.completed_flushes += 1 self.done_at = int(self.env.now) self.done.succeed()