Files
kernbench2/src/kernbench/components/legacy/pe_accel/scheduler/scheduler.py
T
ywkang 1d95df4bee Restructure legacy backups, remove pe_accel, fix DMA self-routing
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure)
- Move pe_accel_legacy/ → legacy/pe_accel/
- Remove custom/pe_accel/ (replaced by new builtin)
- Remove pe_scheduler_v2 from components.yaml
- Switch topology.yaml to pe_scheduler_v1 (new builtin)
- Fix PE_DMA self-routing: handle consecutive DMA_READ stages
  (same component consecutive stages processed in-place, not via port)

382 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 00:02:26 -07:00

435 lines
17 KiB
Python

"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2).
Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2).
Hosts four internal hardware blocks as concurrent SimPy processes:
- DmaInBlock — HBM → TCM tile reads (real fabric DMA)
- DmaWbBlock — TCM → HBM tile writes (real fabric DMA)
- GemmBlock — 2-stage MAC pipeline (fetch + compute)
- MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.)
Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline:
- DmaReadCmd / DmaWriteCmd → PE_DMA out_port
- GemmCmd → PE_GEMM out_port
- MathCmd → PE_MATH out_port
- CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum)
- CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA)
- PeCpuOverheadCmd → yield timeout
Config via node.attrs (all optional):
overhead_ns, clock_freq_ghz, bytes_per_element,
mac_m, mac_k, mac_n, tile_m, tile_k, tile_n,
port_a_bw, port_b_bw, store_port_bw, vector_width
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# ==============================================================================
# Component
# ==============================================================================
class SchedulerV2Component(ComponentBase):
"""PE accelerator scheduler: wires internal blocks and dispatches commands."""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
attrs = node.attrs
# Hardware config
self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0))
self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0))
self._bpe: int = int(attrs.get("bytes_per_element", 2))
# MAC array dimensions
self._mac_m: int = int(attrs.get("mac_m", 8))
self._mac_k: int = int(attrs.get("mac_k", 16))
self._mac_n: int = int(attrs.get("mac_n", 32))
# Tile dimensions
self._tile_m: int = int(attrs.get("tile_m", 32))
self._tile_k: int = int(attrs.get("tile_k", 32))
self._tile_n: int = int(attrs.get("tile_n", 32))
# Bandwidth (bytes/cycle)
self._port_a_bw: int = int(attrs.get("port_a_bw", 256))
self._port_b_bw: int = int(attrs.get("port_b_bw", 256))
self._store_port_bw: int = int(attrs.get("store_port_bw", 256))
self._vector_width: int = int(attrs.get("vector_width", 256))
# Initialized in start()
self._dmaIN_block: Any = None
self._dmaWB_block: Any = None
self._gemm_block: Any = None
self._math_block: Any = None
self._dmaIN_cmd_q: simpy.Store | None = None
self._dmaIN_to_fetch_trig: simpy.Store | None = None
# Pipeline tracking
self._next_pipeline_id: int = 0
self._pipeline_queues: dict[int, simpy.Store] = {}
# -- SimPy lifecycle -------------------------------------------------------
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
"""Scheduler overhead per dispatch."""
yield env.timeout(self._overhead_ns)
def start(self, env: simpy.Environment) -> None:
"""Create internal blocks, wire queues, start SimPy processes."""
from kernbench.components.custom.pe_accel.blocks import (
DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock,
)
pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma")
# -- TCM block (shared BW-serialized scratchpad) -----------------------
# Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm)
pe_links = {}
if self.ctx and self.ctx.spec:
pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {})
tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0))
tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0))
self._tcm_block = TcmBlock(
env=env,
read_bw_gbs=tcm_read_bw,
write_bw_gbs=tcm_write_bw,
)
# -- Internal queues ---------------------------------------------------
self._dmaIN_cmd_q = simpy.Store(env)
self._dmaIN_to_fetch_trig = simpy.Store(env)
dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math
fetch_to_gemm_trig = simpy.Store(env)
gemm_to_math_trig = simpy.Store(env)
gemm_to_dmaWB_trig = simpy.Store(env)
math_to_dmaWB_trig = simpy.Store(env)
# Completion queues (block → _completion_router → pipeline reply_q)
dmaIN_completion_q = simpy.Store(env)
dmaWB_completion_q = simpy.Store(env)
gemm_completion_q = simpy.Store(env)
math_completion_q = simpy.Store(env)
# -- Create blocks -----------------------------------------------------
self._dmaIN_block = DmaInBlock(
env=env,
cmd_q=self._dmaIN_cmd_q,
to_fetch_trig=self._dmaIN_to_fetch_trig,
to_math_trig=dmaIN_to_math_trig,
completion_q=dmaIN_completion_q,
pe_dma_port=pe_dma_port,
pe_prefix=self._pe_prefix,
)
self._dmaWB_block = DmaWbBlock(
env=env,
completion_q=dmaWB_completion_q,
pe_dma_port=pe_dma_port,
pe_prefix=self._pe_prefix,
bytes_per_element=self._bpe,
)
self._gemm_block = GemmBlock(
env=env,
trig_in=self._dmaIN_to_fetch_trig,
fetch_to_gemm_trig=fetch_to_gemm_trig,
to_math_trig=gemm_to_math_trig,
to_dmaWB_trig=gemm_to_dmaWB_trig,
completion_q=gemm_completion_q,
tcm_port=self._tcm_block.port,
mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n,
bytes_per_element=self._bpe,
clock_freq_ghz=self._clock_freq_ghz,
)
self._math_block = MathBlock(
env=env,
trig_in=gemm_to_math_trig,
to_dmaWB_trig=math_to_dmaWB_trig,
completion_q=math_completion_q,
tcm_port=self._tcm_block.port,
bytes_per_element=self._bpe,
clock_freq_ghz=self._clock_freq_ghz,
vector_width=self._vector_width,
)
# -- Start block processes ---------------------------------------------
env.process(self._tcm_block._run())
env.process(self._dmaIN_block._load_loop())
env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig))
env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig))
env.process(self._gemm_block._fetch_stage())
env.process(self._gemm_block._gemm_stage())
env.process(self._math_block._run_k_accumulation())
env.process(self._math_block._run_element_wise(dmaIN_to_math_trig))
# Wire in-ports → inbox, start _worker
super().start(env)
# Start completion routers
for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q):
env.process(self._completion_router(q))
# -- Internal processes ----------------------------------------------------
def _completion_router(self, queue: simpy.Store) -> Generator:
"""Route block completion triggers to the correct pipeline's reply queue."""
while True:
trigger = yield queue.get()
if trigger is None:
break
reply_q = self._pipeline_queues.get(trigger.pipeline_id)
if reply_q is not None:
yield reply_q.put(trigger)
def _worker(self, env: simpy.Environment) -> Generator:
"""Main inbox loop: dispatch PE commands, forward fabric transactions."""
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
env.process(self._forward_txn(env, msg))
# ==========================================================================
# Command Dispatch
# ==========================================================================
def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Route a PeInternalTxn to the appropriate engine or pipeline."""
from kernbench.common.pe_commands import (
CompositeCmd, DmaReadCmd, DmaWriteCmd,
GemmCmd, MathCmd, PeCpuOverheadCmd,
)
yield from self.run(env, 0) # scheduler overhead
cmd = pe_txn.command
pp = self._pe_prefix
if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)):
yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn)
elif isinstance(cmd, GemmCmd):
yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn)
elif isinstance(cmd, MathCmd):
yield self.out_ports[f"{pp}.pe_math"].put(pe_txn)
elif isinstance(cmd, CompositeCmd):
if cmd.op == "gemm" and cmd.b is not None:
yield from self._dispatch_composite_gemm(env, pe_txn, cmd)
else:
yield from self._dispatch_composite_math(env, pe_txn, cmd)
elif isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles / self._clock_freq_ghz)
pe_txn.done.succeed()
else:
pe_txn.done.succeed()
# -- GEMM composite --------------------------------------------------------
def _dispatch_composite_gemm(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Run tiled GEMM pipeline and collect per-stage metrics."""
from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline
a, b = cmd.a, cmd.b
M, K, N = a.shape[-2], a.shape[-1], b.shape[-1]
pid, reply_q = self._alloc_pipeline(env)
pipeline = GemmPipeline(
env=env, M=M, K=K, N=N,
tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n,
bytes_per_element=self._bpe,
pipeline_id=pid,
reply_queue=reply_q,
dmaIN_cmd_q=self._dmaIN_cmd_q,
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
)
# Load descriptors into shared blocks
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
self._gemm_block.load_descriptors(pipeline.gemm_descs)
self._math_block.load_descriptors(pipeline.math_descs)
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
start_ns = env.now
yield pipeline.done
wall_clock_ns = env.now - start_ns
del self._pipeline_queues[pid]
# -- Collect metrics ---------------------------------------------------
rd = pe_txn.result_data
rd["M_tiles"] = pipeline.M_tiles
rd["K_tiles"] = pipeline.K_tiles
rd["N_tiles"] = pipeline.N_tiles
rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles
rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles
rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values())
rd["total_dma_write_bytes"] = sum(
d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values()
)
rd["wall_clock_ns"] = wall_clock_ns
# Per-stage timing
rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = (
_collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, []))
)
rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = (
_collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, []))
)
rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = (
_collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, []))
)
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
)
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
)
# Derived metrics
rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0
rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0
rd["pipeline_parallelism"] = (
(rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"])
/ wall_clock_ns if wall_clock_ns > 0 else 0.0
)
rd["effective_read_bw_gbs"] = (
rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"]
if rd["t_dma_read_total_ns"] > 0 else None
)
rd["effective_write_bw_gbs"] = (
rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"]
if rd["t_dma_write_total_ns"] > 0 else None
)
rd["bottleneck_stage"] = max(
[("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]),
("store", rd["t_tcm_store_max_ns"])],
key=lambda x: x[1],
)[0]
# pe_cpu.py compatibility aliases
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
rd["compute_ns"] = rd["t_compute_total_ns"]
pe_txn.done.succeed()
# -- Math composite --------------------------------------------------------
def _dispatch_composite_math(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Run tiled element-wise math pipeline and collect metrics."""
from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline
assert self._dmaIN_cmd_q is not None
assert self._dmaIN_to_fetch_trig is not None
a = cmd.a
M = a.shape[-2] if len(a.shape) >= 2 else 1
N = a.shape[-1]
op = cmd.math_op or "identity"
pid, reply_q = self._alloc_pipeline(env)
pipeline = MathPipeline(
env=env, M=M, N=N,
tile_m=self._tile_m, tile_n=self._tile_n,
bytes_per_element=self._bpe,
pipeline_id=pid,
reply_queue=reply_q,
dmaIN_cmd_q=self._dmaIN_cmd_q,
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
op=op,
src_addr=a.addr, dst_addr=cmd.out_addr,
)
# Load descriptors into shared blocks
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
self._math_block.load_math_op_descriptors(pipeline.math_op_descs)
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
start_ns = env.now
yield pipeline.done
wall_clock_ns = env.now - start_ns
del self._pipeline_queues[pid]
# -- Collect metrics ---------------------------------------------------
rd = pe_txn.result_data
rd["M_tiles"] = pipeline.M_tiles
rd["N_tiles"] = pipeline.N_tiles
rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles
rd["wall_clock_ns"] = wall_clock_ns
rd["math_op"] = op
# DMA timing
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
)
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
)
# Math op timing (load + compute + store)
rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = (
_collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, []))
)
rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = (
_collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, []))
)
rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = (
_collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, []))
)
# pe_cpu.py compatibility aliases
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
rd["compute_ns"] = rd["t_math_compute_total_ns"]
pe_txn.done.succeed()
# -- Helpers ---------------------------------------------------------------
def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]:
"""Allocate a pipeline ID and reply queue."""
pid = self._next_pipeline_id
self._next_pipeline_id += 1
reply_q = simpy.Store(env)
self._pipeline_queues[pid] = reply_q
return pid, reply_q
# ==============================================================================
# Utility
# ==============================================================================
def _collect_timing(times: list) -> tuple[list, float, float]:
"""Return (raw_list, max_ns, total_ns) from a timing list."""
return times, max(times, default=0.0), sum(times)