1d95df4bee
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure) - Move pe_accel_legacy/ → legacy/pe_accel/ - Remove custom/pe_accel/ (replaced by new builtin) - Remove pe_scheduler_v2 from components.yaml - Switch topology.yaml to pe_scheduler_v1 (new builtin) - Fix PE_DMA self-routing: handle consecutive DMA_READ stages (same component consecutive stages processed in-place, not via port) 382 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
435 lines
17 KiB
Python
435 lines
17 KiB
Python
"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2).
|
|
|
|
Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2).
|
|
|
|
Hosts four internal hardware blocks as concurrent SimPy processes:
|
|
- DmaInBlock — HBM → TCM tile reads (real fabric DMA)
|
|
- DmaWbBlock — TCM → HBM tile writes (real fabric DMA)
|
|
- GemmBlock — 2-stage MAC pipeline (fetch + compute)
|
|
- MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.)
|
|
|
|
Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline:
|
|
- DmaReadCmd / DmaWriteCmd → PE_DMA out_port
|
|
- GemmCmd → PE_GEMM out_port
|
|
- MathCmd → PE_MATH out_port
|
|
- CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum)
|
|
- CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA)
|
|
- PeCpuOverheadCmd → yield timeout
|
|
|
|
Config via node.attrs (all optional):
|
|
overhead_ns, clock_freq_ghz, bytes_per_element,
|
|
mac_m, mac_k, mac_n, tile_m, tile_k, tile_n,
|
|
port_a_bw, port_b_bw, store_port_bw, vector_width
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import ComponentBase
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
|
|
# ==============================================================================
|
|
# Component
|
|
# ==============================================================================
|
|
|
|
|
|
class SchedulerV2Component(ComponentBase):
|
|
"""PE accelerator scheduler: wires internal blocks and dispatches commands."""
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
|
attrs = node.attrs
|
|
|
|
# Hardware config
|
|
self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0))
|
|
self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0))
|
|
self._bpe: int = int(attrs.get("bytes_per_element", 2))
|
|
|
|
# MAC array dimensions
|
|
self._mac_m: int = int(attrs.get("mac_m", 8))
|
|
self._mac_k: int = int(attrs.get("mac_k", 16))
|
|
self._mac_n: int = int(attrs.get("mac_n", 32))
|
|
|
|
# Tile dimensions
|
|
self._tile_m: int = int(attrs.get("tile_m", 32))
|
|
self._tile_k: int = int(attrs.get("tile_k", 32))
|
|
self._tile_n: int = int(attrs.get("tile_n", 32))
|
|
|
|
# Bandwidth (bytes/cycle)
|
|
self._port_a_bw: int = int(attrs.get("port_a_bw", 256))
|
|
self._port_b_bw: int = int(attrs.get("port_b_bw", 256))
|
|
self._store_port_bw: int = int(attrs.get("store_port_bw", 256))
|
|
self._vector_width: int = int(attrs.get("vector_width", 256))
|
|
|
|
# Initialized in start()
|
|
self._dmaIN_block: Any = None
|
|
self._dmaWB_block: Any = None
|
|
self._gemm_block: Any = None
|
|
self._math_block: Any = None
|
|
self._dmaIN_cmd_q: simpy.Store | None = None
|
|
self._dmaIN_to_fetch_trig: simpy.Store | None = None
|
|
|
|
# Pipeline tracking
|
|
self._next_pipeline_id: int = 0
|
|
self._pipeline_queues: dict[int, simpy.Store] = {}
|
|
|
|
# -- SimPy lifecycle -------------------------------------------------------
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
"""Scheduler overhead per dispatch."""
|
|
yield env.timeout(self._overhead_ns)
|
|
|
|
def start(self, env: simpy.Environment) -> None:
|
|
"""Create internal blocks, wire queues, start SimPy processes."""
|
|
from kernbench.components.custom.pe_accel.blocks import (
|
|
DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock,
|
|
)
|
|
|
|
pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma")
|
|
|
|
# -- TCM block (shared BW-serialized scratchpad) -----------------------
|
|
|
|
# Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm)
|
|
pe_links = {}
|
|
if self.ctx and self.ctx.spec:
|
|
pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {})
|
|
tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0))
|
|
tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0))
|
|
|
|
self._tcm_block = TcmBlock(
|
|
env=env,
|
|
read_bw_gbs=tcm_read_bw,
|
|
write_bw_gbs=tcm_write_bw,
|
|
)
|
|
|
|
# -- Internal queues ---------------------------------------------------
|
|
|
|
self._dmaIN_cmd_q = simpy.Store(env)
|
|
self._dmaIN_to_fetch_trig = simpy.Store(env)
|
|
dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math
|
|
fetch_to_gemm_trig = simpy.Store(env)
|
|
gemm_to_math_trig = simpy.Store(env)
|
|
gemm_to_dmaWB_trig = simpy.Store(env)
|
|
math_to_dmaWB_trig = simpy.Store(env)
|
|
|
|
# Completion queues (block → _completion_router → pipeline reply_q)
|
|
dmaIN_completion_q = simpy.Store(env)
|
|
dmaWB_completion_q = simpy.Store(env)
|
|
gemm_completion_q = simpy.Store(env)
|
|
math_completion_q = simpy.Store(env)
|
|
|
|
# -- Create blocks -----------------------------------------------------
|
|
|
|
self._dmaIN_block = DmaInBlock(
|
|
env=env,
|
|
cmd_q=self._dmaIN_cmd_q,
|
|
to_fetch_trig=self._dmaIN_to_fetch_trig,
|
|
to_math_trig=dmaIN_to_math_trig,
|
|
completion_q=dmaIN_completion_q,
|
|
pe_dma_port=pe_dma_port,
|
|
pe_prefix=self._pe_prefix,
|
|
)
|
|
|
|
self._dmaWB_block = DmaWbBlock(
|
|
env=env,
|
|
completion_q=dmaWB_completion_q,
|
|
pe_dma_port=pe_dma_port,
|
|
pe_prefix=self._pe_prefix,
|
|
bytes_per_element=self._bpe,
|
|
)
|
|
|
|
self._gemm_block = GemmBlock(
|
|
env=env,
|
|
trig_in=self._dmaIN_to_fetch_trig,
|
|
fetch_to_gemm_trig=fetch_to_gemm_trig,
|
|
to_math_trig=gemm_to_math_trig,
|
|
to_dmaWB_trig=gemm_to_dmaWB_trig,
|
|
completion_q=gemm_completion_q,
|
|
tcm_port=self._tcm_block.port,
|
|
mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n,
|
|
bytes_per_element=self._bpe,
|
|
clock_freq_ghz=self._clock_freq_ghz,
|
|
)
|
|
|
|
self._math_block = MathBlock(
|
|
env=env,
|
|
trig_in=gemm_to_math_trig,
|
|
to_dmaWB_trig=math_to_dmaWB_trig,
|
|
completion_q=math_completion_q,
|
|
tcm_port=self._tcm_block.port,
|
|
bytes_per_element=self._bpe,
|
|
clock_freq_ghz=self._clock_freq_ghz,
|
|
vector_width=self._vector_width,
|
|
)
|
|
|
|
# -- Start block processes ---------------------------------------------
|
|
|
|
env.process(self._tcm_block._run())
|
|
env.process(self._dmaIN_block._load_loop())
|
|
env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig))
|
|
env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig))
|
|
env.process(self._gemm_block._fetch_stage())
|
|
env.process(self._gemm_block._gemm_stage())
|
|
env.process(self._math_block._run_k_accumulation())
|
|
env.process(self._math_block._run_element_wise(dmaIN_to_math_trig))
|
|
|
|
# Wire in-ports → inbox, start _worker
|
|
super().start(env)
|
|
|
|
# Start completion routers
|
|
for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q):
|
|
env.process(self._completion_router(q))
|
|
|
|
# -- Internal processes ----------------------------------------------------
|
|
|
|
def _completion_router(self, queue: simpy.Store) -> Generator:
|
|
"""Route block completion triggers to the correct pipeline's reply queue."""
|
|
while True:
|
|
trigger = yield queue.get()
|
|
if trigger is None:
|
|
break
|
|
reply_q = self._pipeline_queues.get(trigger.pipeline_id)
|
|
if reply_q is not None:
|
|
yield reply_q.put(trigger)
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
"""Main inbox loop: dispatch PE commands, forward fabric transactions."""
|
|
from kernbench.common.pe_commands import PeInternalTxn
|
|
|
|
while True:
|
|
msg: Any = yield self._inbox.get()
|
|
if isinstance(msg, PeInternalTxn):
|
|
env.process(self._dispatch(env, msg))
|
|
else:
|
|
env.process(self._forward_txn(env, msg))
|
|
|
|
# ==========================================================================
|
|
# Command Dispatch
|
|
# ==========================================================================
|
|
|
|
def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
|
"""Route a PeInternalTxn to the appropriate engine or pipeline."""
|
|
from kernbench.common.pe_commands import (
|
|
CompositeCmd, DmaReadCmd, DmaWriteCmd,
|
|
GemmCmd, MathCmd, PeCpuOverheadCmd,
|
|
)
|
|
|
|
yield from self.run(env, 0) # scheduler overhead
|
|
|
|
cmd = pe_txn.command
|
|
pp = self._pe_prefix
|
|
|
|
if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)):
|
|
yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn)
|
|
|
|
elif isinstance(cmd, GemmCmd):
|
|
yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn)
|
|
|
|
elif isinstance(cmd, MathCmd):
|
|
yield self.out_ports[f"{pp}.pe_math"].put(pe_txn)
|
|
|
|
elif isinstance(cmd, CompositeCmd):
|
|
if cmd.op == "gemm" and cmd.b is not None:
|
|
yield from self._dispatch_composite_gemm(env, pe_txn, cmd)
|
|
else:
|
|
yield from self._dispatch_composite_math(env, pe_txn, cmd)
|
|
|
|
elif isinstance(cmd, PeCpuOverheadCmd):
|
|
yield env.timeout(cmd.cycles / self._clock_freq_ghz)
|
|
pe_txn.done.succeed()
|
|
|
|
else:
|
|
pe_txn.done.succeed()
|
|
|
|
# -- GEMM composite --------------------------------------------------------
|
|
|
|
def _dispatch_composite_gemm(
|
|
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
|
) -> Generator:
|
|
"""Run tiled GEMM pipeline and collect per-stage metrics."""
|
|
from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline
|
|
|
|
a, b = cmd.a, cmd.b
|
|
M, K, N = a.shape[-2], a.shape[-1], b.shape[-1]
|
|
|
|
pid, reply_q = self._alloc_pipeline(env)
|
|
|
|
pipeline = GemmPipeline(
|
|
env=env, M=M, K=K, N=N,
|
|
tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n,
|
|
bytes_per_element=self._bpe,
|
|
pipeline_id=pid,
|
|
reply_queue=reply_q,
|
|
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
|
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
|
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
|
)
|
|
|
|
# Load descriptors into shared blocks
|
|
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
|
self._gemm_block.load_descriptors(pipeline.gemm_descs)
|
|
self._math_block.load_descriptors(pipeline.math_descs)
|
|
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
|
|
|
start_ns = env.now
|
|
yield pipeline.done
|
|
wall_clock_ns = env.now - start_ns
|
|
del self._pipeline_queues[pid]
|
|
|
|
# -- Collect metrics ---------------------------------------------------
|
|
rd = pe_txn.result_data
|
|
rd["M_tiles"] = pipeline.M_tiles
|
|
rd["K_tiles"] = pipeline.K_tiles
|
|
rd["N_tiles"] = pipeline.N_tiles
|
|
rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles
|
|
rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
|
rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values())
|
|
rd["total_dma_write_bytes"] = sum(
|
|
d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values()
|
|
)
|
|
rd["wall_clock_ns"] = wall_clock_ns
|
|
|
|
# Per-stage timing
|
|
rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = (
|
|
_collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, []))
|
|
)
|
|
rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = (
|
|
_collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, []))
|
|
)
|
|
rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = (
|
|
_collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, []))
|
|
)
|
|
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
|
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
|
)
|
|
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
|
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
|
)
|
|
|
|
# Derived metrics
|
|
rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0
|
|
rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0
|
|
rd["pipeline_parallelism"] = (
|
|
(rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"])
|
|
/ wall_clock_ns if wall_clock_ns > 0 else 0.0
|
|
)
|
|
rd["effective_read_bw_gbs"] = (
|
|
rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"]
|
|
if rd["t_dma_read_total_ns"] > 0 else None
|
|
)
|
|
rd["effective_write_bw_gbs"] = (
|
|
rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"]
|
|
if rd["t_dma_write_total_ns"] > 0 else None
|
|
)
|
|
rd["bottleneck_stage"] = max(
|
|
[("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]),
|
|
("store", rd["t_tcm_store_max_ns"])],
|
|
key=lambda x: x[1],
|
|
)[0]
|
|
|
|
# pe_cpu.py compatibility aliases
|
|
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
|
rd["compute_ns"] = rd["t_compute_total_ns"]
|
|
|
|
pe_txn.done.succeed()
|
|
|
|
# -- Math composite --------------------------------------------------------
|
|
|
|
def _dispatch_composite_math(
|
|
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
|
) -> Generator:
|
|
"""Run tiled element-wise math pipeline and collect metrics."""
|
|
from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline
|
|
|
|
assert self._dmaIN_cmd_q is not None
|
|
assert self._dmaIN_to_fetch_trig is not None
|
|
|
|
a = cmd.a
|
|
M = a.shape[-2] if len(a.shape) >= 2 else 1
|
|
N = a.shape[-1]
|
|
op = cmd.math_op or "identity"
|
|
|
|
pid, reply_q = self._alloc_pipeline(env)
|
|
|
|
pipeline = MathPipeline(
|
|
env=env, M=M, N=N,
|
|
tile_m=self._tile_m, tile_n=self._tile_n,
|
|
bytes_per_element=self._bpe,
|
|
pipeline_id=pid,
|
|
reply_queue=reply_q,
|
|
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
|
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
|
op=op,
|
|
src_addr=a.addr, dst_addr=cmd.out_addr,
|
|
)
|
|
|
|
# Load descriptors into shared blocks
|
|
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
|
self._math_block.load_math_op_descriptors(pipeline.math_op_descs)
|
|
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
|
|
|
start_ns = env.now
|
|
yield pipeline.done
|
|
wall_clock_ns = env.now - start_ns
|
|
del self._pipeline_queues[pid]
|
|
|
|
# -- Collect metrics ---------------------------------------------------
|
|
rd = pe_txn.result_data
|
|
rd["M_tiles"] = pipeline.M_tiles
|
|
rd["N_tiles"] = pipeline.N_tiles
|
|
rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
|
rd["wall_clock_ns"] = wall_clock_ns
|
|
rd["math_op"] = op
|
|
|
|
# DMA timing
|
|
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
|
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
|
)
|
|
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
|
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
|
)
|
|
|
|
# Math op timing (load + compute + store)
|
|
rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = (
|
|
_collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, []))
|
|
)
|
|
rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = (
|
|
_collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, []))
|
|
)
|
|
rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = (
|
|
_collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, []))
|
|
)
|
|
|
|
# pe_cpu.py compatibility aliases
|
|
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
|
rd["compute_ns"] = rd["t_math_compute_total_ns"]
|
|
|
|
pe_txn.done.succeed()
|
|
|
|
# -- Helpers ---------------------------------------------------------------
|
|
|
|
def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]:
|
|
"""Allocate a pipeline ID and reply queue."""
|
|
pid = self._next_pipeline_id
|
|
self._next_pipeline_id += 1
|
|
reply_q = simpy.Store(env)
|
|
self._pipeline_queues[pid] = reply_q
|
|
return pid, reply_q
|
|
|
|
|
|
# ==============================================================================
|
|
# Utility
|
|
# ==============================================================================
|
|
|
|
def _collect_timing(times: list) -> tuple[list, float, float]:
|
|
"""Return (raw_list, max_ns, total_ns) from a timing list."""
|
|
return times, max(times, default=0.0), sum(times)
|