Restructure legacy backups, remove pe_accel, fix DMA self-routing
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure) - Move pe_accel_legacy/ → legacy/pe_accel/ - Remove custom/pe_accel/ (replaced by new builtin) - Remove pe_scheduler_v2 from components.yaml - Switch topology.yaml to pe_scheduler_v1 (new builtin) - Fix PE_DMA self-routing: handle consecutive DMA_READ stages (same component consecutive stages processed in-place, not via port) 382 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,434 @@
|
||||
"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2).
|
||||
|
||||
Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2).
|
||||
|
||||
Hosts four internal hardware blocks as concurrent SimPy processes:
|
||||
- DmaInBlock — HBM → TCM tile reads (real fabric DMA)
|
||||
- DmaWbBlock — TCM → HBM tile writes (real fabric DMA)
|
||||
- GemmBlock — 2-stage MAC pipeline (fetch + compute)
|
||||
- MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.)
|
||||
|
||||
Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA out_port
|
||||
- GemmCmd → PE_GEMM out_port
|
||||
- MathCmd → PE_MATH out_port
|
||||
- CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum)
|
||||
- CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA)
|
||||
- PeCpuOverheadCmd → yield timeout
|
||||
|
||||
Config via node.attrs (all optional):
|
||||
overhead_ns, clock_freq_ghz, bytes_per_element,
|
||||
mac_m, mac_k, mac_n, tile_m, tile_k, tile_n,
|
||||
port_a_bw, port_b_bw, store_port_bw, vector_width
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Component
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SchedulerV2Component(ComponentBase):
|
||||
"""PE accelerator scheduler: wires internal blocks and dispatches commands."""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
attrs = node.attrs
|
||||
|
||||
# Hardware config
|
||||
self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0))
|
||||
self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0))
|
||||
self._bpe: int = int(attrs.get("bytes_per_element", 2))
|
||||
|
||||
# MAC array dimensions
|
||||
self._mac_m: int = int(attrs.get("mac_m", 8))
|
||||
self._mac_k: int = int(attrs.get("mac_k", 16))
|
||||
self._mac_n: int = int(attrs.get("mac_n", 32))
|
||||
|
||||
# Tile dimensions
|
||||
self._tile_m: int = int(attrs.get("tile_m", 32))
|
||||
self._tile_k: int = int(attrs.get("tile_k", 32))
|
||||
self._tile_n: int = int(attrs.get("tile_n", 32))
|
||||
|
||||
# Bandwidth (bytes/cycle)
|
||||
self._port_a_bw: int = int(attrs.get("port_a_bw", 256))
|
||||
self._port_b_bw: int = int(attrs.get("port_b_bw", 256))
|
||||
self._store_port_bw: int = int(attrs.get("store_port_bw", 256))
|
||||
self._vector_width: int = int(attrs.get("vector_width", 256))
|
||||
|
||||
# Initialized in start()
|
||||
self._dmaIN_block: Any = None
|
||||
self._dmaWB_block: Any = None
|
||||
self._gemm_block: Any = None
|
||||
self._math_block: Any = None
|
||||
self._dmaIN_cmd_q: simpy.Store | None = None
|
||||
self._dmaIN_to_fetch_trig: simpy.Store | None = None
|
||||
|
||||
# Pipeline tracking
|
||||
self._next_pipeline_id: int = 0
|
||||
self._pipeline_queues: dict[int, simpy.Store] = {}
|
||||
|
||||
# -- SimPy lifecycle -------------------------------------------------------
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
"""Scheduler overhead per dispatch."""
|
||||
yield env.timeout(self._overhead_ns)
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
"""Create internal blocks, wire queues, start SimPy processes."""
|
||||
from kernbench.components.custom.pe_accel.blocks import (
|
||||
DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock,
|
||||
)
|
||||
|
||||
pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma")
|
||||
|
||||
# -- TCM block (shared BW-serialized scratchpad) -----------------------
|
||||
|
||||
# Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm)
|
||||
pe_links = {}
|
||||
if self.ctx and self.ctx.spec:
|
||||
pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {})
|
||||
tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0))
|
||||
tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0))
|
||||
|
||||
self._tcm_block = TcmBlock(
|
||||
env=env,
|
||||
read_bw_gbs=tcm_read_bw,
|
||||
write_bw_gbs=tcm_write_bw,
|
||||
)
|
||||
|
||||
# -- Internal queues ---------------------------------------------------
|
||||
|
||||
self._dmaIN_cmd_q = simpy.Store(env)
|
||||
self._dmaIN_to_fetch_trig = simpy.Store(env)
|
||||
dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math
|
||||
fetch_to_gemm_trig = simpy.Store(env)
|
||||
gemm_to_math_trig = simpy.Store(env)
|
||||
gemm_to_dmaWB_trig = simpy.Store(env)
|
||||
math_to_dmaWB_trig = simpy.Store(env)
|
||||
|
||||
# Completion queues (block → _completion_router → pipeline reply_q)
|
||||
dmaIN_completion_q = simpy.Store(env)
|
||||
dmaWB_completion_q = simpy.Store(env)
|
||||
gemm_completion_q = simpy.Store(env)
|
||||
math_completion_q = simpy.Store(env)
|
||||
|
||||
# -- Create blocks -----------------------------------------------------
|
||||
|
||||
self._dmaIN_block = DmaInBlock(
|
||||
env=env,
|
||||
cmd_q=self._dmaIN_cmd_q,
|
||||
to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
to_math_trig=dmaIN_to_math_trig,
|
||||
completion_q=dmaIN_completion_q,
|
||||
pe_dma_port=pe_dma_port,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
|
||||
self._dmaWB_block = DmaWbBlock(
|
||||
env=env,
|
||||
completion_q=dmaWB_completion_q,
|
||||
pe_dma_port=pe_dma_port,
|
||||
pe_prefix=self._pe_prefix,
|
||||
bytes_per_element=self._bpe,
|
||||
)
|
||||
|
||||
self._gemm_block = GemmBlock(
|
||||
env=env,
|
||||
trig_in=self._dmaIN_to_fetch_trig,
|
||||
fetch_to_gemm_trig=fetch_to_gemm_trig,
|
||||
to_math_trig=gemm_to_math_trig,
|
||||
to_dmaWB_trig=gemm_to_dmaWB_trig,
|
||||
completion_q=gemm_completion_q,
|
||||
tcm_port=self._tcm_block.port,
|
||||
mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n,
|
||||
bytes_per_element=self._bpe,
|
||||
clock_freq_ghz=self._clock_freq_ghz,
|
||||
)
|
||||
|
||||
self._math_block = MathBlock(
|
||||
env=env,
|
||||
trig_in=gemm_to_math_trig,
|
||||
to_dmaWB_trig=math_to_dmaWB_trig,
|
||||
completion_q=math_completion_q,
|
||||
tcm_port=self._tcm_block.port,
|
||||
bytes_per_element=self._bpe,
|
||||
clock_freq_ghz=self._clock_freq_ghz,
|
||||
vector_width=self._vector_width,
|
||||
)
|
||||
|
||||
# -- Start block processes ---------------------------------------------
|
||||
|
||||
env.process(self._tcm_block._run())
|
||||
env.process(self._dmaIN_block._load_loop())
|
||||
env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig))
|
||||
env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig))
|
||||
env.process(self._gemm_block._fetch_stage())
|
||||
env.process(self._gemm_block._gemm_stage())
|
||||
env.process(self._math_block._run_k_accumulation())
|
||||
env.process(self._math_block._run_element_wise(dmaIN_to_math_trig))
|
||||
|
||||
# Wire in-ports → inbox, start _worker
|
||||
super().start(env)
|
||||
|
||||
# Start completion routers
|
||||
for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q):
|
||||
env.process(self._completion_router(q))
|
||||
|
||||
# -- Internal processes ----------------------------------------------------
|
||||
|
||||
def _completion_router(self, queue: simpy.Store) -> Generator:
|
||||
"""Route block completion triggers to the correct pipeline's reply queue."""
|
||||
while True:
|
||||
trigger = yield queue.get()
|
||||
if trigger is None:
|
||||
break
|
||||
reply_q = self._pipeline_queues.get(trigger.pipeline_id)
|
||||
if reply_q is not None:
|
||||
yield reply_q.put(trigger)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Main inbox loop: dispatch PE commands, forward fabric transactions."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ==========================================================================
|
||||
# Command Dispatch
|
||||
# ==========================================================================
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Route a PeInternalTxn to the appropriate engine or pipeline."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd, DmaReadCmd, DmaWriteCmd,
|
||||
GemmCmd, MathCmd, PeCpuOverheadCmd,
|
||||
)
|
||||
|
||||
yield from self.run(env, 0) # scheduler overhead
|
||||
|
||||
cmd = pe_txn.command
|
||||
pp = self._pe_prefix
|
||||
|
||||
if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)):
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, GemmCmd):
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, MathCmd):
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._dispatch_composite_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._dispatch_composite_math(env, pe_txn, cmd)
|
||||
|
||||
elif isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles / self._clock_freq_ghz)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- GEMM composite --------------------------------------------------------
|
||||
|
||||
def _dispatch_composite_gemm(
|
||||
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
||||
) -> Generator:
|
||||
"""Run tiled GEMM pipeline and collect per-stage metrics."""
|
||||
from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline
|
||||
|
||||
a, b = cmd.a, cmd.b
|
||||
M, K, N = a.shape[-2], a.shape[-1], b.shape[-1]
|
||||
|
||||
pid, reply_q = self._alloc_pipeline(env)
|
||||
|
||||
pipeline = GemmPipeline(
|
||||
env=env, M=M, K=K, N=N,
|
||||
tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n,
|
||||
bytes_per_element=self._bpe,
|
||||
pipeline_id=pid,
|
||||
reply_queue=reply_q,
|
||||
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
||||
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||
)
|
||||
|
||||
# Load descriptors into shared blocks
|
||||
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
||||
self._gemm_block.load_descriptors(pipeline.gemm_descs)
|
||||
self._math_block.load_descriptors(pipeline.math_descs)
|
||||
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
||||
|
||||
start_ns = env.now
|
||||
yield pipeline.done
|
||||
wall_clock_ns = env.now - start_ns
|
||||
del self._pipeline_queues[pid]
|
||||
|
||||
# -- Collect metrics ---------------------------------------------------
|
||||
rd = pe_txn.result_data
|
||||
rd["M_tiles"] = pipeline.M_tiles
|
||||
rd["K_tiles"] = pipeline.K_tiles
|
||||
rd["N_tiles"] = pipeline.N_tiles
|
||||
rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles
|
||||
rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
||||
rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values())
|
||||
rd["total_dma_write_bytes"] = sum(
|
||||
d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values()
|
||||
)
|
||||
rd["wall_clock_ns"] = wall_clock_ns
|
||||
|
||||
# Per-stage timing
|
||||
rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = (
|
||||
_collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = (
|
||||
_collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
||||
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
||||
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# Derived metrics
|
||||
rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0
|
||||
rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0
|
||||
rd["pipeline_parallelism"] = (
|
||||
(rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"])
|
||||
/ wall_clock_ns if wall_clock_ns > 0 else 0.0
|
||||
)
|
||||
rd["effective_read_bw_gbs"] = (
|
||||
rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"]
|
||||
if rd["t_dma_read_total_ns"] > 0 else None
|
||||
)
|
||||
rd["effective_write_bw_gbs"] = (
|
||||
rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"]
|
||||
if rd["t_dma_write_total_ns"] > 0 else None
|
||||
)
|
||||
rd["bottleneck_stage"] = max(
|
||||
[("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]),
|
||||
("store", rd["t_tcm_store_max_ns"])],
|
||||
key=lambda x: x[1],
|
||||
)[0]
|
||||
|
||||
# pe_cpu.py compatibility aliases
|
||||
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
||||
rd["compute_ns"] = rd["t_compute_total_ns"]
|
||||
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- Math composite --------------------------------------------------------
|
||||
|
||||
def _dispatch_composite_math(
|
||||
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
||||
) -> Generator:
|
||||
"""Run tiled element-wise math pipeline and collect metrics."""
|
||||
from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline
|
||||
|
||||
assert self._dmaIN_cmd_q is not None
|
||||
assert self._dmaIN_to_fetch_trig is not None
|
||||
|
||||
a = cmd.a
|
||||
M = a.shape[-2] if len(a.shape) >= 2 else 1
|
||||
N = a.shape[-1]
|
||||
op = cmd.math_op or "identity"
|
||||
|
||||
pid, reply_q = self._alloc_pipeline(env)
|
||||
|
||||
pipeline = MathPipeline(
|
||||
env=env, M=M, N=N,
|
||||
tile_m=self._tile_m, tile_n=self._tile_n,
|
||||
bytes_per_element=self._bpe,
|
||||
pipeline_id=pid,
|
||||
reply_queue=reply_q,
|
||||
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
||||
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
op=op,
|
||||
src_addr=a.addr, dst_addr=cmd.out_addr,
|
||||
)
|
||||
|
||||
# Load descriptors into shared blocks
|
||||
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
||||
self._math_block.load_math_op_descriptors(pipeline.math_op_descs)
|
||||
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
||||
|
||||
start_ns = env.now
|
||||
yield pipeline.done
|
||||
wall_clock_ns = env.now - start_ns
|
||||
del self._pipeline_queues[pid]
|
||||
|
||||
# -- Collect metrics ---------------------------------------------------
|
||||
rd = pe_txn.result_data
|
||||
rd["M_tiles"] = pipeline.M_tiles
|
||||
rd["N_tiles"] = pipeline.N_tiles
|
||||
rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
||||
rd["wall_clock_ns"] = wall_clock_ns
|
||||
rd["math_op"] = op
|
||||
|
||||
# DMA timing
|
||||
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
||||
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
||||
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# Math op timing (load + compute + store)
|
||||
rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# pe_cpu.py compatibility aliases
|
||||
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
||||
rd["compute_ns"] = rd["t_math_compute_total_ns"]
|
||||
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- Helpers ---------------------------------------------------------------
|
||||
|
||||
def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]:
|
||||
"""Allocate a pipeline ID and reply queue."""
|
||||
pid = self._next_pipeline_id
|
||||
self._next_pipeline_id += 1
|
||||
reply_q = simpy.Store(env)
|
||||
self._pipeline_queues[pid] = reply_q
|
||||
return pid, reply_q
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Utility
|
||||
# ==============================================================================
|
||||
|
||||
def _collect_timing(times: list) -> tuple[list, float, float]:
|
||||
"""Return (raw_list, max_ns, total_ns) from a timing list."""
|
||||
return times, max(times, default=0.0), sum(times)
|
||||
Reference in New Issue
Block a user