1d95df4bee
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure) - Move pe_accel_legacy/ → legacy/pe_accel/ - Remove custom/pe_accel/ (replaced by new builtin) - Remove pe_scheduler_v2 from components.yaml - Switch topology.yaml to pe_scheduler_v1 (new builtin) - Fix PE_DMA self-routing: handle consecutive DMA_READ stages (same component consecutive stages processed in-place, not via port) 382 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution.
|
|
|
|
Flow per tile:
|
|
DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile)
|
|
|
|
Mirrors GemmTilingPipeline but for unary element-wise operations.
|
|
Pipeline overlap across tiles: while one tile is in MATH_OP, the next
|
|
tile's DMA_IN can proceed concurrently.
|
|
|
|
Constructor starts two SimPy processes:
|
|
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
|
|
- _collect_completions(): waits for all tiles to writeback
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from math import ceil
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles
|
|
from kernbench.components.custom.pe_accel.types import (
|
|
DmaInDescriptor,
|
|
DmaRequest,
|
|
DmaWBDescriptor,
|
|
MathOpDescriptor,
|
|
Trigger,
|
|
)
|
|
|
|
|
|
class MathPipeline:
|
|
"""Coordinates one tiled element-wise math operation across shared blocks."""
|
|
|
|
def __init__(
|
|
self,
|
|
env: simpy.Environment,
|
|
M: int, N: int,
|
|
tile_m: int, tile_n: int,
|
|
bytes_per_element: int,
|
|
pipeline_id: int,
|
|
reply_queue: simpy.Store,
|
|
dmaIN_cmd_q: simpy.Store,
|
|
dmaIN_to_fetch_trig: simpy.Store,
|
|
op: str = "exp",
|
|
src_addr: int = 0,
|
|
dst_addr: int = 0,
|
|
dma_in: bool = True,
|
|
dma_out: bool = True,
|
|
) -> None:
|
|
self.env = env
|
|
self.M, self.N = M, N
|
|
self.pipeline_id = pipeline_id
|
|
self.reply_queue = reply_queue
|
|
self.dmaIN_cmd_q = dmaIN_cmd_q
|
|
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
|
|
self.op = op
|
|
|
|
self._dma_in = dma_in
|
|
self._skip_dmaWB = not dma_out
|
|
|
|
_Tm = min(tile_m, M)
|
|
_Tn = min(tile_n, N)
|
|
|
|
self.M_tiles = ceil(M / tile_m)
|
|
self.N_tiles = ceil(N / tile_n)
|
|
|
|
# Generate tile schedule with pre-computed addresses
|
|
self.schedule = generate_math_tiles(
|
|
self.M_tiles, self.N_tiles,
|
|
M=M, N=N,
|
|
tile_m=_Tm, tile_n=_Tn,
|
|
bytes_per_element=bytes_per_element,
|
|
src_addr=src_addr, dst_addr=dst_addr,
|
|
pipeline_id=pipeline_id,
|
|
)
|
|
|
|
# Build descriptor tables for shared blocks
|
|
pid = pipeline_id
|
|
tile_bytes = _Tm * _Tn * bytes_per_element
|
|
|
|
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
|
|
self.math_op_descs: dict[tuple, MathOpDescriptor] = {}
|
|
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
|
|
|
|
for cmd in self.schedule.commands:
|
|
t = cmd.tile_id
|
|
|
|
if dma_in:
|
|
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
|
|
size_bytes=tile_bytes,
|
|
src_addr=cmd.src_tile_addr,
|
|
next_block="MATH",
|
|
)
|
|
|
|
self.math_op_descs[(pid, t)] = MathOpDescriptor(
|
|
Tm=_Tm, Tn=_Tn, op=op,
|
|
src_addr=cmd.src_tile_addr,
|
|
dst_addr=cmd.dst_tile_addr,
|
|
)
|
|
|
|
if not self._skip_dmaWB:
|
|
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
|
|
Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr,
|
|
)
|
|
|
|
self.expected_flushes = self.M_tiles * self.N_tiles
|
|
self.completed_flushes = 0
|
|
self.done_at: int = 0
|
|
self.done: simpy.Event = env.event()
|
|
|
|
env.process(self._feed_commands())
|
|
env.process(self._collect_completions())
|
|
|
|
def _feed_commands(self):
|
|
"""Send DmaRequests for each tile's input to dmaIN_cmd_q."""
|
|
for cmd in self.schedule.commands:
|
|
if self._dma_in:
|
|
yield self.dmaIN_cmd_q.put(DmaRequest(
|
|
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
|
|
))
|
|
else:
|
|
yield self.dmaIN_to_fetch_trig.put(Trigger(
|
|
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
|
|
))
|
|
|
|
def _collect_completions(self):
|
|
"""Wait for all tile completions, then signal done."""
|
|
while self.completed_flushes < self.expected_flushes:
|
|
yield self.reply_queue.get()
|
|
self.completed_flushes += 1
|
|
|
|
self.done_at = int(self.env.now)
|
|
self.done.succeed()
|