Files
kernbench2/src/kernbench/components/legacy/pe_accel/scheduler/math_pipeline.py
T
ywkang 1d95df4bee Restructure legacy backups, remove pe_accel, fix DMA self-routing
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure)
- Move pe_accel_legacy/ → legacy/pe_accel/
- Remove custom/pe_accel/ (replaced by new builtin)
- Remove pe_scheduler_v2 from components.yaml
- Switch topology.yaml to pe_scheduler_v1 (new builtin)
- Fix PE_DMA self-routing: handle consecutive DMA_READ stages
  (same component consecutive stages processed in-place, not via port)

382 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 00:02:26 -07:00

133 lines
4.2 KiB
Python

"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution.
Flow per tile:
DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile)
Mirrors GemmTilingPipeline but for unary element-wise operations.
Pipeline overlap across tiles: while one tile is in MATH_OP, the next
tile's DMA_IN can proceed concurrently.
Constructor starts two SimPy processes:
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
- _collect_completions(): waits for all tiles to writeback
"""
from __future__ import annotations
from math import ceil
import simpy
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles
from kernbench.components.custom.pe_accel.types import (
DmaInDescriptor,
DmaRequest,
DmaWBDescriptor,
MathOpDescriptor,
Trigger,
)
class MathPipeline:
"""Coordinates one tiled element-wise math operation across shared blocks."""
def __init__(
self,
env: simpy.Environment,
M: int, N: int,
tile_m: int, tile_n: int,
bytes_per_element: int,
pipeline_id: int,
reply_queue: simpy.Store,
dmaIN_cmd_q: simpy.Store,
dmaIN_to_fetch_trig: simpy.Store,
op: str = "exp",
src_addr: int = 0,
dst_addr: int = 0,
dma_in: bool = True,
dma_out: bool = True,
) -> None:
self.env = env
self.M, self.N = M, N
self.pipeline_id = pipeline_id
self.reply_queue = reply_queue
self.dmaIN_cmd_q = dmaIN_cmd_q
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
self.op = op
self._dma_in = dma_in
self._skip_dmaWB = not dma_out
_Tm = min(tile_m, M)
_Tn = min(tile_n, N)
self.M_tiles = ceil(M / tile_m)
self.N_tiles = ceil(N / tile_n)
# Generate tile schedule with pre-computed addresses
self.schedule = generate_math_tiles(
self.M_tiles, self.N_tiles,
M=M, N=N,
tile_m=_Tm, tile_n=_Tn,
bytes_per_element=bytes_per_element,
src_addr=src_addr, dst_addr=dst_addr,
pipeline_id=pipeline_id,
)
# Build descriptor tables for shared blocks
pid = pipeline_id
tile_bytes = _Tm * _Tn * bytes_per_element
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
self.math_op_descs: dict[tuple, MathOpDescriptor] = {}
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
for cmd in self.schedule.commands:
t = cmd.tile_id
if dma_in:
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
size_bytes=tile_bytes,
src_addr=cmd.src_tile_addr,
next_block="MATH",
)
self.math_op_descs[(pid, t)] = MathOpDescriptor(
Tm=_Tm, Tn=_Tn, op=op,
src_addr=cmd.src_tile_addr,
dst_addr=cmd.dst_tile_addr,
)
if not self._skip_dmaWB:
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr,
)
self.expected_flushes = self.M_tiles * self.N_tiles
self.completed_flushes = 0
self.done_at: int = 0
self.done: simpy.Event = env.event()
env.process(self._feed_commands())
env.process(self._collect_completions())
def _feed_commands(self):
"""Send DmaRequests for each tile's input to dmaIN_cmd_q."""
for cmd in self.schedule.commands:
if self._dma_in:
yield self.dmaIN_cmd_q.put(DmaRequest(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
))
else:
yield self.dmaIN_to_fetch_trig.put(Trigger(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
))
def _collect_completions(self):
"""Wait for all tile completions, then signal done."""
while self.completed_flushes < self.expected_flushes:
yield self.reply_queue.get()
self.completed_flushes += 1
self.done_at = int(self.env.now)
self.done.succeed()