Add SchedulerV2 (pe_accel), DPPolicy overrides, and new benchmarks
- Add cycle-accurate PE accelerator scheduler (SchedulerV2) with tiled GEMM/Math pipelines (DMA_IN → GEMM → MATH → DMA_WB) - Add DPPolicy num_pes/num_cubes/num_sips overrides for single-PE testing - Support tuple target_pe for targeting specific PE subsets - Add gemm_single_pe and gpt3_qkv benchmarks - Switch default topology to pe_scheduler_v2 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
"""Single-PE GEMM benchmark via scheduler_v2 (pe_accel).
|
||||
|
||||
Full host-to-PE pipeline:
|
||||
Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM
|
||||
|
||||
Single PE: num_sips=1, num_cubes=1, num_pes=1 via DPPolicy override.
|
||||
Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams
|
||||
per-tile DMA internally.
|
||||
|
||||
Run:
|
||||
kernbench run gemm_single_pe
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# GEMM dimensions: (M, K) x (K, N) → (M, N)
|
||||
M, K, N = 32, 128, 32
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
"""Single-PE GEMM: out = a @ b. Both operands streamed from HBM by scheduler."""
|
||||
M, K, N = int(M), int(K), int(N)
|
||||
|
||||
a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
||||
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
||||
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
||||
tl.wait(h)
|
||||
|
||||
|
||||
def run(torch):
|
||||
"""Run the single-PE GEMM benchmark."""
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
|
||||
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
|
||||
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
|
||||
out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out")
|
||||
|
||||
torch.launch("gemm_single_pe", _gemm_kernel, a, b, out, M, K, N)
|
||||
@@ -0,0 +1,92 @@
|
||||
"""GPT-3 QKV projection benchmark: sharded across PEs via pe_accel_v1.
|
||||
|
||||
GPT-3 architecture:
|
||||
d_model = 12288 (hidden dimension)
|
||||
n_heads = 96 (attention heads)
|
||||
d_head = 128 (dimension per head)
|
||||
|
||||
Sharding strategy (column-wise across all PEs):
|
||||
X : (seq_len, d_model) -- replicated to all PEs
|
||||
W_Q/K/V : (d_model, d_model) -- column-wise sharded across cubes × PEs
|
||||
out_Q/K/V: (seq_len, d_model) -- column-wise sharded across cubes × PEs
|
||||
|
||||
Each PE computes:
|
||||
Q_slice = X @ W_Q_slice : (seq_len, d_model) @ (d_model, cols_per_pe) -> (seq_len, cols_per_pe)
|
||||
K_slice, V_slice: same
|
||||
|
||||
PE count is configurable via N_CUBES × N_PE_PER_CUBE (DPPolicy override).
|
||||
topology.yaml is unchanged.
|
||||
|
||||
Run:
|
||||
kernbench run gpt3_qkv
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# -- PE configuration (DPPolicy overrides — does not change topology.yaml) -----
|
||||
N_SIPS = 1
|
||||
N_CUBES = 16 # cubes per SIP
|
||||
N_PE_PER_CUBE = 8 # PEs per cube
|
||||
N_PES = N_CUBES * N_PE_PER_CUBE # 128 total
|
||||
|
||||
# -- GPT-3 architecture -------------------------------------------------------
|
||||
GPT3_D_MODEL = 12288
|
||||
SEQ_LEN = 32
|
||||
COLS_PER_PE = GPT3_D_MODEL // N_PES # 12288 / 128 = 96
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _gpt3_qkv_kernel(x_ptr, wq_ptr, wk_ptr, wv_ptr,
|
||||
out_q_ptr, out_k_ptr, out_v_ptr,
|
||||
seq_len, d_model, cols_per_pe, tl, DTYPE="f16"):
|
||||
"""GPT-3 QKV sharded: each PE uses program_id to index its VA slice."""
|
||||
pid = tl.program_id(0)
|
||||
bpe = 2 # f16
|
||||
|
||||
M = int(seq_len)
|
||||
K = int(d_model)
|
||||
N = int(cols_per_pe)
|
||||
|
||||
w_slice = K * N * bpe
|
||||
out_slice = M * N * bpe
|
||||
|
||||
x = tl.load(int(x_ptr), shape=(M, K), dtype=DTYPE)
|
||||
wq = tl.ref(int(wq_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
|
||||
wk = tl.ref(int(wk_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
|
||||
wv = tl.ref(int(wv_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE)
|
||||
|
||||
hq = tl.composite(op="gemm", a=x, b=wq,
|
||||
out_ptr=int(out_q_ptr) + pid * out_slice)
|
||||
hk = tl.composite(op="gemm", a=x, b=wk,
|
||||
out_ptr=int(out_k_ptr) + pid * out_slice)
|
||||
hv = tl.composite(op="gemm", a=x, b=wv,
|
||||
out_ptr=int(out_v_ptr) + pid * out_slice)
|
||||
|
||||
tl.wait(hq)
|
||||
tl.wait(hk)
|
||||
tl.wait(hv)
|
||||
|
||||
|
||||
def run(torch):
|
||||
"""Run the GPT-3 QKV benchmark."""
|
||||
M = SEQ_LEN
|
||||
K = GPT3_D_MODEL
|
||||
N = COLS_PER_PE
|
||||
|
||||
# X: replicated across all PEs
|
||||
dp_replicate = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs
|
||||
dp_sharded = DPPolicy(cube="column_wise", pe="column_wise",
|
||||
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
|
||||
x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x")
|
||||
wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq")
|
||||
wk = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wk")
|
||||
wv = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wv")
|
||||
out_q = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_q")
|
||||
out_k = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_k")
|
||||
out_v = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_v")
|
||||
|
||||
torch.launch("gpt3_qkv", _gpt3_qkv_kernel,
|
||||
x, wq, wk, wv, out_q, out_k, out_v,
|
||||
M, K, N)
|
||||
+1
-1
@@ -50,4 +50,4 @@ components:
|
||||
pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent
|
||||
|
||||
# Custom — add your implementations here
|
||||
# pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent
|
||||
pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component
|
||||
|
||||
@@ -320,10 +320,12 @@ class MCpuComponent(ComponentBase):
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||
def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]:
|
||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||
if isinstance(target_pe, int):
|
||||
return [target_pe]
|
||||
if isinstance(target_pe, tuple):
|
||||
return list(target_pe)
|
||||
# "all": all PEs in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""PeAccel: cycle-accurate accelerator component for pe_scheduler slot.
|
||||
|
||||
Register in components.yaml as:
|
||||
pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component
|
||||
|
||||
Then reference in topology.yaml:
|
||||
pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { ... } }
|
||||
|
||||
Package layout:
|
||||
scheduler/ — scheduler block (component + dispatch + tiling)
|
||||
scheduler.py — SchedulerV2Component
|
||||
gemm_pipeline.py — tiled GEMM coordinator
|
||||
math_pipeline.py — tiled element-wise math coordinator
|
||||
tile_address.py — per-tile address computation
|
||||
blocks/ — hardware blocks (DMA_IN, DMA_WB, GEMM, MATH, TCM)
|
||||
types.py — data classes (descriptors, triggers, tile commands)
|
||||
"""
|
||||
from kernbench.components.custom.pe_accel.scheduler import SchedulerV2Component
|
||||
|
||||
__all__ = ["SchedulerV2Component"]
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Hardware blocks for pe_accel.
|
||||
|
||||
Each block is a concurrent SimPy process modeling one functional unit:
|
||||
- DmaInBlock: HBM → TCM tile reads (issues real DmaReadCmd to PE_DMA)
|
||||
- DmaWbBlock: TCM → HBM tile writes (issues real DmaWriteCmd to PE_DMA)
|
||||
- GemmBlock: 2-stage MAC pipeline (fetch + compute)
|
||||
- MathBlock: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.)
|
||||
- TcmBlock: TCM access serialization with BW-based timing
|
||||
"""
|
||||
from kernbench.components.custom.pe_accel.blocks.dma_in import DmaInBlock
|
||||
from kernbench.components.custom.pe_accel.blocks.dma_wb import DmaWbBlock
|
||||
from kernbench.components.custom.pe_accel.blocks.gemm import GemmBlock
|
||||
from kernbench.components.custom.pe_accel.blocks.math import MathBlock
|
||||
from kernbench.components.custom.pe_accel.blocks.tcm import TcmBlock, TcmRequest
|
||||
|
||||
__all__ = ["DmaInBlock", "DmaWbBlock", "GemmBlock", "MathBlock", "TcmBlock", "TcmRequest"]
|
||||
@@ -0,0 +1,96 @@
|
||||
"""DMA IN Block: reads tiles from HBM into TCM via real PE_DMA fabric.
|
||||
|
||||
Flow per tile:
|
||||
1. Receive DmaRequest from tiling pipeline
|
||||
2. Look up DmaInDescriptor for address and size
|
||||
3. Issue DmaReadCmd → PE_DMA → fabric → HBM controller → response
|
||||
4. Route completion Trigger to next block (GEMM, MATH, or COMPLETION)
|
||||
|
||||
Timing is real fabric latency (not analytical) — includes BW contention,
|
||||
propagation delay, and HBM controller serialization.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.types import DmaInDescriptor, DmaRequest, Trigger
|
||||
|
||||
|
||||
class DmaInBlock:
|
||||
"""HBM → TCM tile reader. Shared across all concurrent pipelines.
|
||||
|
||||
Pipelines pre-load DmaInDescriptors keyed by (pipeline_id, tile_id, operand).
|
||||
The _load_loop process reads DmaRequests, issues real DmaReadCmd to PE_DMA,
|
||||
and routes completion triggers based on descriptor.next_block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
cmd_q: simpy.Store,
|
||||
to_fetch_trig: simpy.Store,
|
||||
to_math_trig: simpy.Store,
|
||||
completion_q: simpy.Store,
|
||||
*,
|
||||
pe_dma_port: simpy.Store | None,
|
||||
pe_prefix: str,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.cmd_q = cmd_q
|
||||
self.to_fetch_trig = to_fetch_trig
|
||||
self.to_math_trig = to_math_trig
|
||||
self.completion_q = completion_q
|
||||
self._pe_dma_port = pe_dma_port
|
||||
self._pe_prefix = pe_prefix
|
||||
self._descriptor_table: dict[tuple[int, int, str], DmaInDescriptor] = {}
|
||||
|
||||
# Per-pipeline timing histogram (keyed by pipeline_id)
|
||||
self.t_dma_read_per_request: dict[int, list[float]] = {}
|
||||
|
||||
def load_descriptors(self, descs: dict[tuple, DmaInDescriptor]) -> None:
|
||||
"""Pre-load per-operand DMA descriptors (cumulative across pipelines)."""
|
||||
self._descriptor_table.update(descs)
|
||||
|
||||
def _load_loop(self):
|
||||
"""Main process: receive DmaRequests, issue DmaReadCmd, route triggers."""
|
||||
from kernbench.common.pe_commands import DmaReadCmd, PeInternalTxn, TensorHandle
|
||||
|
||||
while True:
|
||||
req: DmaRequest = yield self.cmd_q.get()
|
||||
if req is None:
|
||||
break
|
||||
|
||||
desc = self._descriptor_table[(req.pipeline_id, req.tile_id, req.operand)]
|
||||
|
||||
# Issue real DMA read through PE_DMA → fabric → HBM
|
||||
read_done = self.env.event()
|
||||
handle = TensorHandle(
|
||||
id=f"accel_rd_{req.pipeline_id}_{req.tile_id}_{req.operand}",
|
||||
addr=desc.src_addr,
|
||||
shape=(desc.size_bytes,),
|
||||
dtype="uint8",
|
||||
nbytes=desc.size_bytes,
|
||||
)
|
||||
txn = PeInternalTxn(
|
||||
command=DmaReadCmd(handle=handle, src_addr=desc.src_addr, nbytes=desc.size_bytes),
|
||||
done=read_done,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
t0 = self.env.now
|
||||
yield self._pe_dma_port.put(txn)
|
||||
yield read_done
|
||||
self.t_dma_read_per_request.setdefault(req.pipeline_id, []).append(self.env.now - t0)
|
||||
|
||||
# Route trigger to next block
|
||||
trig = Trigger(
|
||||
tile_id=req.tile_id,
|
||||
pipeline_id=req.pipeline_id,
|
||||
vc=0 if req.operand == "A" else 1,
|
||||
source_block="DMA_IN",
|
||||
)
|
||||
if desc.next_block == "MATH":
|
||||
yield self.to_math_trig.put(trig)
|
||||
elif desc.next_block == "COMPLETION":
|
||||
yield self.completion_q.put(trig)
|
||||
else: # "GEMM" (default)
|
||||
yield self.to_fetch_trig.put(trig)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""DMA Writeback Block: writes result tiles from TCM to HBM via real PE_DMA fabric.
|
||||
|
||||
Flow per tile:
|
||||
1. Receive flush Trigger from GEMM or MATH block
|
||||
2. Look up DmaWBDescriptor for address and tile size
|
||||
3. Issue DmaWriteCmd → PE_DMA → fabric → HBM controller → response
|
||||
4. Send completion Trigger to pipeline
|
||||
|
||||
Two _flush_loop processes run concurrently:
|
||||
- One drains GEMM → DMA_WB triggers (direct writeback path)
|
||||
- One drains MATH → DMA_WB triggers (K-accumulation or element-wise flush)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.types import DmaWBDescriptor, Trigger
|
||||
|
||||
|
||||
class DmaWbBlock:
|
||||
"""TCM → HBM tile writer. Shared across all concurrent pipelines.
|
||||
|
||||
Pipelines pre-load DmaWBDescriptors keyed by (pipeline_id, tile_id).
|
||||
Each _flush_loop process reads triggers, issues real DmaWriteCmd to PE_DMA,
|
||||
and forwards completion to the pipeline's reply queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
completion_q: simpy.Store,
|
||||
*,
|
||||
pe_dma_port: simpy.Store | None,
|
||||
pe_prefix: str,
|
||||
bytes_per_element: int,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.completion_q = completion_q
|
||||
self._pe_dma_port = pe_dma_port
|
||||
self._pe_prefix = pe_prefix
|
||||
self._bpe = bytes_per_element
|
||||
self._descriptor_table: dict[tuple[int, int], DmaWBDescriptor] = {}
|
||||
|
||||
# Per-pipeline timing histogram (keyed by pipeline_id)
|
||||
self.t_dma_write_per_tile: dict[int, list[float]] = {}
|
||||
|
||||
def load_descriptors(self, descs: dict[tuple, DmaWBDescriptor]) -> None:
|
||||
"""Pre-load per-tile writeback descriptors (cumulative across pipelines)."""
|
||||
self._descriptor_table.update(descs)
|
||||
|
||||
def _flush_loop(self, trig_q: simpy.Store):
|
||||
"""Main process: receive flush triggers, issue DmaWriteCmd, send completion."""
|
||||
from kernbench.common.pe_commands import DmaWriteCmd, PeInternalTxn, TensorHandle
|
||||
|
||||
while True:
|
||||
trigger: Trigger = yield trig_q.get()
|
||||
if trigger is None:
|
||||
break
|
||||
|
||||
pid = trigger.pipeline_id
|
||||
tile_id = trigger.tile_id
|
||||
desc = self._descriptor_table.get((pid, tile_id))
|
||||
|
||||
if desc:
|
||||
c_bytes = desc.Tm * desc.Tn * self._bpe
|
||||
|
||||
# Issue real DMA write through PE_DMA → fabric → HBM
|
||||
write_done = self.env.event()
|
||||
handle = TensorHandle(
|
||||
id=f"accel_wb_{pid}_{tile_id}",
|
||||
addr=desc.dst_addr,
|
||||
shape=(desc.Tm, desc.Tn),
|
||||
dtype="float16",
|
||||
nbytes=c_bytes,
|
||||
)
|
||||
txn = PeInternalTxn(
|
||||
command=DmaWriteCmd(handle=handle, dst_addr=desc.dst_addr, nbytes=c_bytes),
|
||||
done=write_done,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
t0 = self.env.now
|
||||
yield self._pe_dma_port.put(txn)
|
||||
yield write_done
|
||||
self.t_dma_write_per_tile.setdefault(pid, []).append(self.env.now - t0)
|
||||
|
||||
yield self.completion_q.put(
|
||||
Trigger(tile_id=tile_id, pipeline_id=pid, source_block="DMA_WB")
|
||||
)
|
||||
@@ -0,0 +1,160 @@
|
||||
"""GEMM Block: 2-stage MAC pipeline (fetch + compute).
|
||||
|
||||
Stage 1 — Fetch (_fetch_stage):
|
||||
Collects DMA completion triggers (one per operand per tile).
|
||||
When all operands arrive, issues TCM read request (SPMem → MAC registers).
|
||||
|
||||
Stage 2 — Compute (_gemm_stage):
|
||||
Models MAC array computation time.
|
||||
Issues TCM write request (MAC result → SPMem).
|
||||
Routes output trigger to MathBlock, DmaWbBlock, or completion.
|
||||
|
||||
TCM access goes through TcmBlock for real BW serialization.
|
||||
MAC compute time is cycle-accurate: ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest
|
||||
from kernbench.components.custom.pe_accel.types import GemmDescriptor, Trigger
|
||||
|
||||
|
||||
class GemmBlock:
|
||||
"""2-stage MAC pipeline shared across all concurrent pipelines.
|
||||
|
||||
Pipelines pre-load GemmDescriptors keyed by (pipeline_id, tile_id).
|
||||
Two SimPy processes run concurrently: _fetch_stage and _gemm_stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
trig_in: simpy.Store,
|
||||
fetch_to_gemm_trig: simpy.Store,
|
||||
to_math_trig: simpy.Store,
|
||||
to_dmaWB_trig: simpy.Store,
|
||||
completion_q: simpy.Store,
|
||||
*,
|
||||
tcm_port: simpy.Store,
|
||||
mac_m: int,
|
||||
mac_k: int,
|
||||
mac_n: int,
|
||||
bytes_per_element: int,
|
||||
clock_freq_ghz: float,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.trig_in = trig_in
|
||||
self.fetch_to_gemm_trig = fetch_to_gemm_trig
|
||||
self.to_math_trig = to_math_trig
|
||||
self.to_dmaWB_trig = to_dmaWB_trig
|
||||
self.completion_q = completion_q
|
||||
self._tcm_port = tcm_port
|
||||
|
||||
self._mac_m = mac_m
|
||||
self._mac_k = mac_k
|
||||
self._mac_n = mac_n
|
||||
self._bpe = bytes_per_element
|
||||
self._freq = clock_freq_ghz
|
||||
|
||||
self._descriptor_table: dict[tuple[int, int], GemmDescriptor] = {}
|
||||
|
||||
# Per-pipeline timing histograms
|
||||
self.t_tcm_load_per_tile: dict[int, list[float]] = {}
|
||||
self.t_compute_per_tile: dict[int, list[float]] = {}
|
||||
|
||||
def load_descriptors(self, descs: dict[tuple, GemmDescriptor]) -> None:
|
||||
"""Pre-load per-tile GEMM descriptors (cumulative across pipelines)."""
|
||||
self._descriptor_table.update(descs)
|
||||
|
||||
def _compute_ns(self, desc: GemmDescriptor) -> float:
|
||||
"""MAC array compute time for one tile (ns)."""
|
||||
cycles = ceil(desc.Tm / self._mac_m) * ceil(desc.Tk / self._mac_k) * ceil(desc.Tn / self._mac_n)
|
||||
return cycles / self._freq
|
||||
|
||||
# -- Stage 1: Fetch (TCM → MAC load) --------------------------------------
|
||||
|
||||
def _fetch_stage(self):
|
||||
"""Collect DMA triggers per tile, issue TCM read for operand load."""
|
||||
pending: dict[tuple[int, int], list[Trigger]] = {}
|
||||
while True:
|
||||
trigger = yield self.trig_in.get()
|
||||
if trigger is None:
|
||||
yield self.fetch_to_gemm_trig.put(None)
|
||||
break
|
||||
|
||||
key = (trigger.pipeline_id, trigger.tile_id)
|
||||
pending.setdefault(key, []).append(trigger)
|
||||
|
||||
desc = self._descriptor_table.get(key)
|
||||
needed = desc.triggers_needed if desc else 2
|
||||
if len(pending[key]) < needed:
|
||||
continue
|
||||
|
||||
del pending[key]
|
||||
|
||||
# TCM load: read A and B tile data from SPMem → MAC registers
|
||||
if desc and desc.gemm_load:
|
||||
a_bytes = desc.Tm * desc.Tk * self._bpe
|
||||
b_bytes = desc.Tk * desc.Tn * self._bpe
|
||||
load_bytes = a_bytes + b_bytes
|
||||
|
||||
t0 = self.env.now
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("read", load_bytes, done, tag="gemm_load"))
|
||||
yield done
|
||||
self.t_tcm_load_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
yield self.fetch_to_gemm_trig.put(trigger)
|
||||
|
||||
# -- Stage 2: Compute (MAC array) + Store (MAC → TCM) ---------------------
|
||||
|
||||
def _gemm_stage(self):
|
||||
"""MAC computation, then TCM store, then route to next block."""
|
||||
while True:
|
||||
trigger = yield self.fetch_to_gemm_trig.get()
|
||||
if trigger is None:
|
||||
break
|
||||
|
||||
key = (trigger.pipeline_id, trigger.tile_id)
|
||||
desc = self._descriptor_table.get(key)
|
||||
|
||||
# MAC compute
|
||||
if desc and desc.gemm_compute:
|
||||
t_compute = self._compute_ns(desc)
|
||||
t0 = self.env.now
|
||||
if t_compute > 0:
|
||||
yield self.env.timeout(t_compute)
|
||||
self.t_compute_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
# Route output
|
||||
route = desc.next_block if desc else "MATH"
|
||||
out_trig = Trigger(
|
||||
tile_id=trigger.tile_id,
|
||||
pipeline_id=trigger.pipeline_id,
|
||||
source_block="GEMM",
|
||||
)
|
||||
|
||||
if route == "MATH":
|
||||
yield self.to_math_trig.put(out_trig)
|
||||
elif route == "DMAWB":
|
||||
# TCM store before writeback
|
||||
if desc:
|
||||
c_bytes = desc.Tm * desc.Tn * self._bpe
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store"))
|
||||
yield done
|
||||
yield self.to_dmaWB_trig.put(out_trig)
|
||||
else: # "DONE" — C stays in SPMem, no flush
|
||||
if desc:
|
||||
c_bytes = desc.Tm * desc.Tn * self._bpe
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store"))
|
||||
yield done
|
||||
yield self.completion_q.put(out_trig)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Math Block: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.).
|
||||
|
||||
Two concurrent processing modes:
|
||||
|
||||
1. K-accumulation (_run_k_accumulation):
|
||||
Receives triggers from GemmBlock after each K-tile compute.
|
||||
Issues TCM write for partial-result store.
|
||||
On final K-tile, routes to DMA_WB or completion.
|
||||
|
||||
2. Element-wise ops (_run_element_wise):
|
||||
Receives triggers from DMA_IN after each tile read.
|
||||
Issues TCM read (load input), compute (SIMD), TCM write (store result).
|
||||
Routes output to DMA_WB for writeback.
|
||||
|
||||
TCM access goes through TcmBlock for real BW serialization.
|
||||
SIMD compute time: ceil(num_elements / vector_width) / clock_freq.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest
|
||||
from kernbench.components.custom.pe_accel.types import MathDescriptor, MathOpDescriptor, Trigger
|
||||
|
||||
|
||||
class MathBlock:
|
||||
"""K-accumulation + element-wise math unit.
|
||||
|
||||
Descriptor tables:
|
||||
- _accum_table: MathDescriptor (pipeline_id, tile_id) — for GEMM K-accumulation
|
||||
- _elemwise_table: MathOpDescriptor (pipeline_id, tile_id) — for element-wise ops
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
trig_in: simpy.Store,
|
||||
to_dmaWB_trig: simpy.Store,
|
||||
completion_q: simpy.Store,
|
||||
*,
|
||||
tcm_port: simpy.Store,
|
||||
bytes_per_element: int,
|
||||
clock_freq_ghz: float,
|
||||
vector_width: int = 256,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.trig_in = trig_in # from GemmBlock (K-accumulation)
|
||||
self.to_dmaWB_trig = to_dmaWB_trig
|
||||
self.completion_q = completion_q
|
||||
self._tcm_port = tcm_port
|
||||
|
||||
self._bpe = bytes_per_element
|
||||
self._freq = clock_freq_ghz
|
||||
self._vector_width = vector_width
|
||||
|
||||
# Descriptor tables
|
||||
self._accum_table: dict[tuple[int, int], MathDescriptor] = {}
|
||||
self._elemwise_table: dict[tuple[int, int], MathOpDescriptor] = {}
|
||||
|
||||
# -- Timing histograms (per pipeline_id) --
|
||||
|
||||
# K-accumulation
|
||||
self.t_tcm_store_per_tile: dict[int, list[float]] = {}
|
||||
|
||||
# Element-wise ops
|
||||
self.t_math_op_load_per_tile: dict[int, list[float]] = {}
|
||||
self.t_math_op_compute_per_tile: dict[int, list[float]] = {}
|
||||
self.t_math_op_store_per_tile: dict[int, list[float]] = {}
|
||||
|
||||
# -- Descriptor loading ----------------------------------------------------
|
||||
|
||||
def load_descriptors(self, descs: dict[tuple, MathDescriptor]) -> None:
|
||||
"""Pre-load K-accumulation descriptors (cumulative across pipelines)."""
|
||||
self._accum_table.update(descs)
|
||||
|
||||
def load_math_op_descriptors(self, descs: dict[tuple, MathOpDescriptor]) -> None:
|
||||
"""Pre-load element-wise op descriptors (cumulative across pipelines)."""
|
||||
self._elemwise_table.update(descs)
|
||||
|
||||
# -- Mode 1: K-accumulation ------------------------------------------------
|
||||
|
||||
def _run(self):
|
||||
"""Backward-compat alias."""
|
||||
yield from self._run_k_accumulation()
|
||||
|
||||
def _run_k_accumulation(self):
|
||||
"""Receive GEMM output triggers, TCM store partial result, flush on last-K."""
|
||||
while True:
|
||||
trigger = yield self.trig_in.get()
|
||||
if trigger is None:
|
||||
break
|
||||
|
||||
key = (trigger.pipeline_id, trigger.tile_id)
|
||||
desc = self._accum_table.get(key)
|
||||
|
||||
# TCM store: write partial sum to SPMem
|
||||
if desc:
|
||||
c_bytes = desc.Tm * desc.Tn * self._bpe
|
||||
t0 = self.env.now
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="k_accum_store"))
|
||||
yield done
|
||||
self.t_tcm_store_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
if not desc or not desc.is_last_k:
|
||||
continue # intermediate K-tile: store done, no flush yet
|
||||
|
||||
out_trig = Trigger(
|
||||
tile_id=trigger.tile_id,
|
||||
pipeline_id=trigger.pipeline_id,
|
||||
source_block="MATH",
|
||||
)
|
||||
if desc.skip_dmaWB:
|
||||
yield self.completion_q.put(out_trig)
|
||||
else:
|
||||
yield self.to_dmaWB_trig.put(out_trig)
|
||||
|
||||
# -- Mode 2: Element-wise ops ----------------------------------------------
|
||||
|
||||
def _run_math_op(self, trig_q: simpy.Store):
|
||||
"""Backward-compat alias."""
|
||||
yield from self._run_element_wise(trig_q)
|
||||
|
||||
def _run_element_wise(self, trig_q: simpy.Store):
|
||||
"""Receive DMA_IN triggers, apply element-wise op via TCM, route to DMA_WB.
|
||||
|
||||
Per tile:
|
||||
1. TCM read — load input tile from SPMem to SIMD
|
||||
2. Compute — SIMD operation (exp/log/etc.)
|
||||
3. TCM write — store result from SIMD to SPMem
|
||||
4. Route to DMA_WB
|
||||
"""
|
||||
while True:
|
||||
trigger = yield trig_q.get()
|
||||
if trigger is None:
|
||||
break
|
||||
|
||||
key = (trigger.pipeline_id, trigger.tile_id)
|
||||
desc = self._elemwise_table.get(key)
|
||||
|
||||
if desc:
|
||||
tile_bytes = desc.Tm * desc.Tn * self._bpe
|
||||
num_elements = desc.Tm * desc.Tn
|
||||
|
||||
# 1. TCM read
|
||||
t0 = self.env.now
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("read", tile_bytes, done, tag="elemwise_load"))
|
||||
yield done
|
||||
self.t_math_op_load_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
# 2. SIMD compute
|
||||
t0 = self.env.now
|
||||
compute_cycles = ceil(num_elements / self._vector_width)
|
||||
compute_ns = compute_cycles / self._freq
|
||||
if compute_ns > 0:
|
||||
yield self.env.timeout(compute_ns)
|
||||
self.t_math_op_compute_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
# 3. TCM write
|
||||
t0 = self.env.now
|
||||
done = self.env.event()
|
||||
yield self._tcm_port.put(TcmRequest("write", tile_bytes, done, tag="elemwise_store"))
|
||||
yield done
|
||||
self.t_math_op_store_per_tile.setdefault(trigger.pipeline_id, []).append(
|
||||
self.env.now - t0
|
||||
)
|
||||
|
||||
yield self.to_dmaWB_trig.put(Trigger(
|
||||
tile_id=trigger.tile_id,
|
||||
pipeline_id=trigger.pipeline_id,
|
||||
source_block="MATH_OP",
|
||||
))
|
||||
@@ -0,0 +1,80 @@
|
||||
"""TCM Block: tightly-coupled memory with BW-based access serialization.
|
||||
|
||||
Models SPMem (scratchpad memory) inside the PE. Compute blocks (GEMM, MATH)
|
||||
send TcmRequests for load/store operations. The TCM block serializes access
|
||||
per channel and computes timing based on data size and bandwidth.
|
||||
|
||||
Two channels:
|
||||
- READ (SPMem → compute unit): models operand fetch for MAC/SIMD
|
||||
- WRITE (compute unit → SPMem): models result store from MAC/SIMD
|
||||
|
||||
Each channel has capacity=1: concurrent reads serialize, concurrent writes
|
||||
serialize, but a read and a write can proceed in parallel.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import simpy
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcmRequest:
|
||||
"""Request to read from or write to TCM."""
|
||||
|
||||
direction: str # "read" or "write"
|
||||
nbytes: int
|
||||
done: simpy.Event
|
||||
tag: str = "" # optional label for debugging
|
||||
|
||||
|
||||
class TcmBlock:
|
||||
"""BW-serialized TCM model with dual read/write channels.
|
||||
|
||||
Args:
|
||||
env: SimPy environment.
|
||||
read_bw_gbs: read bandwidth in GB/s (SPMem → compute).
|
||||
write_bw_gbs: write bandwidth in GB/s (compute → SPMem).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
read_bw_gbs: float = 512.0,
|
||||
write_bw_gbs: float = 512.0,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self._read_bw = read_bw_gbs
|
||||
self._write_bw = write_bw_gbs
|
||||
self._read_res = simpy.Resource(env, capacity=1)
|
||||
self._write_res = simpy.Resource(env, capacity=1)
|
||||
self._port: simpy.Store = simpy.Store(env)
|
||||
|
||||
@property
|
||||
def port(self) -> simpy.Store:
|
||||
"""The SimPy Store that blocks send TcmRequests to."""
|
||||
return self._port
|
||||
|
||||
def _run(self):
|
||||
"""Main process: receive TcmRequests, dispatch to channel processes."""
|
||||
while True:
|
||||
req: TcmRequest = yield self._port.get()
|
||||
if req is None:
|
||||
break
|
||||
self.env.process(self._handle(req))
|
||||
|
||||
def _handle(self, req: TcmRequest):
|
||||
"""Acquire channel, apply BW-based delay, signal done."""
|
||||
if req.direction == "write":
|
||||
res = self._write_res
|
||||
bw = self._write_bw
|
||||
else:
|
||||
res = self._read_res
|
||||
bw = self._read_bw
|
||||
|
||||
with res.request() as lock:
|
||||
yield lock
|
||||
if bw > 0 and req.nbytes > 0:
|
||||
delay_ns = req.nbytes / bw
|
||||
yield self.env.timeout(delay_ns)
|
||||
req.done.succeed()
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Scheduler: accelerator component + dispatch + tiling pipelines.
|
||||
|
||||
scheduler.py — SchedulerV2Component (init, wiring, dispatch, metrics)
|
||||
gemm_pipeline.py — GemmPipeline (tiled GEMM coordinator)
|
||||
math_pipeline.py — MathPipeline (tiled element-wise math coordinator)
|
||||
tile_address.py — per-tile address computation
|
||||
"""
|
||||
from kernbench.components.custom.pe_accel.scheduler.scheduler import SchedulerV2Component
|
||||
|
||||
__all__ = ["SchedulerV2Component"]
|
||||
@@ -0,0 +1,157 @@
|
||||
"""GEMM Tiling Pipeline: splits (M,K)×(K,N) into tiles and coordinates execution.
|
||||
|
||||
Flow per tile:
|
||||
DMA_IN(A tile) + DMA_IN(B tile) → GEMM(fetch + compute) → MATH(K-accum) → DMA_WB
|
||||
|
||||
The pipeline does NOT own hardware blocks — it uses the component's shared
|
||||
blocks via descriptor tables and SimPy queues.
|
||||
|
||||
Constructor starts two SimPy processes:
|
||||
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
|
||||
- _collect_completions(): waits for all output tiles to flush
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_gemm_tiles
|
||||
from kernbench.components.custom.pe_accel.types import (
|
||||
CmdType,
|
||||
DmaInDescriptor,
|
||||
DmaRequest,
|
||||
DmaWBDescriptor,
|
||||
GemmDescriptor,
|
||||
MathDescriptor,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
class GemmPipeline:
|
||||
"""Coordinates one tiled GEMM operation across shared hardware blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
M: int, K: int, N: int,
|
||||
tile_m: int, tile_k: int, tile_n: int,
|
||||
bytes_per_element: int,
|
||||
pipeline_id: int,
|
||||
reply_queue: simpy.Store,
|
||||
dmaIN_cmd_q: simpy.Store,
|
||||
dmaIN_to_fetch_trig: simpy.Store,
|
||||
A_addr: int = 0,
|
||||
B_addr: int = 0,
|
||||
C_addr: int = 0,
|
||||
dma_a: bool = True,
|
||||
dma_b: bool = True,
|
||||
dma_c: bool = True,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.M, self.K, self.N = M, K, N
|
||||
self.pipeline_id = pipeline_id
|
||||
self.reply_queue = reply_queue
|
||||
self.dmaIN_cmd_q = dmaIN_cmd_q
|
||||
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
|
||||
|
||||
self._dma_a = dma_a
|
||||
self._dma_b = dma_b
|
||||
self._skip_dmaWB = not dma_c
|
||||
|
||||
_Tm = min(tile_m, M)
|
||||
_Tk = min(tile_k, K)
|
||||
_Tn = min(tile_n, N)
|
||||
|
||||
self.M_tiles = ceil(M / tile_m)
|
||||
self.K_tiles = ceil(K / tile_k)
|
||||
self.N_tiles = ceil(N / tile_n)
|
||||
|
||||
triggers_per_tile = 2 if (dma_a and dma_b) else 1
|
||||
|
||||
# Generate tile schedule with pre-computed addresses
|
||||
self.schedule = generate_gemm_tiles(
|
||||
self.M_tiles, self.K_tiles, self.N_tiles,
|
||||
M=M, K=K, N=N,
|
||||
tile_m=_Tm, tile_k=_Tk, tile_n=_Tn,
|
||||
bytes_per_element=bytes_per_element,
|
||||
A_addr=A_addr, B_addr=B_addr, C_addr=C_addr,
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
# Build descriptor tables for shared blocks
|
||||
pid = pipeline_id
|
||||
a_tile_bytes = _Tm * _Tk * bytes_per_element
|
||||
b_tile_bytes = _Tk * _Tn * bytes_per_element
|
||||
|
||||
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
|
||||
self.gemm_descs: dict[tuple, GemmDescriptor] = {}
|
||||
self.math_descs: dict[tuple, MathDescriptor] = {}
|
||||
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
|
||||
|
||||
for cmd in self.schedule.commands:
|
||||
if cmd.cmd_type != CmdType.DMA_LOAD:
|
||||
continue
|
||||
t = cmd.tile_id
|
||||
|
||||
if dma_a:
|
||||
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
|
||||
size_bytes=a_tile_bytes, src_addr=cmd.a_tile_addr
|
||||
)
|
||||
if dma_b:
|
||||
self.dmaIN_descs[(pid, t, "B")] = DmaInDescriptor(
|
||||
size_bytes=b_tile_bytes, src_addr=cmd.b_tile_addr
|
||||
)
|
||||
|
||||
self.gemm_descs[(pid, t)] = GemmDescriptor(
|
||||
Tm=_Tm, Tk=_Tk, Tn=_Tn,
|
||||
triggers_needed=triggers_per_tile,
|
||||
next_block="MATH",
|
||||
)
|
||||
|
||||
self.math_descs[(pid, t)] = MathDescriptor(
|
||||
Tm=_Tm, Tn=_Tn,
|
||||
is_last_k=cmd.is_last_k,
|
||||
skip_dmaWB=self._skip_dmaWB,
|
||||
)
|
||||
if not self._skip_dmaWB and cmd.is_last_k:
|
||||
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
|
||||
Tm=_Tm, Tn=_Tn, dst_addr=cmd.c_tile_addr
|
||||
)
|
||||
|
||||
self.expected_flushes = self.M_tiles * self.N_tiles
|
||||
self.completed_flushes = 0
|
||||
self.done_at: int = 0
|
||||
self.done: simpy.Event = env.event()
|
||||
|
||||
env.process(self._feed_commands())
|
||||
env.process(self._collect_completions())
|
||||
|
||||
def _feed_commands(self):
|
||||
"""Send DmaRequests for each tile's operands to dmaIN_cmd_q."""
|
||||
for cmd in self.schedule.commands:
|
||||
if cmd.cmd_type != CmdType.DMA_LOAD:
|
||||
continue
|
||||
|
||||
if self._dma_a:
|
||||
yield self.dmaIN_cmd_q.put(DmaRequest(
|
||||
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
|
||||
))
|
||||
if self._dma_b:
|
||||
yield self.dmaIN_cmd_q.put(DmaRequest(
|
||||
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="B",
|
||||
))
|
||||
|
||||
if not self._dma_a and not self._dma_b:
|
||||
yield self.dmaIN_to_fetch_trig.put(Trigger(
|
||||
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
|
||||
))
|
||||
|
||||
def _collect_completions(self):
|
||||
"""Wait for all output tile flush completions, then signal done."""
|
||||
while self.completed_flushes < self.expected_flushes:
|
||||
yield self.reply_queue.get()
|
||||
self.completed_flushes += 1
|
||||
|
||||
self.done_at = int(self.env.now)
|
||||
self.done.succeed()
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution.
|
||||
|
||||
Flow per tile:
|
||||
DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile)
|
||||
|
||||
Mirrors GemmTilingPipeline but for unary element-wise operations.
|
||||
Pipeline overlap across tiles: while one tile is in MATH_OP, the next
|
||||
tile's DMA_IN can proceed concurrently.
|
||||
|
||||
Constructor starts two SimPy processes:
|
||||
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
|
||||
- _collect_completions(): waits for all tiles to writeback
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles
|
||||
from kernbench.components.custom.pe_accel.types import (
|
||||
DmaInDescriptor,
|
||||
DmaRequest,
|
||||
DmaWBDescriptor,
|
||||
MathOpDescriptor,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
class MathPipeline:
|
||||
"""Coordinates one tiled element-wise math operation across shared blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
M: int, N: int,
|
||||
tile_m: int, tile_n: int,
|
||||
bytes_per_element: int,
|
||||
pipeline_id: int,
|
||||
reply_queue: simpy.Store,
|
||||
dmaIN_cmd_q: simpy.Store,
|
||||
dmaIN_to_fetch_trig: simpy.Store,
|
||||
op: str = "exp",
|
||||
src_addr: int = 0,
|
||||
dst_addr: int = 0,
|
||||
dma_in: bool = True,
|
||||
dma_out: bool = True,
|
||||
) -> None:
|
||||
self.env = env
|
||||
self.M, self.N = M, N
|
||||
self.pipeline_id = pipeline_id
|
||||
self.reply_queue = reply_queue
|
||||
self.dmaIN_cmd_q = dmaIN_cmd_q
|
||||
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
|
||||
self.op = op
|
||||
|
||||
self._dma_in = dma_in
|
||||
self._skip_dmaWB = not dma_out
|
||||
|
||||
_Tm = min(tile_m, M)
|
||||
_Tn = min(tile_n, N)
|
||||
|
||||
self.M_tiles = ceil(M / tile_m)
|
||||
self.N_tiles = ceil(N / tile_n)
|
||||
|
||||
# Generate tile schedule with pre-computed addresses
|
||||
self.schedule = generate_math_tiles(
|
||||
self.M_tiles, self.N_tiles,
|
||||
M=M, N=N,
|
||||
tile_m=_Tm, tile_n=_Tn,
|
||||
bytes_per_element=bytes_per_element,
|
||||
src_addr=src_addr, dst_addr=dst_addr,
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
# Build descriptor tables for shared blocks
|
||||
pid = pipeline_id
|
||||
tile_bytes = _Tm * _Tn * bytes_per_element
|
||||
|
||||
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
|
||||
self.math_op_descs: dict[tuple, MathOpDescriptor] = {}
|
||||
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
|
||||
|
||||
for cmd in self.schedule.commands:
|
||||
t = cmd.tile_id
|
||||
|
||||
if dma_in:
|
||||
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
|
||||
size_bytes=tile_bytes,
|
||||
src_addr=cmd.src_tile_addr,
|
||||
next_block="MATH",
|
||||
)
|
||||
|
||||
self.math_op_descs[(pid, t)] = MathOpDescriptor(
|
||||
Tm=_Tm, Tn=_Tn, op=op,
|
||||
src_addr=cmd.src_tile_addr,
|
||||
dst_addr=cmd.dst_tile_addr,
|
||||
)
|
||||
|
||||
if not self._skip_dmaWB:
|
||||
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
|
||||
Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr,
|
||||
)
|
||||
|
||||
self.expected_flushes = self.M_tiles * self.N_tiles
|
||||
self.completed_flushes = 0
|
||||
self.done_at: int = 0
|
||||
self.done: simpy.Event = env.event()
|
||||
|
||||
env.process(self._feed_commands())
|
||||
env.process(self._collect_completions())
|
||||
|
||||
def _feed_commands(self):
|
||||
"""Send DmaRequests for each tile's input to dmaIN_cmd_q."""
|
||||
for cmd in self.schedule.commands:
|
||||
if self._dma_in:
|
||||
yield self.dmaIN_cmd_q.put(DmaRequest(
|
||||
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
|
||||
))
|
||||
else:
|
||||
yield self.dmaIN_to_fetch_trig.put(Trigger(
|
||||
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
|
||||
))
|
||||
|
||||
def _collect_completions(self):
|
||||
"""Wait for all tile completions, then signal done."""
|
||||
while self.completed_flushes < self.expected_flushes:
|
||||
yield self.reply_queue.get()
|
||||
self.completed_flushes += 1
|
||||
|
||||
self.done_at = int(self.env.now)
|
||||
self.done.succeed()
|
||||
@@ -0,0 +1,434 @@
|
||||
"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2).
|
||||
|
||||
Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2).
|
||||
|
||||
Hosts four internal hardware blocks as concurrent SimPy processes:
|
||||
- DmaInBlock — HBM → TCM tile reads (real fabric DMA)
|
||||
- DmaWbBlock — TCM → HBM tile writes (real fabric DMA)
|
||||
- GemmBlock — 2-stage MAC pipeline (fetch + compute)
|
||||
- MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.)
|
||||
|
||||
Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA out_port
|
||||
- GemmCmd → PE_GEMM out_port
|
||||
- MathCmd → PE_MATH out_port
|
||||
- CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum)
|
||||
- CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA)
|
||||
- PeCpuOverheadCmd → yield timeout
|
||||
|
||||
Config via node.attrs (all optional):
|
||||
overhead_ns, clock_freq_ghz, bytes_per_element,
|
||||
mac_m, mac_k, mac_n, tile_m, tile_k, tile_n,
|
||||
port_a_bw, port_b_bw, store_port_bw, vector_width
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Component
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SchedulerV2Component(ComponentBase):
|
||||
"""PE accelerator scheduler: wires internal blocks and dispatches commands."""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
attrs = node.attrs
|
||||
|
||||
# Hardware config
|
||||
self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0))
|
||||
self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0))
|
||||
self._bpe: int = int(attrs.get("bytes_per_element", 2))
|
||||
|
||||
# MAC array dimensions
|
||||
self._mac_m: int = int(attrs.get("mac_m", 8))
|
||||
self._mac_k: int = int(attrs.get("mac_k", 16))
|
||||
self._mac_n: int = int(attrs.get("mac_n", 32))
|
||||
|
||||
# Tile dimensions
|
||||
self._tile_m: int = int(attrs.get("tile_m", 32))
|
||||
self._tile_k: int = int(attrs.get("tile_k", 32))
|
||||
self._tile_n: int = int(attrs.get("tile_n", 32))
|
||||
|
||||
# Bandwidth (bytes/cycle)
|
||||
self._port_a_bw: int = int(attrs.get("port_a_bw", 256))
|
||||
self._port_b_bw: int = int(attrs.get("port_b_bw", 256))
|
||||
self._store_port_bw: int = int(attrs.get("store_port_bw", 256))
|
||||
self._vector_width: int = int(attrs.get("vector_width", 256))
|
||||
|
||||
# Initialized in start()
|
||||
self._dmaIN_block: Any = None
|
||||
self._dmaWB_block: Any = None
|
||||
self._gemm_block: Any = None
|
||||
self._math_block: Any = None
|
||||
self._dmaIN_cmd_q: simpy.Store | None = None
|
||||
self._dmaIN_to_fetch_trig: simpy.Store | None = None
|
||||
|
||||
# Pipeline tracking
|
||||
self._next_pipeline_id: int = 0
|
||||
self._pipeline_queues: dict[int, simpy.Store] = {}
|
||||
|
||||
# -- SimPy lifecycle -------------------------------------------------------
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
"""Scheduler overhead per dispatch."""
|
||||
yield env.timeout(self._overhead_ns)
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
"""Create internal blocks, wire queues, start SimPy processes."""
|
||||
from kernbench.components.custom.pe_accel.blocks import (
|
||||
DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock,
|
||||
)
|
||||
|
||||
pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma")
|
||||
|
||||
# -- TCM block (shared BW-serialized scratchpad) -----------------------
|
||||
|
||||
# Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm)
|
||||
pe_links = {}
|
||||
if self.ctx and self.ctx.spec:
|
||||
pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {})
|
||||
tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0))
|
||||
tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0))
|
||||
|
||||
self._tcm_block = TcmBlock(
|
||||
env=env,
|
||||
read_bw_gbs=tcm_read_bw,
|
||||
write_bw_gbs=tcm_write_bw,
|
||||
)
|
||||
|
||||
# -- Internal queues ---------------------------------------------------
|
||||
|
||||
self._dmaIN_cmd_q = simpy.Store(env)
|
||||
self._dmaIN_to_fetch_trig = simpy.Store(env)
|
||||
dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math
|
||||
fetch_to_gemm_trig = simpy.Store(env)
|
||||
gemm_to_math_trig = simpy.Store(env)
|
||||
gemm_to_dmaWB_trig = simpy.Store(env)
|
||||
math_to_dmaWB_trig = simpy.Store(env)
|
||||
|
||||
# Completion queues (block → _completion_router → pipeline reply_q)
|
||||
dmaIN_completion_q = simpy.Store(env)
|
||||
dmaWB_completion_q = simpy.Store(env)
|
||||
gemm_completion_q = simpy.Store(env)
|
||||
math_completion_q = simpy.Store(env)
|
||||
|
||||
# -- Create blocks -----------------------------------------------------
|
||||
|
||||
self._dmaIN_block = DmaInBlock(
|
||||
env=env,
|
||||
cmd_q=self._dmaIN_cmd_q,
|
||||
to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
to_math_trig=dmaIN_to_math_trig,
|
||||
completion_q=dmaIN_completion_q,
|
||||
pe_dma_port=pe_dma_port,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
|
||||
self._dmaWB_block = DmaWbBlock(
|
||||
env=env,
|
||||
completion_q=dmaWB_completion_q,
|
||||
pe_dma_port=pe_dma_port,
|
||||
pe_prefix=self._pe_prefix,
|
||||
bytes_per_element=self._bpe,
|
||||
)
|
||||
|
||||
self._gemm_block = GemmBlock(
|
||||
env=env,
|
||||
trig_in=self._dmaIN_to_fetch_trig,
|
||||
fetch_to_gemm_trig=fetch_to_gemm_trig,
|
||||
to_math_trig=gemm_to_math_trig,
|
||||
to_dmaWB_trig=gemm_to_dmaWB_trig,
|
||||
completion_q=gemm_completion_q,
|
||||
tcm_port=self._tcm_block.port,
|
||||
mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n,
|
||||
bytes_per_element=self._bpe,
|
||||
clock_freq_ghz=self._clock_freq_ghz,
|
||||
)
|
||||
|
||||
self._math_block = MathBlock(
|
||||
env=env,
|
||||
trig_in=gemm_to_math_trig,
|
||||
to_dmaWB_trig=math_to_dmaWB_trig,
|
||||
completion_q=math_completion_q,
|
||||
tcm_port=self._tcm_block.port,
|
||||
bytes_per_element=self._bpe,
|
||||
clock_freq_ghz=self._clock_freq_ghz,
|
||||
vector_width=self._vector_width,
|
||||
)
|
||||
|
||||
# -- Start block processes ---------------------------------------------
|
||||
|
||||
env.process(self._tcm_block._run())
|
||||
env.process(self._dmaIN_block._load_loop())
|
||||
env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig))
|
||||
env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig))
|
||||
env.process(self._gemm_block._fetch_stage())
|
||||
env.process(self._gemm_block._gemm_stage())
|
||||
env.process(self._math_block._run_k_accumulation())
|
||||
env.process(self._math_block._run_element_wise(dmaIN_to_math_trig))
|
||||
|
||||
# Wire in-ports → inbox, start _worker
|
||||
super().start(env)
|
||||
|
||||
# Start completion routers
|
||||
for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q):
|
||||
env.process(self._completion_router(q))
|
||||
|
||||
# -- Internal processes ----------------------------------------------------
|
||||
|
||||
def _completion_router(self, queue: simpy.Store) -> Generator:
|
||||
"""Route block completion triggers to the correct pipeline's reply queue."""
|
||||
while True:
|
||||
trigger = yield queue.get()
|
||||
if trigger is None:
|
||||
break
|
||||
reply_q = self._pipeline_queues.get(trigger.pipeline_id)
|
||||
if reply_q is not None:
|
||||
yield reply_q.put(trigger)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Main inbox loop: dispatch PE commands, forward fabric transactions."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ==========================================================================
|
||||
# Command Dispatch
|
||||
# ==========================================================================
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Route a PeInternalTxn to the appropriate engine or pipeline."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd, DmaReadCmd, DmaWriteCmd,
|
||||
GemmCmd, MathCmd, PeCpuOverheadCmd,
|
||||
)
|
||||
|
||||
yield from self.run(env, 0) # scheduler overhead
|
||||
|
||||
cmd = pe_txn.command
|
||||
pp = self._pe_prefix
|
||||
|
||||
if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)):
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, GemmCmd):
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, MathCmd):
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(pe_txn)
|
||||
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._dispatch_composite_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._dispatch_composite_math(env, pe_txn, cmd)
|
||||
|
||||
elif isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles / self._clock_freq_ghz)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- GEMM composite --------------------------------------------------------
|
||||
|
||||
def _dispatch_composite_gemm(
|
||||
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
||||
) -> Generator:
|
||||
"""Run tiled GEMM pipeline and collect per-stage metrics."""
|
||||
from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline
|
||||
|
||||
a, b = cmd.a, cmd.b
|
||||
M, K, N = a.shape[-2], a.shape[-1], b.shape[-1]
|
||||
|
||||
pid, reply_q = self._alloc_pipeline(env)
|
||||
|
||||
pipeline = GemmPipeline(
|
||||
env=env, M=M, K=K, N=N,
|
||||
tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n,
|
||||
bytes_per_element=self._bpe,
|
||||
pipeline_id=pid,
|
||||
reply_queue=reply_q,
|
||||
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
||||
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||
)
|
||||
|
||||
# Load descriptors into shared blocks
|
||||
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
||||
self._gemm_block.load_descriptors(pipeline.gemm_descs)
|
||||
self._math_block.load_descriptors(pipeline.math_descs)
|
||||
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
||||
|
||||
start_ns = env.now
|
||||
yield pipeline.done
|
||||
wall_clock_ns = env.now - start_ns
|
||||
del self._pipeline_queues[pid]
|
||||
|
||||
# -- Collect metrics ---------------------------------------------------
|
||||
rd = pe_txn.result_data
|
||||
rd["M_tiles"] = pipeline.M_tiles
|
||||
rd["K_tiles"] = pipeline.K_tiles
|
||||
rd["N_tiles"] = pipeline.N_tiles
|
||||
rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles
|
||||
rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
||||
rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values())
|
||||
rd["total_dma_write_bytes"] = sum(
|
||||
d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values()
|
||||
)
|
||||
rd["wall_clock_ns"] = wall_clock_ns
|
||||
|
||||
# Per-stage timing
|
||||
rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = (
|
||||
_collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = (
|
||||
_collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
||||
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
||||
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# Derived metrics
|
||||
rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0
|
||||
rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0
|
||||
rd["pipeline_parallelism"] = (
|
||||
(rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"])
|
||||
/ wall_clock_ns if wall_clock_ns > 0 else 0.0
|
||||
)
|
||||
rd["effective_read_bw_gbs"] = (
|
||||
rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"]
|
||||
if rd["t_dma_read_total_ns"] > 0 else None
|
||||
)
|
||||
rd["effective_write_bw_gbs"] = (
|
||||
rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"]
|
||||
if rd["t_dma_write_total_ns"] > 0 else None
|
||||
)
|
||||
rd["bottleneck_stage"] = max(
|
||||
[("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]),
|
||||
("store", rd["t_tcm_store_max_ns"])],
|
||||
key=lambda x: x[1],
|
||||
)[0]
|
||||
|
||||
# pe_cpu.py compatibility aliases
|
||||
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
||||
rd["compute_ns"] = rd["t_compute_total_ns"]
|
||||
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- Math composite --------------------------------------------------------
|
||||
|
||||
def _dispatch_composite_math(
|
||||
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
||||
) -> Generator:
|
||||
"""Run tiled element-wise math pipeline and collect metrics."""
|
||||
from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline
|
||||
|
||||
assert self._dmaIN_cmd_q is not None
|
||||
assert self._dmaIN_to_fetch_trig is not None
|
||||
|
||||
a = cmd.a
|
||||
M = a.shape[-2] if len(a.shape) >= 2 else 1
|
||||
N = a.shape[-1]
|
||||
op = cmd.math_op or "identity"
|
||||
|
||||
pid, reply_q = self._alloc_pipeline(env)
|
||||
|
||||
pipeline = MathPipeline(
|
||||
env=env, M=M, N=N,
|
||||
tile_m=self._tile_m, tile_n=self._tile_n,
|
||||
bytes_per_element=self._bpe,
|
||||
pipeline_id=pid,
|
||||
reply_queue=reply_q,
|
||||
dmaIN_cmd_q=self._dmaIN_cmd_q,
|
||||
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
|
||||
op=op,
|
||||
src_addr=a.addr, dst_addr=cmd.out_addr,
|
||||
)
|
||||
|
||||
# Load descriptors into shared blocks
|
||||
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
|
||||
self._math_block.load_math_op_descriptors(pipeline.math_op_descs)
|
||||
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
|
||||
|
||||
start_ns = env.now
|
||||
yield pipeline.done
|
||||
wall_clock_ns = env.now - start_ns
|
||||
del self._pipeline_queues[pid]
|
||||
|
||||
# -- Collect metrics ---------------------------------------------------
|
||||
rd = pe_txn.result_data
|
||||
rd["M_tiles"] = pipeline.M_tiles
|
||||
rd["N_tiles"] = pipeline.N_tiles
|
||||
rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles
|
||||
rd["wall_clock_ns"] = wall_clock_ns
|
||||
rd["math_op"] = op
|
||||
|
||||
# DMA timing
|
||||
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
|
||||
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
|
||||
)
|
||||
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
|
||||
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# Math op timing (load + compute + store)
|
||||
rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, []))
|
||||
)
|
||||
rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = (
|
||||
_collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, []))
|
||||
)
|
||||
|
||||
# pe_cpu.py compatibility aliases
|
||||
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
|
||||
rd["compute_ns"] = rd["t_math_compute_total_ns"]
|
||||
|
||||
pe_txn.done.succeed()
|
||||
|
||||
# -- Helpers ---------------------------------------------------------------
|
||||
|
||||
def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]:
|
||||
"""Allocate a pipeline ID and reply queue."""
|
||||
pid = self._next_pipeline_id
|
||||
self._next_pipeline_id += 1
|
||||
reply_q = simpy.Store(env)
|
||||
self._pipeline_queues[pid] = reply_q
|
||||
return pid, reply_q
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Utility
|
||||
# ==============================================================================
|
||||
|
||||
def _collect_timing(times: list) -> tuple[list, float, float]:
|
||||
"""Return (raw_list, max_ns, total_ns) from a timing list."""
|
||||
return times, max(times, default=0.0), sum(times)
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Tile schedule generators for GEMM and element-wise math operations.
|
||||
|
||||
Each generator produces a plan of tile commands with pre-computed addresses.
|
||||
Pipelines use these plans to build descriptor tables and feed commands
|
||||
to the shared hardware blocks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
from kernbench.components.custom.pe_accel.types import (
|
||||
CmdType,
|
||||
MathSchedulePlan,
|
||||
MathTileCommand,
|
||||
SchedulePlan,
|
||||
TileCommand,
|
||||
)
|
||||
|
||||
|
||||
def generate_gemm_tiles(
|
||||
M_tiles: int, K_tiles: int, N_tiles: int,
|
||||
M: int = 0, K: int = 0, N: int = 0,
|
||||
tile_m: int = 0, tile_k: int = 0, tile_n: int = 0,
|
||||
bytes_per_element: int = 2,
|
||||
A_addr: int = 0, B_addr: int = 0, C_addr: int = 0,
|
||||
pipeline_id: int = 0,
|
||||
) -> SchedulePlan:
|
||||
"""Generate GEMM tile commands in M (outer) -> N -> K (inner) order.
|
||||
|
||||
Stamps is_last_k=True on the final K-tile per (m, n) pair.
|
||||
Emits one DMA_FLUSH per (m, n) pair after all K tiles.
|
||||
|
||||
Per-tile addresses (row-major layout):
|
||||
A (M,K): A_addr + (m * tile_m * K + k * tile_k) * bpe
|
||||
B (K,N): B_addr + (k * tile_k * N + n * tile_n) * bpe
|
||||
C (M,N): C_addr + (m * tile_m * N + n * tile_n) * bpe
|
||||
"""
|
||||
commands: list[TileCommand] = []
|
||||
cmd_id = 0
|
||||
tile_id = 0
|
||||
bpe = bytes_per_element
|
||||
|
||||
for m in range(M_tiles):
|
||||
for n in range(N_tiles):
|
||||
c_tile_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
|
||||
|
||||
for k in range(K_tiles):
|
||||
last_k = k == K_tiles - 1
|
||||
a_tile_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
|
||||
b_tile_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
|
||||
|
||||
commands.append(TileCommand(
|
||||
cmd_id=cmd_id, cmd_type=CmdType.DMA_LOAD,
|
||||
tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n,
|
||||
is_last_k=last_k, pipeline_id=pipeline_id,
|
||||
a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr,
|
||||
c_tile_addr=c_tile_addr,
|
||||
))
|
||||
cmd_id += 1
|
||||
|
||||
commands.append(TileCommand(
|
||||
cmd_id=cmd_id, cmd_type=CmdType.TENSOR_OP,
|
||||
tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n,
|
||||
is_last_k=last_k, pipeline_id=pipeline_id,
|
||||
a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr,
|
||||
c_tile_addr=c_tile_addr,
|
||||
))
|
||||
cmd_id += 1
|
||||
tile_id += 1
|
||||
|
||||
# One flush per (m, n) pair after all K tiles
|
||||
commands.append(TileCommand(
|
||||
cmd_id=cmd_id, cmd_type=CmdType.DMA_FLUSH,
|
||||
tile_id=tile_id - 1, m_idx=m, k_idx=0, n_idx=n,
|
||||
pipeline_id=pipeline_id,
|
||||
c_tile_addr=c_tile_addr,
|
||||
))
|
||||
cmd_id += 1
|
||||
|
||||
return SchedulePlan(
|
||||
commands=commands, M_tiles=M_tiles, K_tiles=K_tiles, N_tiles=N_tiles
|
||||
)
|
||||
|
||||
|
||||
def generate_math_tiles(
|
||||
M_tiles: int, N_tiles: int,
|
||||
M: int = 0, N: int = 0,
|
||||
tile_m: int = 0, tile_n: int = 0,
|
||||
bytes_per_element: int = 2,
|
||||
src_addr: int = 0, dst_addr: int = 0,
|
||||
pipeline_id: int = 0,
|
||||
) -> MathSchedulePlan:
|
||||
"""Generate element-wise math tile commands in row-major order.
|
||||
|
||||
Per-tile addresses (row-major layout):
|
||||
src: src_addr + (m * tile_m * N + n * tile_n) * bpe
|
||||
dst: dst_addr + (m * tile_m * N + n * tile_n) * bpe
|
||||
"""
|
||||
commands: list[MathTileCommand] = []
|
||||
cmd_id = 0
|
||||
tile_id = 0
|
||||
bpe = bytes_per_element
|
||||
|
||||
for m in range(M_tiles):
|
||||
for n in range(N_tiles):
|
||||
offset = (m * tile_m * N + n * tile_n) * bpe
|
||||
commands.append(MathTileCommand(
|
||||
cmd_id=cmd_id,
|
||||
tile_id=tile_id,
|
||||
m_idx=m,
|
||||
n_idx=n,
|
||||
src_tile_addr=src_addr + offset,
|
||||
dst_tile_addr=dst_addr + offset,
|
||||
pipeline_id=pipeline_id,
|
||||
))
|
||||
cmd_id += 1
|
||||
tile_id += 1
|
||||
|
||||
return MathSchedulePlan(
|
||||
commands=commands, M_tiles=M_tiles, N_tiles=N_tiles
|
||||
)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Data types for pe_accel_v1: descriptors, triggers, tile commands.
|
||||
|
||||
All types are frozen/plain dataclasses with no logic.
|
||||
Schedule generators live in tiling/schedule.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
# -- Enums ---------------------------------------------------------------------
|
||||
|
||||
class CmdType(Enum):
|
||||
DMA_LOAD = auto()
|
||||
TENSOR_OP = auto()
|
||||
DMA_FLUSH = auto()
|
||||
|
||||
|
||||
# -- Inter-block messaging -----------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class Trigger:
|
||||
"""Completion token passed between hardware blocks."""
|
||||
|
||||
tile_id: int
|
||||
pipeline_id: int
|
||||
vc: int | None = None
|
||||
source_block: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DmaRequest:
|
||||
"""DMA load request — descriptor lookup key only.
|
||||
|
||||
Transfer params (size, address) live in the pre-loaded DmaInDescriptor.
|
||||
"""
|
||||
|
||||
tile_id: int
|
||||
pipeline_id: int
|
||||
operand: str # "A" or "B"
|
||||
|
||||
|
||||
# -- Descriptors (pre-loaded by pipelines, consumed by blocks) -----------------
|
||||
|
||||
@dataclass
|
||||
class DmaInDescriptor:
|
||||
"""Per-operand DMA read descriptor."""
|
||||
|
||||
size_bytes: int
|
||||
src_addr: int = 0
|
||||
next_block: str = "GEMM" # "GEMM" | "MATH" | "COMPLETION"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GemmDescriptor:
|
||||
"""Per-tile GEMM descriptor."""
|
||||
|
||||
Tm: int
|
||||
Tk: int
|
||||
Tn: int
|
||||
triggers_needed: int = 2 # 2 = both operands from DMA; 1 = SPMem bypass
|
||||
gemm_load: bool = True
|
||||
gemm_compute: bool = True
|
||||
next_block: str = "MATH" # "MATH" | "DMAWB" | "DONE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathDescriptor:
|
||||
"""Per-tile K-accumulation descriptor (used by GEMM pipeline)."""
|
||||
|
||||
Tm: int
|
||||
Tn: int
|
||||
is_last_k: bool
|
||||
skip_dmaWB: bool # True = C stays in SPMem; False = flush to HBM
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathOpDescriptor:
|
||||
"""Per-tile element-wise math op descriptor (used by math pipeline)."""
|
||||
|
||||
Tm: int
|
||||
Tn: int
|
||||
op: str # "exp", "log", "sqrt", "sigmoid", etc.
|
||||
src_addr: int = 0
|
||||
dst_addr: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DmaWBDescriptor:
|
||||
"""Per-tile DMA writeback descriptor."""
|
||||
|
||||
Tm: int
|
||||
Tn: int
|
||||
dst_addr: int = 0
|
||||
|
||||
|
||||
# -- Tile commands (produced by schedule generators) ---------------------------
|
||||
|
||||
@dataclass
|
||||
class TileCommand:
|
||||
"""A single GEMM tile command."""
|
||||
|
||||
cmd_id: int
|
||||
cmd_type: CmdType
|
||||
tile_id: int
|
||||
m_idx: int
|
||||
k_idx: int
|
||||
n_idx: int
|
||||
is_last_k: bool = False
|
||||
pipeline_id: int = 0
|
||||
a_tile_addr: int = 0
|
||||
b_tile_addr: int = 0
|
||||
c_tile_addr: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulePlan:
|
||||
"""Full tile schedule for one GEMM operation."""
|
||||
|
||||
commands: list # list[TileCommand]
|
||||
M_tiles: int
|
||||
K_tiles: int
|
||||
N_tiles: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathTileCommand:
|
||||
"""A single element-wise math tile command."""
|
||||
|
||||
cmd_id: int
|
||||
tile_id: int
|
||||
m_idx: int
|
||||
n_idx: int
|
||||
src_tile_addr: int = 0
|
||||
dst_tile_addr: int = 0
|
||||
pipeline_id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathSchedulePlan:
|
||||
"""Full tile schedule for one element-wise math operation."""
|
||||
|
||||
commands: list # list[MathTileCommand]
|
||||
M_tiles: int
|
||||
N_tiles: int
|
||||
|
||||
|
||||
@@ -13,11 +13,19 @@ class DPPolicy:
|
||||
- "replicate": full copy at each unit
|
||||
- "column_wise": split K (column) axis across units
|
||||
- "row_wise": split M (row) axis across units
|
||||
|
||||
Optional overrides (default None = use topology dimensions):
|
||||
- num_pes: override PEs per cube (e.g., 1 for single-PE test)
|
||||
- num_cubes: override cubes per SIP (e.g., 1 for single-cube test)
|
||||
- num_sips: override SIP count
|
||||
"""
|
||||
|
||||
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
num_pes: int | None = None
|
||||
num_cubes: int | None = None
|
||||
num_sips: int | None = None
|
||||
|
||||
|
||||
def _split_shape(
|
||||
|
||||
@@ -269,15 +269,25 @@ class RuntimeContext:
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0])
|
||||
# DPPolicy overrides take precedence over topology dimensions
|
||||
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
|
||||
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||||
eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=self._num_cubes,
|
||||
num_sips=self._num_sips,
|
||||
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
|
||||
num_sips=eff_num_sips,
|
||||
)
|
||||
|
||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||
pe_indices = {s.pe_index for s in placement}
|
||||
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
|
||||
# Infer target_pe from placement using local (within-cube) PE IDs.
|
||||
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
|
||||
local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement})
|
||||
if len(local_pe_ids) == 1:
|
||||
target_pe: int | tuple[int, ...] | str = local_pe_ids[0]
|
||||
elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube:
|
||||
target_pe = "all"
|
||||
else:
|
||||
target_pe = tuple(local_pe_ids)
|
||||
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
||||
|
||||
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
|
||||
@@ -407,7 +417,8 @@ class RuntimeContext:
|
||||
# Collect tensors and scalars
|
||||
tensor_args: list[Tensor] = []
|
||||
scalar_args: list = []
|
||||
target_pe: int | str = 0
|
||||
_pe_set: set[int] = set()
|
||||
_pe_all = False
|
||||
|
||||
for a in args:
|
||||
if isinstance(a, Tensor):
|
||||
@@ -415,9 +426,11 @@ class RuntimeContext:
|
||||
if a._dp_metadata is not None:
|
||||
dp_target = a._dp_metadata.target_pe
|
||||
if dp_target == "all":
|
||||
target_pe = "all"
|
||||
elif isinstance(dp_target, int) and target_pe != "all":
|
||||
target_pe = dp_target
|
||||
_pe_all = True
|
||||
elif isinstance(dp_target, tuple):
|
||||
_pe_set.update(dp_target)
|
||||
elif isinstance(dp_target, int):
|
||||
_pe_set.add(dp_target)
|
||||
elif isinstance(a, (int, float)):
|
||||
dtype_str = "f32" if isinstance(a, float) else "i32"
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
@@ -427,6 +440,16 @@ class RuntimeContext:
|
||||
dtype_str = "f32" if isinstance(v, float) else "i32"
|
||||
scalar_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
|
||||
# Resolve target_pe from collected PE info
|
||||
if _pe_all:
|
||||
target_pe: int | tuple[int, ...] | str = "all"
|
||||
elif len(_pe_set) == 1:
|
||||
target_pe = next(iter(_pe_set))
|
||||
elif len(_pe_set) > 1:
|
||||
target_pe = tuple(sorted(_pe_set))
|
||||
else:
|
||||
target_pe = 0
|
||||
|
||||
# Determine all target SIPs from tensor shards
|
||||
sip_set: set[int] = set()
|
||||
for t in tensor_args:
|
||||
|
||||
@@ -89,7 +89,7 @@ class KernelLaunchMsg:
|
||||
kernel_ref: KernelRef
|
||||
args: tuple[KernelArg, ...]
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
target_pe: int | tuple[int, ...] | Literal["all"] = "all"
|
||||
msg_type: Literal["kernel_launch"] = "kernel_launch"
|
||||
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class DPMetadata:
|
||||
dp_policy: DPPolicy | None = None
|
||||
sip: int = 0
|
||||
cube: int = 0
|
||||
target_pe: int | str = 0 # int → single PE, "all" → all PEs
|
||||
target_pe: int | tuple[int, ...] | str = 0 # int → single PE, tuple → specific PEs, "all" → all PEs
|
||||
|
||||
|
||||
class Tensor:
|
||||
@@ -166,7 +166,7 @@ class Tensor:
|
||||
dp_policy: DPPolicy | None = None,
|
||||
sip: int = 0,
|
||||
cube: int = 0,
|
||||
target_pe: int | str = 0,
|
||||
target_pe: int | tuple[int, ...] | str = 0,
|
||||
) -> Tensor:
|
||||
"""Set DP placement metadata (like torch.Tensor.to())."""
|
||||
if placement is None:
|
||||
|
||||
+1
-1
@@ -61,7 +61,7 @@ cube:
|
||||
pe_template:
|
||||
components:
|
||||
pe_cpu: { kind: pe_cpu, impl: pe_cpu_v1, attrs: { overhead_ns: 2.0 } }
|
||||
pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v1, attrs: { overhead_ns: 1.0 } }
|
||||
pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { overhead_ns: 1.0 } }
|
||||
pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } }
|
||||
pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } }
|
||||
pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }
|
||||
|
||||
Reference in New Issue
Block a user