"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2). Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2). Hosts four internal hardware blocks as concurrent SimPy processes: - DmaInBlock — HBM → TCM tile reads (real fabric DMA) - DmaWbBlock — TCM → HBM tile writes (real fabric DMA) - GemmBlock — 2-stage MAC pipeline (fetch + compute) - MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.) Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline: - DmaReadCmd / DmaWriteCmd → PE_DMA out_port - GemmCmd → PE_GEMM out_port - MathCmd → PE_MATH out_port - CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum) - CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA) - PeCpuOverheadCmd → yield timeout Config via node.attrs (all optional): overhead_ns, clock_freq_ghz, bytes_per_element, mac_m, mac_k, mac_n, tile_m, tile_k, tile_n, port_a_bw, port_b_bw, store_port_bw, vector_width """ from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node # ============================================================================== # Component # ============================================================================== class SchedulerV2Component(ComponentBase): """PE accelerator scheduler: wires internal blocks and dispatches commands.""" def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._pe_prefix: str = node.id.rsplit(".", 1)[0] attrs = node.attrs # Hardware config self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0)) self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0)) self._bpe: int = int(attrs.get("bytes_per_element", 2)) # MAC array dimensions self._mac_m: int = int(attrs.get("mac_m", 8)) self._mac_k: int = int(attrs.get("mac_k", 16)) self._mac_n: int = int(attrs.get("mac_n", 32)) # Tile dimensions self._tile_m: int = int(attrs.get("tile_m", 32)) self._tile_k: int = int(attrs.get("tile_k", 32)) self._tile_n: int = int(attrs.get("tile_n", 32)) # Bandwidth (bytes/cycle) self._port_a_bw: int = int(attrs.get("port_a_bw", 256)) self._port_b_bw: int = int(attrs.get("port_b_bw", 256)) self._store_port_bw: int = int(attrs.get("store_port_bw", 256)) self._vector_width: int = int(attrs.get("vector_width", 256)) # Initialized in start() self._dmaIN_block: Any = None self._dmaWB_block: Any = None self._gemm_block: Any = None self._math_block: Any = None self._dmaIN_cmd_q: simpy.Store | None = None self._dmaIN_to_fetch_trig: simpy.Store | None = None # Pipeline tracking self._next_pipeline_id: int = 0 self._pipeline_queues: dict[int, simpy.Store] = {} # -- SimPy lifecycle ------------------------------------------------------- def run(self, env: simpy.Environment, nbytes: int) -> Generator: """Scheduler overhead per dispatch.""" yield env.timeout(self._overhead_ns) def start(self, env: simpy.Environment) -> None: """Create internal blocks, wire queues, start SimPy processes.""" from kernbench.components.custom.pe_accel.blocks import ( DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock, ) pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma") # -- TCM block (shared BW-serialized scratchpad) ----------------------- # Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm) pe_links = {} if self.ctx and self.ctx.spec: pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {}) tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0)) tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0)) self._tcm_block = TcmBlock( env=env, read_bw_gbs=tcm_read_bw, write_bw_gbs=tcm_write_bw, ) # -- Internal queues --------------------------------------------------- self._dmaIN_cmd_q = simpy.Store(env) self._dmaIN_to_fetch_trig = simpy.Store(env) dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math fetch_to_gemm_trig = simpy.Store(env) gemm_to_math_trig = simpy.Store(env) gemm_to_dmaWB_trig = simpy.Store(env) math_to_dmaWB_trig = simpy.Store(env) # Completion queues (block → _completion_router → pipeline reply_q) dmaIN_completion_q = simpy.Store(env) dmaWB_completion_q = simpy.Store(env) gemm_completion_q = simpy.Store(env) math_completion_q = simpy.Store(env) # -- Create blocks ----------------------------------------------------- self._dmaIN_block = DmaInBlock( env=env, cmd_q=self._dmaIN_cmd_q, to_fetch_trig=self._dmaIN_to_fetch_trig, to_math_trig=dmaIN_to_math_trig, completion_q=dmaIN_completion_q, pe_dma_port=pe_dma_port, pe_prefix=self._pe_prefix, ) self._dmaWB_block = DmaWbBlock( env=env, completion_q=dmaWB_completion_q, pe_dma_port=pe_dma_port, pe_prefix=self._pe_prefix, bytes_per_element=self._bpe, ) self._gemm_block = GemmBlock( env=env, trig_in=self._dmaIN_to_fetch_trig, fetch_to_gemm_trig=fetch_to_gemm_trig, to_math_trig=gemm_to_math_trig, to_dmaWB_trig=gemm_to_dmaWB_trig, completion_q=gemm_completion_q, tcm_port=self._tcm_block.port, mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n, bytes_per_element=self._bpe, clock_freq_ghz=self._clock_freq_ghz, ) self._math_block = MathBlock( env=env, trig_in=gemm_to_math_trig, to_dmaWB_trig=math_to_dmaWB_trig, completion_q=math_completion_q, tcm_port=self._tcm_block.port, bytes_per_element=self._bpe, clock_freq_ghz=self._clock_freq_ghz, vector_width=self._vector_width, ) # -- Start block processes --------------------------------------------- env.process(self._tcm_block._run()) env.process(self._dmaIN_block._load_loop()) env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig)) env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig)) env.process(self._gemm_block._fetch_stage()) env.process(self._gemm_block._gemm_stage()) env.process(self._math_block._run_k_accumulation()) env.process(self._math_block._run_element_wise(dmaIN_to_math_trig)) # Wire in-ports → inbox, start _worker super().start(env) # Start completion routers for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q): env.process(self._completion_router(q)) # -- Internal processes ---------------------------------------------------- def _completion_router(self, queue: simpy.Store) -> Generator: """Route block completion triggers to the correct pipeline's reply queue.""" while True: trigger = yield queue.get() if trigger is None: break reply_q = self._pipeline_queues.get(trigger.pipeline_id) if reply_q is not None: yield reply_q.put(trigger) def _worker(self, env: simpy.Environment) -> Generator: """Main inbox loop: dispatch PE commands, forward fabric transactions.""" from kernbench.common.pe_commands import PeInternalTxn while True: msg: Any = yield self._inbox.get() if isinstance(msg, PeInternalTxn): env.process(self._dispatch(env, msg)) else: env.process(self._forward_txn(env, msg)) # ========================================================================== # Command Dispatch # ========================================================================== def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator: """Route a PeInternalTxn to the appropriate engine or pipeline.""" from kernbench.common.pe_commands import ( CompositeCmd, DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, PeCpuOverheadCmd, ) yield from self.run(env, 0) # scheduler overhead cmd = pe_txn.command pp = self._pe_prefix if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)): yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn) elif isinstance(cmd, GemmCmd): yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn) elif isinstance(cmd, MathCmd): yield self.out_ports[f"{pp}.pe_math"].put(pe_txn) elif isinstance(cmd, CompositeCmd): if cmd.op == "gemm" and cmd.b is not None: yield from self._dispatch_composite_gemm(env, pe_txn, cmd) else: yield from self._dispatch_composite_math(env, pe_txn, cmd) elif isinstance(cmd, PeCpuOverheadCmd): yield env.timeout(cmd.cycles / self._clock_freq_ghz) pe_txn.done.succeed() else: pe_txn.done.succeed() # -- GEMM composite -------------------------------------------------------- def _dispatch_composite_gemm( self, env: simpy.Environment, pe_txn: Any, cmd: Any, ) -> Generator: """Run tiled GEMM pipeline and collect per-stage metrics.""" from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline a, b = cmd.a, cmd.b M, K, N = a.shape[-2], a.shape[-1], b.shape[-1] pid, reply_q = self._alloc_pipeline(env) pipeline = GemmPipeline( env=env, M=M, K=K, N=N, tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n, bytes_per_element=self._bpe, pipeline_id=pid, reply_queue=reply_q, dmaIN_cmd_q=self._dmaIN_cmd_q, dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr, ) # Load descriptors into shared blocks self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) self._gemm_block.load_descriptors(pipeline.gemm_descs) self._math_block.load_descriptors(pipeline.math_descs) self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) start_ns = env.now yield pipeline.done wall_clock_ns = env.now - start_ns del self._pipeline_queues[pid] # -- Collect metrics --------------------------------------------------- rd = pe_txn.result_data rd["M_tiles"] = pipeline.M_tiles rd["K_tiles"] = pipeline.K_tiles rd["N_tiles"] = pipeline.N_tiles rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values()) rd["total_dma_write_bytes"] = sum( d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values() ) rd["wall_clock_ns"] = wall_clock_ns # Per-stage timing rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = ( _collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, [])) ) rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = ( _collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, [])) ) rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = ( _collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, [])) ) rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) ) rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) ) # Derived metrics rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0 rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0 rd["pipeline_parallelism"] = ( (rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"]) / wall_clock_ns if wall_clock_ns > 0 else 0.0 ) rd["effective_read_bw_gbs"] = ( rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"] if rd["t_dma_read_total_ns"] > 0 else None ) rd["effective_write_bw_gbs"] = ( rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"] if rd["t_dma_write_total_ns"] > 0 else None ) rd["bottleneck_stage"] = max( [("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]), ("store", rd["t_tcm_store_max_ns"])], key=lambda x: x[1], )[0] # pe_cpu.py compatibility aliases rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] rd["compute_ns"] = rd["t_compute_total_ns"] pe_txn.done.succeed() # -- Math composite -------------------------------------------------------- def _dispatch_composite_math( self, env: simpy.Environment, pe_txn: Any, cmd: Any, ) -> Generator: """Run tiled element-wise math pipeline and collect metrics.""" from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline assert self._dmaIN_cmd_q is not None assert self._dmaIN_to_fetch_trig is not None a = cmd.a M = a.shape[-2] if len(a.shape) >= 2 else 1 N = a.shape[-1] op = cmd.math_op or "identity" pid, reply_q = self._alloc_pipeline(env) pipeline = MathPipeline( env=env, M=M, N=N, tile_m=self._tile_m, tile_n=self._tile_n, bytes_per_element=self._bpe, pipeline_id=pid, reply_queue=reply_q, dmaIN_cmd_q=self._dmaIN_cmd_q, dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, op=op, src_addr=a.addr, dst_addr=cmd.out_addr, ) # Load descriptors into shared blocks self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) self._math_block.load_math_op_descriptors(pipeline.math_op_descs) self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) start_ns = env.now yield pipeline.done wall_clock_ns = env.now - start_ns del self._pipeline_queues[pid] # -- Collect metrics --------------------------------------------------- rd = pe_txn.result_data rd["M_tiles"] = pipeline.M_tiles rd["N_tiles"] = pipeline.N_tiles rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles rd["wall_clock_ns"] = wall_clock_ns rd["math_op"] = op # DMA timing rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) ) rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) ) # Math op timing (load + compute + store) rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = ( _collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, [])) ) rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = ( _collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, [])) ) rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = ( _collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, [])) ) # pe_cpu.py compatibility aliases rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] rd["compute_ns"] = rd["t_math_compute_total_ns"] pe_txn.done.succeed() # -- Helpers --------------------------------------------------------------- def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]: """Allocate a pipeline ID and reply queue.""" pid = self._next_pipeline_id self._next_pipeline_id += 1 reply_q = simpy.Store(env) self._pipeline_queues[pid] = reply_q return pid, reply_q # ============================================================================== # Utility # ============================================================================== def _collect_timing(times: list) -> tuple[list, float, float]: """Return (raw_list, max_ns, total_ns) from a timing list.""" return times, max(times, default=0.0), sum(times)