From 114510d4b94247073c485b07a83a50bf21c23cf1 Mon Sep 17 00:00:00 2001 From: Yangwook Date: Thu, 26 Mar 2026 23:18:49 -0700 Subject: [PATCH] Add SchedulerV2 (pe_accel), DPPolicy overrides, and new benchmarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add cycle-accurate PE accelerator scheduler (SchedulerV2) with tiled GEMM/Math pipelines (DMA_IN → GEMM → MATH → DMA_WB) - Add DPPolicy num_pes/num_cubes/num_sips overrides for single-PE testing - Support tuple target_pe for targeting specific PE subsets - Add gemm_single_pe and gpt3_qkv benchmarks - Switch default topology to pe_scheduler_v2 Co-Authored-By: Claude Opus 4.6 (1M context) --- benches/gemm_single_pe.py | 39 ++ benches/gpt3_qkv.py | 92 ++++ components.yaml | 2 +- src/kernbench/components/builtin/m_cpu.py | 4 +- .../components/custom/pe_accel/__init__.py | 20 + .../custom/pe_accel/blocks/__init__.py | 16 + .../custom/pe_accel/blocks/dma_in.py | 96 ++++ .../custom/pe_accel/blocks/dma_wb.py | 88 ++++ .../components/custom/pe_accel/blocks/gemm.py | 160 +++++++ .../components/custom/pe_accel/blocks/math.py | 181 ++++++++ .../components/custom/pe_accel/blocks/tcm.py | 80 ++++ .../custom/pe_accel/scheduler/__init__.py | 10 + .../pe_accel/scheduler/gemm_pipeline.py | 157 +++++++ .../pe_accel/scheduler/math_pipeline.py | 132 ++++++ .../custom/pe_accel/scheduler/scheduler.py | 434 ++++++++++++++++++ .../custom/pe_accel/scheduler/tiling.py | 121 +++++ .../components/custom/pe_accel/types.py | 148 ++++++ src/kernbench/policy/placement/dp.py | 8 + src/kernbench/runtime_api/context.py | 41 +- src/kernbench/runtime_api/kernel.py | 2 +- src/kernbench/runtime_api/tensor.py | 4 +- topology.yaml | 2 +- 22 files changed, 1822 insertions(+), 15 deletions(-) create mode 100644 benches/gemm_single_pe.py create mode 100644 benches/gpt3_qkv.py create mode 100644 src/kernbench/components/custom/pe_accel/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/dma_in.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/dma_wb.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/gemm.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/math.py create mode 100644 src/kernbench/components/custom/pe_accel/blocks/tcm.py create mode 100644 src/kernbench/components/custom/pe_accel/scheduler/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel/scheduler/gemm_pipeline.py create mode 100644 src/kernbench/components/custom/pe_accel/scheduler/math_pipeline.py create mode 100644 src/kernbench/components/custom/pe_accel/scheduler/scheduler.py create mode 100644 src/kernbench/components/custom/pe_accel/scheduler/tiling.py create mode 100644 src/kernbench/components/custom/pe_accel/types.py diff --git a/benches/gemm_single_pe.py b/benches/gemm_single_pe.py new file mode 100644 index 0000000..dda336f --- /dev/null +++ b/benches/gemm_single_pe.py @@ -0,0 +1,39 @@ +"""Single-PE GEMM benchmark via scheduler_v2 (pe_accel). + +Full host-to-PE pipeline: + Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM + +Single PE: num_sips=1, num_cubes=1, num_pes=1 via DPPolicy override. +Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams +per-tile DMA internally. + +Run: + kernbench run gemm_single_pe +""" +from kernbench.policy.placement.dp import DPPolicy + +# GEMM dimensions: (M, K) x (K, N) → (M, N) +M, K, N = 32, 128, 32 +DTYPE = "f16" + + +def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): + """Single-PE GEMM: out = a @ b. Both operands streamed from HBM by scheduler.""" + M, K, N = int(M), int(K), int(N) + + a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE) + b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE) + h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) + tl.wait(h) + + +def run(torch): + """Run the single-PE GEMM benchmark.""" + dp = DPPolicy(cube="replicate", pe="replicate", + num_sips=1, num_cubes=1, num_pes=1) + + a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a") + b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b") + out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out") + + torch.launch("gemm_single_pe", _gemm_kernel, a, b, out, M, K, N) diff --git a/benches/gpt3_qkv.py b/benches/gpt3_qkv.py new file mode 100644 index 0000000..5ff8fd4 --- /dev/null +++ b/benches/gpt3_qkv.py @@ -0,0 +1,92 @@ +"""GPT-3 QKV projection benchmark: sharded across PEs via pe_accel_v1. + +GPT-3 architecture: + d_model = 12288 (hidden dimension) + n_heads = 96 (attention heads) + d_head = 128 (dimension per head) + +Sharding strategy (column-wise across all PEs): + X : (seq_len, d_model) -- replicated to all PEs + W_Q/K/V : (d_model, d_model) -- column-wise sharded across cubes × PEs + out_Q/K/V: (seq_len, d_model) -- column-wise sharded across cubes × PEs + +Each PE computes: + Q_slice = X @ W_Q_slice : (seq_len, d_model) @ (d_model, cols_per_pe) -> (seq_len, cols_per_pe) + K_slice, V_slice: same + +PE count is configurable via N_CUBES × N_PE_PER_CUBE (DPPolicy override). +topology.yaml is unchanged. + +Run: + kernbench run gpt3_qkv +""" +from kernbench.policy.placement.dp import DPPolicy + +# -- PE configuration (DPPolicy overrides — does not change topology.yaml) ----- +N_SIPS = 1 +N_CUBES = 16 # cubes per SIP +N_PE_PER_CUBE = 8 # PEs per cube +N_PES = N_CUBES * N_PE_PER_CUBE # 128 total + +# -- GPT-3 architecture ------------------------------------------------------- +GPT3_D_MODEL = 12288 +SEQ_LEN = 32 +COLS_PER_PE = GPT3_D_MODEL // N_PES # 12288 / 128 = 96 +DTYPE = "f16" + + +def _gpt3_qkv_kernel(x_ptr, wq_ptr, wk_ptr, wv_ptr, + out_q_ptr, out_k_ptr, out_v_ptr, + seq_len, d_model, cols_per_pe, tl, DTYPE="f16"): + """GPT-3 QKV sharded: each PE uses program_id to index its VA slice.""" + pid = tl.program_id(0) + bpe = 2 # f16 + + M = int(seq_len) + K = int(d_model) + N = int(cols_per_pe) + + w_slice = K * N * bpe + out_slice = M * N * bpe + + x = tl.load(int(x_ptr), shape=(M, K), dtype=DTYPE) + wq = tl.ref(int(wq_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE) + wk = tl.ref(int(wk_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE) + wv = tl.ref(int(wv_ptr) + pid * w_slice, shape=(K, N), dtype=DTYPE) + + hq = tl.composite(op="gemm", a=x, b=wq, + out_ptr=int(out_q_ptr) + pid * out_slice) + hk = tl.composite(op="gemm", a=x, b=wk, + out_ptr=int(out_k_ptr) + pid * out_slice) + hv = tl.composite(op="gemm", a=x, b=wv, + out_ptr=int(out_v_ptr) + pid * out_slice) + + tl.wait(hq) + tl.wait(hk) + tl.wait(hv) + + +def run(torch): + """Run the GPT-3 QKV benchmark.""" + M = SEQ_LEN + K = GPT3_D_MODEL + N = COLS_PER_PE + + # X: replicated across all PEs + dp_replicate = DPPolicy(cube="replicate", pe="replicate", + num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) + # W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs + dp_sharded = DPPolicy(cube="column_wise", pe="column_wise", + num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) + + x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x") + wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq") + wk = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wk") + wv = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wv") + out_q = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_q") + out_k = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_k") + out_v = torch.empty((M, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="out_v") + + torch.launch("gpt3_qkv", _gpt3_qkv_kernel, + x, wq, wk, wv, out_q, out_k, out_v, + M, K, N) diff --git a/components.yaml b/components.yaml index ee459d8..8bf0f85 100644 --- a/components.yaml +++ b/components.yaml @@ -50,4 +50,4 @@ components: pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent # Custom — add your implementations here - # pe_cpu_v2: kernbench.components.custom.my_pe_cpu:MyPeCpuComponent + pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component diff --git a/src/kernbench/components/builtin/m_cpu.py b/src/kernbench/components/builtin/m_cpu.py index 40c9ae5..f62a15b 100644 --- a/src/kernbench/components/builtin/m_cpu.py +++ b/src/kernbench/components/builtin/m_cpu.py @@ -320,10 +320,12 @@ class MCpuComponent(ComponentBase): else: txn.done.succeed() - def _resolve_pe_ids(self, target_pe: int | str) -> list[int]: + def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]: """Return list of PE IDs to fan out to (used by kernel launch fan-out).""" if isinstance(target_pe, int): return [target_pe] + if isinstance(target_pe, tuple): + return list(target_pe) # "all": all PEs in local cube n_slices = 8 if self.ctx and self.ctx.spec: diff --git a/src/kernbench/components/custom/pe_accel/__init__.py b/src/kernbench/components/custom/pe_accel/__init__.py new file mode 100644 index 0000000..60b3647 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/__init__.py @@ -0,0 +1,20 @@ +"""PeAccel: cycle-accurate accelerator component for pe_scheduler slot. + +Register in components.yaml as: + pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component + +Then reference in topology.yaml: + pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { ... } } + +Package layout: + scheduler/ — scheduler block (component + dispatch + tiling) + scheduler.py — SchedulerV2Component + gemm_pipeline.py — tiled GEMM coordinator + math_pipeline.py — tiled element-wise math coordinator + tile_address.py — per-tile address computation + blocks/ — hardware blocks (DMA_IN, DMA_WB, GEMM, MATH, TCM) + types.py — data classes (descriptors, triggers, tile commands) +""" +from kernbench.components.custom.pe_accel.scheduler import SchedulerV2Component + +__all__ = ["SchedulerV2Component"] diff --git a/src/kernbench/components/custom/pe_accel/blocks/__init__.py b/src/kernbench/components/custom/pe_accel/blocks/__init__.py new file mode 100644 index 0000000..9864a31 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/__init__.py @@ -0,0 +1,16 @@ +"""Hardware blocks for pe_accel. + +Each block is a concurrent SimPy process modeling one functional unit: + - DmaInBlock: HBM → TCM tile reads (issues real DmaReadCmd to PE_DMA) + - DmaWbBlock: TCM → HBM tile writes (issues real DmaWriteCmd to PE_DMA) + - GemmBlock: 2-stage MAC pipeline (fetch + compute) + - MathBlock: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.) + - TcmBlock: TCM access serialization with BW-based timing +""" +from kernbench.components.custom.pe_accel.blocks.dma_in import DmaInBlock +from kernbench.components.custom.pe_accel.blocks.dma_wb import DmaWbBlock +from kernbench.components.custom.pe_accel.blocks.gemm import GemmBlock +from kernbench.components.custom.pe_accel.blocks.math import MathBlock +from kernbench.components.custom.pe_accel.blocks.tcm import TcmBlock, TcmRequest + +__all__ = ["DmaInBlock", "DmaWbBlock", "GemmBlock", "MathBlock", "TcmBlock", "TcmRequest"] diff --git a/src/kernbench/components/custom/pe_accel/blocks/dma_in.py b/src/kernbench/components/custom/pe_accel/blocks/dma_in.py new file mode 100644 index 0000000..979fcc7 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/dma_in.py @@ -0,0 +1,96 @@ +"""DMA IN Block: reads tiles from HBM into TCM via real PE_DMA fabric. + +Flow per tile: + 1. Receive DmaRequest from tiling pipeline + 2. Look up DmaInDescriptor for address and size + 3. Issue DmaReadCmd → PE_DMA → fabric → HBM controller → response + 4. Route completion Trigger to next block (GEMM, MATH, or COMPLETION) + +Timing is real fabric latency (not analytical) — includes BW contention, +propagation delay, and HBM controller serialization. +""" +from __future__ import annotations + +import simpy + +from kernbench.components.custom.pe_accel.types import DmaInDescriptor, DmaRequest, Trigger + + +class DmaInBlock: + """HBM → TCM tile reader. Shared across all concurrent pipelines. + + Pipelines pre-load DmaInDescriptors keyed by (pipeline_id, tile_id, operand). + The _load_loop process reads DmaRequests, issues real DmaReadCmd to PE_DMA, + and routes completion triggers based on descriptor.next_block. + """ + + def __init__( + self, + env: simpy.Environment, + cmd_q: simpy.Store, + to_fetch_trig: simpy.Store, + to_math_trig: simpy.Store, + completion_q: simpy.Store, + *, + pe_dma_port: simpy.Store | None, + pe_prefix: str, + ) -> None: + self.env = env + self.cmd_q = cmd_q + self.to_fetch_trig = to_fetch_trig + self.to_math_trig = to_math_trig + self.completion_q = completion_q + self._pe_dma_port = pe_dma_port + self._pe_prefix = pe_prefix + self._descriptor_table: dict[tuple[int, int, str], DmaInDescriptor] = {} + + # Per-pipeline timing histogram (keyed by pipeline_id) + self.t_dma_read_per_request: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, DmaInDescriptor]) -> None: + """Pre-load per-operand DMA descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _load_loop(self): + """Main process: receive DmaRequests, issue DmaReadCmd, route triggers.""" + from kernbench.common.pe_commands import DmaReadCmd, PeInternalTxn, TensorHandle + + while True: + req: DmaRequest = yield self.cmd_q.get() + if req is None: + break + + desc = self._descriptor_table[(req.pipeline_id, req.tile_id, req.operand)] + + # Issue real DMA read through PE_DMA → fabric → HBM + read_done = self.env.event() + handle = TensorHandle( + id=f"accel_rd_{req.pipeline_id}_{req.tile_id}_{req.operand}", + addr=desc.src_addr, + shape=(desc.size_bytes,), + dtype="uint8", + nbytes=desc.size_bytes, + ) + txn = PeInternalTxn( + command=DmaReadCmd(handle=handle, src_addr=desc.src_addr, nbytes=desc.size_bytes), + done=read_done, + pe_prefix=self._pe_prefix, + ) + t0 = self.env.now + yield self._pe_dma_port.put(txn) + yield read_done + self.t_dma_read_per_request.setdefault(req.pipeline_id, []).append(self.env.now - t0) + + # Route trigger to next block + trig = Trigger( + tile_id=req.tile_id, + pipeline_id=req.pipeline_id, + vc=0 if req.operand == "A" else 1, + source_block="DMA_IN", + ) + if desc.next_block == "MATH": + yield self.to_math_trig.put(trig) + elif desc.next_block == "COMPLETION": + yield self.completion_q.put(trig) + else: # "GEMM" (default) + yield self.to_fetch_trig.put(trig) diff --git a/src/kernbench/components/custom/pe_accel/blocks/dma_wb.py b/src/kernbench/components/custom/pe_accel/blocks/dma_wb.py new file mode 100644 index 0000000..538eca6 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/dma_wb.py @@ -0,0 +1,88 @@ +"""DMA Writeback Block: writes result tiles from TCM to HBM via real PE_DMA fabric. + +Flow per tile: + 1. Receive flush Trigger from GEMM or MATH block + 2. Look up DmaWBDescriptor for address and tile size + 3. Issue DmaWriteCmd → PE_DMA → fabric → HBM controller → response + 4. Send completion Trigger to pipeline + +Two _flush_loop processes run concurrently: + - One drains GEMM → DMA_WB triggers (direct writeback path) + - One drains MATH → DMA_WB triggers (K-accumulation or element-wise flush) +""" +from __future__ import annotations + +import simpy + +from kernbench.components.custom.pe_accel.types import DmaWBDescriptor, Trigger + + +class DmaWbBlock: + """TCM → HBM tile writer. Shared across all concurrent pipelines. + + Pipelines pre-load DmaWBDescriptors keyed by (pipeline_id, tile_id). + Each _flush_loop process reads triggers, issues real DmaWriteCmd to PE_DMA, + and forwards completion to the pipeline's reply queue. + """ + + def __init__( + self, + env: simpy.Environment, + completion_q: simpy.Store, + *, + pe_dma_port: simpy.Store | None, + pe_prefix: str, + bytes_per_element: int, + ) -> None: + self.env = env + self.completion_q = completion_q + self._pe_dma_port = pe_dma_port + self._pe_prefix = pe_prefix + self._bpe = bytes_per_element + self._descriptor_table: dict[tuple[int, int], DmaWBDescriptor] = {} + + # Per-pipeline timing histogram (keyed by pipeline_id) + self.t_dma_write_per_tile: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, DmaWBDescriptor]) -> None: + """Pre-load per-tile writeback descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _flush_loop(self, trig_q: simpy.Store): + """Main process: receive flush triggers, issue DmaWriteCmd, send completion.""" + from kernbench.common.pe_commands import DmaWriteCmd, PeInternalTxn, TensorHandle + + while True: + trigger: Trigger = yield trig_q.get() + if trigger is None: + break + + pid = trigger.pipeline_id + tile_id = trigger.tile_id + desc = self._descriptor_table.get((pid, tile_id)) + + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + + # Issue real DMA write through PE_DMA → fabric → HBM + write_done = self.env.event() + handle = TensorHandle( + id=f"accel_wb_{pid}_{tile_id}", + addr=desc.dst_addr, + shape=(desc.Tm, desc.Tn), + dtype="float16", + nbytes=c_bytes, + ) + txn = PeInternalTxn( + command=DmaWriteCmd(handle=handle, dst_addr=desc.dst_addr, nbytes=c_bytes), + done=write_done, + pe_prefix=self._pe_prefix, + ) + t0 = self.env.now + yield self._pe_dma_port.put(txn) + yield write_done + self.t_dma_write_per_tile.setdefault(pid, []).append(self.env.now - t0) + + yield self.completion_q.put( + Trigger(tile_id=tile_id, pipeline_id=pid, source_block="DMA_WB") + ) diff --git a/src/kernbench/components/custom/pe_accel/blocks/gemm.py b/src/kernbench/components/custom/pe_accel/blocks/gemm.py new file mode 100644 index 0000000..4ac6d2b --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/gemm.py @@ -0,0 +1,160 @@ +"""GEMM Block: 2-stage MAC pipeline (fetch + compute). + +Stage 1 — Fetch (_fetch_stage): + Collects DMA completion triggers (one per operand per tile). + When all operands arrive, issues TCM read request (SPMem → MAC registers). + +Stage 2 — Compute (_gemm_stage): + Models MAC array computation time. + Issues TCM write request (MAC result → SPMem). + Routes output trigger to MathBlock, DmaWbBlock, or completion. + +TCM access goes through TcmBlock for real BW serialization. +MAC compute time is cycle-accurate: ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n). +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest +from kernbench.components.custom.pe_accel.types import GemmDescriptor, Trigger + + +class GemmBlock: + """2-stage MAC pipeline shared across all concurrent pipelines. + + Pipelines pre-load GemmDescriptors keyed by (pipeline_id, tile_id). + Two SimPy processes run concurrently: _fetch_stage and _gemm_stage. + """ + + def __init__( + self, + env: simpy.Environment, + trig_in: simpy.Store, + fetch_to_gemm_trig: simpy.Store, + to_math_trig: simpy.Store, + to_dmaWB_trig: simpy.Store, + completion_q: simpy.Store, + *, + tcm_port: simpy.Store, + mac_m: int, + mac_k: int, + mac_n: int, + bytes_per_element: int, + clock_freq_ghz: float, + ) -> None: + self.env = env + self.trig_in = trig_in + self.fetch_to_gemm_trig = fetch_to_gemm_trig + self.to_math_trig = to_math_trig + self.to_dmaWB_trig = to_dmaWB_trig + self.completion_q = completion_q + self._tcm_port = tcm_port + + self._mac_m = mac_m + self._mac_k = mac_k + self._mac_n = mac_n + self._bpe = bytes_per_element + self._freq = clock_freq_ghz + + self._descriptor_table: dict[tuple[int, int], GemmDescriptor] = {} + + # Per-pipeline timing histograms + self.t_tcm_load_per_tile: dict[int, list[float]] = {} + self.t_compute_per_tile: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, GemmDescriptor]) -> None: + """Pre-load per-tile GEMM descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _compute_ns(self, desc: GemmDescriptor) -> float: + """MAC array compute time for one tile (ns).""" + cycles = ceil(desc.Tm / self._mac_m) * ceil(desc.Tk / self._mac_k) * ceil(desc.Tn / self._mac_n) + return cycles / self._freq + + # -- Stage 1: Fetch (TCM → MAC load) -------------------------------------- + + def _fetch_stage(self): + """Collect DMA triggers per tile, issue TCM read for operand load.""" + pending: dict[tuple[int, int], list[Trigger]] = {} + while True: + trigger = yield self.trig_in.get() + if trigger is None: + yield self.fetch_to_gemm_trig.put(None) + break + + key = (trigger.pipeline_id, trigger.tile_id) + pending.setdefault(key, []).append(trigger) + + desc = self._descriptor_table.get(key) + needed = desc.triggers_needed if desc else 2 + if len(pending[key]) < needed: + continue + + del pending[key] + + # TCM load: read A and B tile data from SPMem → MAC registers + if desc and desc.gemm_load: + a_bytes = desc.Tm * desc.Tk * self._bpe + b_bytes = desc.Tk * desc.Tn * self._bpe + load_bytes = a_bytes + b_bytes + + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("read", load_bytes, done, tag="gemm_load")) + yield done + self.t_tcm_load_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + yield self.fetch_to_gemm_trig.put(trigger) + + # -- Stage 2: Compute (MAC array) + Store (MAC → TCM) --------------------- + + def _gemm_stage(self): + """MAC computation, then TCM store, then route to next block.""" + while True: + trigger = yield self.fetch_to_gemm_trig.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._descriptor_table.get(key) + + # MAC compute + if desc and desc.gemm_compute: + t_compute = self._compute_ns(desc) + t0 = self.env.now + if t_compute > 0: + yield self.env.timeout(t_compute) + self.t_compute_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # Route output + route = desc.next_block if desc else "MATH" + out_trig = Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="GEMM", + ) + + if route == "MATH": + yield self.to_math_trig.put(out_trig) + elif route == "DMAWB": + # TCM store before writeback + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store")) + yield done + yield self.to_dmaWB_trig.put(out_trig) + else: # "DONE" — C stays in SPMem, no flush + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store")) + yield done + yield self.completion_q.put(out_trig) diff --git a/src/kernbench/components/custom/pe_accel/blocks/math.py b/src/kernbench/components/custom/pe_accel/blocks/math.py new file mode 100644 index 0000000..e1f5305 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/math.py @@ -0,0 +1,181 @@ +"""Math Block: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.). + +Two concurrent processing modes: + + 1. K-accumulation (_run_k_accumulation): + Receives triggers from GemmBlock after each K-tile compute. + Issues TCM write for partial-result store. + On final K-tile, routes to DMA_WB or completion. + + 2. Element-wise ops (_run_element_wise): + Receives triggers from DMA_IN after each tile read. + Issues TCM read (load input), compute (SIMD), TCM write (store result). + Routes output to DMA_WB for writeback. + +TCM access goes through TcmBlock for real BW serialization. +SIMD compute time: ceil(num_elements / vector_width) / clock_freq. +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest +from kernbench.components.custom.pe_accel.types import MathDescriptor, MathOpDescriptor, Trigger + + +class MathBlock: + """K-accumulation + element-wise math unit. + + Descriptor tables: + - _accum_table: MathDescriptor (pipeline_id, tile_id) — for GEMM K-accumulation + - _elemwise_table: MathOpDescriptor (pipeline_id, tile_id) — for element-wise ops + """ + + def __init__( + self, + env: simpy.Environment, + trig_in: simpy.Store, + to_dmaWB_trig: simpy.Store, + completion_q: simpy.Store, + *, + tcm_port: simpy.Store, + bytes_per_element: int, + clock_freq_ghz: float, + vector_width: int = 256, + ) -> None: + self.env = env + self.trig_in = trig_in # from GemmBlock (K-accumulation) + self.to_dmaWB_trig = to_dmaWB_trig + self.completion_q = completion_q + self._tcm_port = tcm_port + + self._bpe = bytes_per_element + self._freq = clock_freq_ghz + self._vector_width = vector_width + + # Descriptor tables + self._accum_table: dict[tuple[int, int], MathDescriptor] = {} + self._elemwise_table: dict[tuple[int, int], MathOpDescriptor] = {} + + # -- Timing histograms (per pipeline_id) -- + + # K-accumulation + self.t_tcm_store_per_tile: dict[int, list[float]] = {} + + # Element-wise ops + self.t_math_op_load_per_tile: dict[int, list[float]] = {} + self.t_math_op_compute_per_tile: dict[int, list[float]] = {} + self.t_math_op_store_per_tile: dict[int, list[float]] = {} + + # -- Descriptor loading ---------------------------------------------------- + + def load_descriptors(self, descs: dict[tuple, MathDescriptor]) -> None: + """Pre-load K-accumulation descriptors (cumulative across pipelines).""" + self._accum_table.update(descs) + + def load_math_op_descriptors(self, descs: dict[tuple, MathOpDescriptor]) -> None: + """Pre-load element-wise op descriptors (cumulative across pipelines).""" + self._elemwise_table.update(descs) + + # -- Mode 1: K-accumulation ------------------------------------------------ + + def _run(self): + """Backward-compat alias.""" + yield from self._run_k_accumulation() + + def _run_k_accumulation(self): + """Receive GEMM output triggers, TCM store partial result, flush on last-K.""" + while True: + trigger = yield self.trig_in.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._accum_table.get(key) + + # TCM store: write partial sum to SPMem + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="k_accum_store")) + yield done + self.t_tcm_store_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + if not desc or not desc.is_last_k: + continue # intermediate K-tile: store done, no flush yet + + out_trig = Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="MATH", + ) + if desc.skip_dmaWB: + yield self.completion_q.put(out_trig) + else: + yield self.to_dmaWB_trig.put(out_trig) + + # -- Mode 2: Element-wise ops ---------------------------------------------- + + def _run_math_op(self, trig_q: simpy.Store): + """Backward-compat alias.""" + yield from self._run_element_wise(trig_q) + + def _run_element_wise(self, trig_q: simpy.Store): + """Receive DMA_IN triggers, apply element-wise op via TCM, route to DMA_WB. + + Per tile: + 1. TCM read — load input tile from SPMem to SIMD + 2. Compute — SIMD operation (exp/log/etc.) + 3. TCM write — store result from SIMD to SPMem + 4. Route to DMA_WB + """ + while True: + trigger = yield trig_q.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._elemwise_table.get(key) + + if desc: + tile_bytes = desc.Tm * desc.Tn * self._bpe + num_elements = desc.Tm * desc.Tn + + # 1. TCM read + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("read", tile_bytes, done, tag="elemwise_load")) + yield done + self.t_math_op_load_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # 2. SIMD compute + t0 = self.env.now + compute_cycles = ceil(num_elements / self._vector_width) + compute_ns = compute_cycles / self._freq + if compute_ns > 0: + yield self.env.timeout(compute_ns) + self.t_math_op_compute_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # 3. TCM write + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", tile_bytes, done, tag="elemwise_store")) + yield done + self.t_math_op_store_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + yield self.to_dmaWB_trig.put(Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="MATH_OP", + )) diff --git a/src/kernbench/components/custom/pe_accel/blocks/tcm.py b/src/kernbench/components/custom/pe_accel/blocks/tcm.py new file mode 100644 index 0000000..d5d098b --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/blocks/tcm.py @@ -0,0 +1,80 @@ +"""TCM Block: tightly-coupled memory with BW-based access serialization. + +Models SPMem (scratchpad memory) inside the PE. Compute blocks (GEMM, MATH) +send TcmRequests for load/store operations. The TCM block serializes access +per channel and computes timing based on data size and bandwidth. + +Two channels: + - READ (SPMem → compute unit): models operand fetch for MAC/SIMD + - WRITE (compute unit → SPMem): models result store from MAC/SIMD + +Each channel has capacity=1: concurrent reads serialize, concurrent writes +serialize, but a read and a write can proceed in parallel. +""" +from __future__ import annotations + +from dataclasses import dataclass + +import simpy + + +@dataclass +class TcmRequest: + """Request to read from or write to TCM.""" + + direction: str # "read" or "write" + nbytes: int + done: simpy.Event + tag: str = "" # optional label for debugging + + +class TcmBlock: + """BW-serialized TCM model with dual read/write channels. + + Args: + env: SimPy environment. + read_bw_gbs: read bandwidth in GB/s (SPMem → compute). + write_bw_gbs: write bandwidth in GB/s (compute → SPMem). + """ + + def __init__( + self, + env: simpy.Environment, + read_bw_gbs: float = 512.0, + write_bw_gbs: float = 512.0, + ) -> None: + self.env = env + self._read_bw = read_bw_gbs + self._write_bw = write_bw_gbs + self._read_res = simpy.Resource(env, capacity=1) + self._write_res = simpy.Resource(env, capacity=1) + self._port: simpy.Store = simpy.Store(env) + + @property + def port(self) -> simpy.Store: + """The SimPy Store that blocks send TcmRequests to.""" + return self._port + + def _run(self): + """Main process: receive TcmRequests, dispatch to channel processes.""" + while True: + req: TcmRequest = yield self._port.get() + if req is None: + break + self.env.process(self._handle(req)) + + def _handle(self, req: TcmRequest): + """Acquire channel, apply BW-based delay, signal done.""" + if req.direction == "write": + res = self._write_res + bw = self._write_bw + else: + res = self._read_res + bw = self._read_bw + + with res.request() as lock: + yield lock + if bw > 0 and req.nbytes > 0: + delay_ns = req.nbytes / bw + yield self.env.timeout(delay_ns) + req.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel/scheduler/__init__.py b/src/kernbench/components/custom/pe_accel/scheduler/__init__.py new file mode 100644 index 0000000..885fc36 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/scheduler/__init__.py @@ -0,0 +1,10 @@ +"""Scheduler: accelerator component + dispatch + tiling pipelines. + + scheduler.py — SchedulerV2Component (init, wiring, dispatch, metrics) + gemm_pipeline.py — GemmPipeline (tiled GEMM coordinator) + math_pipeline.py — MathPipeline (tiled element-wise math coordinator) + tile_address.py — per-tile address computation +""" +from kernbench.components.custom.pe_accel.scheduler.scheduler import SchedulerV2Component + +__all__ = ["SchedulerV2Component"] diff --git a/src/kernbench/components/custom/pe_accel/scheduler/gemm_pipeline.py b/src/kernbench/components/custom/pe_accel/scheduler/gemm_pipeline.py new file mode 100644 index 0000000..afce12d --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/scheduler/gemm_pipeline.py @@ -0,0 +1,157 @@ +"""GEMM Tiling Pipeline: splits (M,K)×(K,N) into tiles and coordinates execution. + +Flow per tile: + DMA_IN(A tile) + DMA_IN(B tile) → GEMM(fetch + compute) → MATH(K-accum) → DMA_WB + +The pipeline does NOT own hardware blocks — it uses the component's shared +blocks via descriptor tables and SimPy queues. + +Constructor starts two SimPy processes: + - _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q + - _collect_completions(): waits for all output tiles to flush +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.scheduler.tiling import generate_gemm_tiles +from kernbench.components.custom.pe_accel.types import ( + CmdType, + DmaInDescriptor, + DmaRequest, + DmaWBDescriptor, + GemmDescriptor, + MathDescriptor, + Trigger, +) + + +class GemmPipeline: + """Coordinates one tiled GEMM operation across shared hardware blocks.""" + + def __init__( + self, + env: simpy.Environment, + M: int, K: int, N: int, + tile_m: int, tile_k: int, tile_n: int, + bytes_per_element: int, + pipeline_id: int, + reply_queue: simpy.Store, + dmaIN_cmd_q: simpy.Store, + dmaIN_to_fetch_trig: simpy.Store, + A_addr: int = 0, + B_addr: int = 0, + C_addr: int = 0, + dma_a: bool = True, + dma_b: bool = True, + dma_c: bool = True, + ) -> None: + self.env = env + self.M, self.K, self.N = M, K, N + self.pipeline_id = pipeline_id + self.reply_queue = reply_queue + self.dmaIN_cmd_q = dmaIN_cmd_q + self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig + + self._dma_a = dma_a + self._dma_b = dma_b + self._skip_dmaWB = not dma_c + + _Tm = min(tile_m, M) + _Tk = min(tile_k, K) + _Tn = min(tile_n, N) + + self.M_tiles = ceil(M / tile_m) + self.K_tiles = ceil(K / tile_k) + self.N_tiles = ceil(N / tile_n) + + triggers_per_tile = 2 if (dma_a and dma_b) else 1 + + # Generate tile schedule with pre-computed addresses + self.schedule = generate_gemm_tiles( + self.M_tiles, self.K_tiles, self.N_tiles, + M=M, K=K, N=N, + tile_m=_Tm, tile_k=_Tk, tile_n=_Tn, + bytes_per_element=bytes_per_element, + A_addr=A_addr, B_addr=B_addr, C_addr=C_addr, + pipeline_id=pipeline_id, + ) + + # Build descriptor tables for shared blocks + pid = pipeline_id + a_tile_bytes = _Tm * _Tk * bytes_per_element + b_tile_bytes = _Tk * _Tn * bytes_per_element + + self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {} + self.gemm_descs: dict[tuple, GemmDescriptor] = {} + self.math_descs: dict[tuple, MathDescriptor] = {} + self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {} + + for cmd in self.schedule.commands: + if cmd.cmd_type != CmdType.DMA_LOAD: + continue + t = cmd.tile_id + + if dma_a: + self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor( + size_bytes=a_tile_bytes, src_addr=cmd.a_tile_addr + ) + if dma_b: + self.dmaIN_descs[(pid, t, "B")] = DmaInDescriptor( + size_bytes=b_tile_bytes, src_addr=cmd.b_tile_addr + ) + + self.gemm_descs[(pid, t)] = GemmDescriptor( + Tm=_Tm, Tk=_Tk, Tn=_Tn, + triggers_needed=triggers_per_tile, + next_block="MATH", + ) + + self.math_descs[(pid, t)] = MathDescriptor( + Tm=_Tm, Tn=_Tn, + is_last_k=cmd.is_last_k, + skip_dmaWB=self._skip_dmaWB, + ) + if not self._skip_dmaWB and cmd.is_last_k: + self.dmaWB_descs[(pid, t)] = DmaWBDescriptor( + Tm=_Tm, Tn=_Tn, dst_addr=cmd.c_tile_addr + ) + + self.expected_flushes = self.M_tiles * self.N_tiles + self.completed_flushes = 0 + self.done_at: int = 0 + self.done: simpy.Event = env.event() + + env.process(self._feed_commands()) + env.process(self._collect_completions()) + + def _feed_commands(self): + """Send DmaRequests for each tile's operands to dmaIN_cmd_q.""" + for cmd in self.schedule.commands: + if cmd.cmd_type != CmdType.DMA_LOAD: + continue + + if self._dma_a: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A", + )) + if self._dma_b: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="B", + )) + + if not self._dma_a and not self._dma_b: + yield self.dmaIN_to_fetch_trig.put(Trigger( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE", + )) + + def _collect_completions(self): + """Wait for all output tile flush completions, then signal done.""" + while self.completed_flushes < self.expected_flushes: + yield self.reply_queue.get() + self.completed_flushes += 1 + + self.done_at = int(self.env.now) + self.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel/scheduler/math_pipeline.py b/src/kernbench/components/custom/pe_accel/scheduler/math_pipeline.py new file mode 100644 index 0000000..146b9b5 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/scheduler/math_pipeline.py @@ -0,0 +1,132 @@ +"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution. + +Flow per tile: + DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile) + +Mirrors GemmTilingPipeline but for unary element-wise operations. +Pipeline overlap across tiles: while one tile is in MATH_OP, the next +tile's DMA_IN can proceed concurrently. + +Constructor starts two SimPy processes: + - _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q + - _collect_completions(): waits for all tiles to writeback +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles +from kernbench.components.custom.pe_accel.types import ( + DmaInDescriptor, + DmaRequest, + DmaWBDescriptor, + MathOpDescriptor, + Trigger, +) + + +class MathPipeline: + """Coordinates one tiled element-wise math operation across shared blocks.""" + + def __init__( + self, + env: simpy.Environment, + M: int, N: int, + tile_m: int, tile_n: int, + bytes_per_element: int, + pipeline_id: int, + reply_queue: simpy.Store, + dmaIN_cmd_q: simpy.Store, + dmaIN_to_fetch_trig: simpy.Store, + op: str = "exp", + src_addr: int = 0, + dst_addr: int = 0, + dma_in: bool = True, + dma_out: bool = True, + ) -> None: + self.env = env + self.M, self.N = M, N + self.pipeline_id = pipeline_id + self.reply_queue = reply_queue + self.dmaIN_cmd_q = dmaIN_cmd_q + self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig + self.op = op + + self._dma_in = dma_in + self._skip_dmaWB = not dma_out + + _Tm = min(tile_m, M) + _Tn = min(tile_n, N) + + self.M_tiles = ceil(M / tile_m) + self.N_tiles = ceil(N / tile_n) + + # Generate tile schedule with pre-computed addresses + self.schedule = generate_math_tiles( + self.M_tiles, self.N_tiles, + M=M, N=N, + tile_m=_Tm, tile_n=_Tn, + bytes_per_element=bytes_per_element, + src_addr=src_addr, dst_addr=dst_addr, + pipeline_id=pipeline_id, + ) + + # Build descriptor tables for shared blocks + pid = pipeline_id + tile_bytes = _Tm * _Tn * bytes_per_element + + self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {} + self.math_op_descs: dict[tuple, MathOpDescriptor] = {} + self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {} + + for cmd in self.schedule.commands: + t = cmd.tile_id + + if dma_in: + self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor( + size_bytes=tile_bytes, + src_addr=cmd.src_tile_addr, + next_block="MATH", + ) + + self.math_op_descs[(pid, t)] = MathOpDescriptor( + Tm=_Tm, Tn=_Tn, op=op, + src_addr=cmd.src_tile_addr, + dst_addr=cmd.dst_tile_addr, + ) + + if not self._skip_dmaWB: + self.dmaWB_descs[(pid, t)] = DmaWBDescriptor( + Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr, + ) + + self.expected_flushes = self.M_tiles * self.N_tiles + self.completed_flushes = 0 + self.done_at: int = 0 + self.done: simpy.Event = env.event() + + env.process(self._feed_commands()) + env.process(self._collect_completions()) + + def _feed_commands(self): + """Send DmaRequests for each tile's input to dmaIN_cmd_q.""" + for cmd in self.schedule.commands: + if self._dma_in: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A", + )) + else: + yield self.dmaIN_to_fetch_trig.put(Trigger( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE", + )) + + def _collect_completions(self): + """Wait for all tile completions, then signal done.""" + while self.completed_flushes < self.expected_flushes: + yield self.reply_queue.get() + self.completed_flushes += 1 + + self.done_at = int(self.env.now) + self.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel/scheduler/scheduler.py b/src/kernbench/components/custom/pe_accel/scheduler/scheduler.py new file mode 100644 index 0000000..a803193 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/scheduler/scheduler.py @@ -0,0 +1,434 @@ +"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2). + +Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2). + +Hosts four internal hardware blocks as concurrent SimPy processes: + - DmaInBlock — HBM → TCM tile reads (real fabric DMA) + - DmaWbBlock — TCM → HBM tile writes (real fabric DMA) + - GemmBlock — 2-stage MAC pipeline (fetch + compute) + - MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.) + +Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline: + - DmaReadCmd / DmaWriteCmd → PE_DMA out_port + - GemmCmd → PE_GEMM out_port + - MathCmd → PE_MATH out_port + - CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum) + - CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA) + - PeCpuOverheadCmd → yield timeout + +Config via node.attrs (all optional): + overhead_ns, clock_freq_ghz, bytes_per_element, + mac_m, mac_k, mac_n, tile_m, tile_k, tile_n, + port_a_bw, port_b_bw, store_port_bw, vector_width +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +# ============================================================================== +# Component +# ============================================================================== + + +class SchedulerV2Component(ComponentBase): + """PE accelerator scheduler: wires internal blocks and dispatches commands.""" + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._pe_prefix: str = node.id.rsplit(".", 1)[0] + attrs = node.attrs + + # Hardware config + self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0)) + self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0)) + self._bpe: int = int(attrs.get("bytes_per_element", 2)) + + # MAC array dimensions + self._mac_m: int = int(attrs.get("mac_m", 8)) + self._mac_k: int = int(attrs.get("mac_k", 16)) + self._mac_n: int = int(attrs.get("mac_n", 32)) + + # Tile dimensions + self._tile_m: int = int(attrs.get("tile_m", 32)) + self._tile_k: int = int(attrs.get("tile_k", 32)) + self._tile_n: int = int(attrs.get("tile_n", 32)) + + # Bandwidth (bytes/cycle) + self._port_a_bw: int = int(attrs.get("port_a_bw", 256)) + self._port_b_bw: int = int(attrs.get("port_b_bw", 256)) + self._store_port_bw: int = int(attrs.get("store_port_bw", 256)) + self._vector_width: int = int(attrs.get("vector_width", 256)) + + # Initialized in start() + self._dmaIN_block: Any = None + self._dmaWB_block: Any = None + self._gemm_block: Any = None + self._math_block: Any = None + self._dmaIN_cmd_q: simpy.Store | None = None + self._dmaIN_to_fetch_trig: simpy.Store | None = None + + # Pipeline tracking + self._next_pipeline_id: int = 0 + self._pipeline_queues: dict[int, simpy.Store] = {} + + # -- SimPy lifecycle ------------------------------------------------------- + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + """Scheduler overhead per dispatch.""" + yield env.timeout(self._overhead_ns) + + def start(self, env: simpy.Environment) -> None: + """Create internal blocks, wire queues, start SimPy processes.""" + from kernbench.components.custom.pe_accel.blocks import ( + DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock, + ) + + pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma") + + # -- TCM block (shared BW-serialized scratchpad) ----------------------- + + # Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm) + pe_links = {} + if self.ctx and self.ctx.spec: + pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {}) + tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0)) + tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0)) + + self._tcm_block = TcmBlock( + env=env, + read_bw_gbs=tcm_read_bw, + write_bw_gbs=tcm_write_bw, + ) + + # -- Internal queues --------------------------------------------------- + + self._dmaIN_cmd_q = simpy.Store(env) + self._dmaIN_to_fetch_trig = simpy.Store(env) + dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math + fetch_to_gemm_trig = simpy.Store(env) + gemm_to_math_trig = simpy.Store(env) + gemm_to_dmaWB_trig = simpy.Store(env) + math_to_dmaWB_trig = simpy.Store(env) + + # Completion queues (block → _completion_router → pipeline reply_q) + dmaIN_completion_q = simpy.Store(env) + dmaWB_completion_q = simpy.Store(env) + gemm_completion_q = simpy.Store(env) + math_completion_q = simpy.Store(env) + + # -- Create blocks ----------------------------------------------------- + + self._dmaIN_block = DmaInBlock( + env=env, + cmd_q=self._dmaIN_cmd_q, + to_fetch_trig=self._dmaIN_to_fetch_trig, + to_math_trig=dmaIN_to_math_trig, + completion_q=dmaIN_completion_q, + pe_dma_port=pe_dma_port, + pe_prefix=self._pe_prefix, + ) + + self._dmaWB_block = DmaWbBlock( + env=env, + completion_q=dmaWB_completion_q, + pe_dma_port=pe_dma_port, + pe_prefix=self._pe_prefix, + bytes_per_element=self._bpe, + ) + + self._gemm_block = GemmBlock( + env=env, + trig_in=self._dmaIN_to_fetch_trig, + fetch_to_gemm_trig=fetch_to_gemm_trig, + to_math_trig=gemm_to_math_trig, + to_dmaWB_trig=gemm_to_dmaWB_trig, + completion_q=gemm_completion_q, + tcm_port=self._tcm_block.port, + mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n, + bytes_per_element=self._bpe, + clock_freq_ghz=self._clock_freq_ghz, + ) + + self._math_block = MathBlock( + env=env, + trig_in=gemm_to_math_trig, + to_dmaWB_trig=math_to_dmaWB_trig, + completion_q=math_completion_q, + tcm_port=self._tcm_block.port, + bytes_per_element=self._bpe, + clock_freq_ghz=self._clock_freq_ghz, + vector_width=self._vector_width, + ) + + # -- Start block processes --------------------------------------------- + + env.process(self._tcm_block._run()) + env.process(self._dmaIN_block._load_loop()) + env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig)) + env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig)) + env.process(self._gemm_block._fetch_stage()) + env.process(self._gemm_block._gemm_stage()) + env.process(self._math_block._run_k_accumulation()) + env.process(self._math_block._run_element_wise(dmaIN_to_math_trig)) + + # Wire in-ports → inbox, start _worker + super().start(env) + + # Start completion routers + for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q): + env.process(self._completion_router(q)) + + # -- Internal processes ---------------------------------------------------- + + def _completion_router(self, queue: simpy.Store) -> Generator: + """Route block completion triggers to the correct pipeline's reply queue.""" + while True: + trigger = yield queue.get() + if trigger is None: + break + reply_q = self._pipeline_queues.get(trigger.pipeline_id) + if reply_q is not None: + yield reply_q.put(trigger) + + def _worker(self, env: simpy.Environment) -> Generator: + """Main inbox loop: dispatch PE commands, forward fabric transactions.""" + from kernbench.common.pe_commands import PeInternalTxn + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, PeInternalTxn): + env.process(self._dispatch(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + # ========================================================================== + # Command Dispatch + # ========================================================================== + + def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator: + """Route a PeInternalTxn to the appropriate engine or pipeline.""" + from kernbench.common.pe_commands import ( + CompositeCmd, DmaReadCmd, DmaWriteCmd, + GemmCmd, MathCmd, PeCpuOverheadCmd, + ) + + yield from self.run(env, 0) # scheduler overhead + + cmd = pe_txn.command + pp = self._pe_prefix + + if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)): + yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn) + + elif isinstance(cmd, GemmCmd): + yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn) + + elif isinstance(cmd, MathCmd): + yield self.out_ports[f"{pp}.pe_math"].put(pe_txn) + + elif isinstance(cmd, CompositeCmd): + if cmd.op == "gemm" and cmd.b is not None: + yield from self._dispatch_composite_gemm(env, pe_txn, cmd) + else: + yield from self._dispatch_composite_math(env, pe_txn, cmd) + + elif isinstance(cmd, PeCpuOverheadCmd): + yield env.timeout(cmd.cycles / self._clock_freq_ghz) + pe_txn.done.succeed() + + else: + pe_txn.done.succeed() + + # -- GEMM composite -------------------------------------------------------- + + def _dispatch_composite_gemm( + self, env: simpy.Environment, pe_txn: Any, cmd: Any, + ) -> Generator: + """Run tiled GEMM pipeline and collect per-stage metrics.""" + from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline + + a, b = cmd.a, cmd.b + M, K, N = a.shape[-2], a.shape[-1], b.shape[-1] + + pid, reply_q = self._alloc_pipeline(env) + + pipeline = GemmPipeline( + env=env, M=M, K=K, N=N, + tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n, + bytes_per_element=self._bpe, + pipeline_id=pid, + reply_queue=reply_q, + dmaIN_cmd_q=self._dmaIN_cmd_q, + dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, + A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr, + ) + + # Load descriptors into shared blocks + self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) + self._gemm_block.load_descriptors(pipeline.gemm_descs) + self._math_block.load_descriptors(pipeline.math_descs) + self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) + + start_ns = env.now + yield pipeline.done + wall_clock_ns = env.now - start_ns + del self._pipeline_queues[pid] + + # -- Collect metrics --------------------------------------------------- + rd = pe_txn.result_data + rd["M_tiles"] = pipeline.M_tiles + rd["K_tiles"] = pipeline.K_tiles + rd["N_tiles"] = pipeline.N_tiles + rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles + rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles + rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values()) + rd["total_dma_write_bytes"] = sum( + d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values() + ) + rd["wall_clock_ns"] = wall_clock_ns + + # Per-stage timing + rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = ( + _collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, [])) + ) + rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = ( + _collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, [])) + ) + rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = ( + _collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, [])) + ) + rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( + _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) + ) + rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( + _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) + ) + + # Derived metrics + rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0 + rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0 + rd["pipeline_parallelism"] = ( + (rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"]) + / wall_clock_ns if wall_clock_ns > 0 else 0.0 + ) + rd["effective_read_bw_gbs"] = ( + rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"] + if rd["t_dma_read_total_ns"] > 0 else None + ) + rd["effective_write_bw_gbs"] = ( + rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"] + if rd["t_dma_write_total_ns"] > 0 else None + ) + rd["bottleneck_stage"] = max( + [("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]), + ("store", rd["t_tcm_store_max_ns"])], + key=lambda x: x[1], + )[0] + + # pe_cpu.py compatibility aliases + rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] + rd["compute_ns"] = rd["t_compute_total_ns"] + + pe_txn.done.succeed() + + # -- Math composite -------------------------------------------------------- + + def _dispatch_composite_math( + self, env: simpy.Environment, pe_txn: Any, cmd: Any, + ) -> Generator: + """Run tiled element-wise math pipeline and collect metrics.""" + from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline + + assert self._dmaIN_cmd_q is not None + assert self._dmaIN_to_fetch_trig is not None + + a = cmd.a + M = a.shape[-2] if len(a.shape) >= 2 else 1 + N = a.shape[-1] + op = cmd.math_op or "identity" + + pid, reply_q = self._alloc_pipeline(env) + + pipeline = MathPipeline( + env=env, M=M, N=N, + tile_m=self._tile_m, tile_n=self._tile_n, + bytes_per_element=self._bpe, + pipeline_id=pid, + reply_queue=reply_q, + dmaIN_cmd_q=self._dmaIN_cmd_q, + dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, + op=op, + src_addr=a.addr, dst_addr=cmd.out_addr, + ) + + # Load descriptors into shared blocks + self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) + self._math_block.load_math_op_descriptors(pipeline.math_op_descs) + self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) + + start_ns = env.now + yield pipeline.done + wall_clock_ns = env.now - start_ns + del self._pipeline_queues[pid] + + # -- Collect metrics --------------------------------------------------- + rd = pe_txn.result_data + rd["M_tiles"] = pipeline.M_tiles + rd["N_tiles"] = pipeline.N_tiles + rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles + rd["wall_clock_ns"] = wall_clock_ns + rd["math_op"] = op + + # DMA timing + rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( + _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) + ) + rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( + _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) + ) + + # Math op timing (load + compute + store) + rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, [])) + ) + rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, [])) + ) + rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, [])) + ) + + # pe_cpu.py compatibility aliases + rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] + rd["compute_ns"] = rd["t_math_compute_total_ns"] + + pe_txn.done.succeed() + + # -- Helpers --------------------------------------------------------------- + + def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]: + """Allocate a pipeline ID and reply queue.""" + pid = self._next_pipeline_id + self._next_pipeline_id += 1 + reply_q = simpy.Store(env) + self._pipeline_queues[pid] = reply_q + return pid, reply_q + + +# ============================================================================== +# Utility +# ============================================================================== + +def _collect_timing(times: list) -> tuple[list, float, float]: + """Return (raw_list, max_ns, total_ns) from a timing list.""" + return times, max(times, default=0.0), sum(times) diff --git a/src/kernbench/components/custom/pe_accel/scheduler/tiling.py b/src/kernbench/components/custom/pe_accel/scheduler/tiling.py new file mode 100644 index 0000000..c84f577 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/scheduler/tiling.py @@ -0,0 +1,121 @@ +"""Tile schedule generators for GEMM and element-wise math operations. + +Each generator produces a plan of tile commands with pre-computed addresses. +Pipelines use these plans to build descriptor tables and feed commands +to the shared hardware blocks. +""" +from __future__ import annotations + +from math import ceil + +from kernbench.components.custom.pe_accel.types import ( + CmdType, + MathSchedulePlan, + MathTileCommand, + SchedulePlan, + TileCommand, +) + + +def generate_gemm_tiles( + M_tiles: int, K_tiles: int, N_tiles: int, + M: int = 0, K: int = 0, N: int = 0, + tile_m: int = 0, tile_k: int = 0, tile_n: int = 0, + bytes_per_element: int = 2, + A_addr: int = 0, B_addr: int = 0, C_addr: int = 0, + pipeline_id: int = 0, +) -> SchedulePlan: + """Generate GEMM tile commands in M (outer) -> N -> K (inner) order. + + Stamps is_last_k=True on the final K-tile per (m, n) pair. + Emits one DMA_FLUSH per (m, n) pair after all K tiles. + + Per-tile addresses (row-major layout): + A (M,K): A_addr + (m * tile_m * K + k * tile_k) * bpe + B (K,N): B_addr + (k * tile_k * N + n * tile_n) * bpe + C (M,N): C_addr + (m * tile_m * N + n * tile_n) * bpe + """ + commands: list[TileCommand] = [] + cmd_id = 0 + tile_id = 0 + bpe = bytes_per_element + + for m in range(M_tiles): + for n in range(N_tiles): + c_tile_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe + + for k in range(K_tiles): + last_k = k == K_tiles - 1 + a_tile_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe + b_tile_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe + + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.DMA_LOAD, + tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n, + is_last_k=last_k, pipeline_id=pipeline_id, + a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.TENSOR_OP, + tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n, + is_last_k=last_k, pipeline_id=pipeline_id, + a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + tile_id += 1 + + # One flush per (m, n) pair after all K tiles + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.DMA_FLUSH, + tile_id=tile_id - 1, m_idx=m, k_idx=0, n_idx=n, + pipeline_id=pipeline_id, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + + return SchedulePlan( + commands=commands, M_tiles=M_tiles, K_tiles=K_tiles, N_tiles=N_tiles + ) + + +def generate_math_tiles( + M_tiles: int, N_tiles: int, + M: int = 0, N: int = 0, + tile_m: int = 0, tile_n: int = 0, + bytes_per_element: int = 2, + src_addr: int = 0, dst_addr: int = 0, + pipeline_id: int = 0, +) -> MathSchedulePlan: + """Generate element-wise math tile commands in row-major order. + + Per-tile addresses (row-major layout): + src: src_addr + (m * tile_m * N + n * tile_n) * bpe + dst: dst_addr + (m * tile_m * N + n * tile_n) * bpe + """ + commands: list[MathTileCommand] = [] + cmd_id = 0 + tile_id = 0 + bpe = bytes_per_element + + for m in range(M_tiles): + for n in range(N_tiles): + offset = (m * tile_m * N + n * tile_n) * bpe + commands.append(MathTileCommand( + cmd_id=cmd_id, + tile_id=tile_id, + m_idx=m, + n_idx=n, + src_tile_addr=src_addr + offset, + dst_tile_addr=dst_addr + offset, + pipeline_id=pipeline_id, + )) + cmd_id += 1 + tile_id += 1 + + return MathSchedulePlan( + commands=commands, M_tiles=M_tiles, N_tiles=N_tiles + ) diff --git a/src/kernbench/components/custom/pe_accel/types.py b/src/kernbench/components/custom/pe_accel/types.py new file mode 100644 index 0000000..0e6babe --- /dev/null +++ b/src/kernbench/components/custom/pe_accel/types.py @@ -0,0 +1,148 @@ +"""Data types for pe_accel_v1: descriptors, triggers, tile commands. + +All types are frozen/plain dataclasses with no logic. +Schedule generators live in tiling/schedule.py. +""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, auto + + +# -- Enums --------------------------------------------------------------------- + +class CmdType(Enum): + DMA_LOAD = auto() + TENSOR_OP = auto() + DMA_FLUSH = auto() + + +# -- Inter-block messaging ----------------------------------------------------- + +@dataclass +class Trigger: + """Completion token passed between hardware blocks.""" + + tile_id: int + pipeline_id: int + vc: int | None = None + source_block: str = "" + + +@dataclass +class DmaRequest: + """DMA load request — descriptor lookup key only. + + Transfer params (size, address) live in the pre-loaded DmaInDescriptor. + """ + + tile_id: int + pipeline_id: int + operand: str # "A" or "B" + + +# -- Descriptors (pre-loaded by pipelines, consumed by blocks) ----------------- + +@dataclass +class DmaInDescriptor: + """Per-operand DMA read descriptor.""" + + size_bytes: int + src_addr: int = 0 + next_block: str = "GEMM" # "GEMM" | "MATH" | "COMPLETION" + + +@dataclass +class GemmDescriptor: + """Per-tile GEMM descriptor.""" + + Tm: int + Tk: int + Tn: int + triggers_needed: int = 2 # 2 = both operands from DMA; 1 = SPMem bypass + gemm_load: bool = True + gemm_compute: bool = True + next_block: str = "MATH" # "MATH" | "DMAWB" | "DONE" + + +@dataclass +class MathDescriptor: + """Per-tile K-accumulation descriptor (used by GEMM pipeline).""" + + Tm: int + Tn: int + is_last_k: bool + skip_dmaWB: bool # True = C stays in SPMem; False = flush to HBM + + +@dataclass +class MathOpDescriptor: + """Per-tile element-wise math op descriptor (used by math pipeline).""" + + Tm: int + Tn: int + op: str # "exp", "log", "sqrt", "sigmoid", etc. + src_addr: int = 0 + dst_addr: int = 0 + + +@dataclass +class DmaWBDescriptor: + """Per-tile DMA writeback descriptor.""" + + Tm: int + Tn: int + dst_addr: int = 0 + + +# -- Tile commands (produced by schedule generators) --------------------------- + +@dataclass +class TileCommand: + """A single GEMM tile command.""" + + cmd_id: int + cmd_type: CmdType + tile_id: int + m_idx: int + k_idx: int + n_idx: int + is_last_k: bool = False + pipeline_id: int = 0 + a_tile_addr: int = 0 + b_tile_addr: int = 0 + c_tile_addr: int = 0 + + +@dataclass +class SchedulePlan: + """Full tile schedule for one GEMM operation.""" + + commands: list # list[TileCommand] + M_tiles: int + K_tiles: int + N_tiles: int + + +@dataclass +class MathTileCommand: + """A single element-wise math tile command.""" + + cmd_id: int + tile_id: int + m_idx: int + n_idx: int + src_tile_addr: int = 0 + dst_tile_addr: int = 0 + pipeline_id: int = 0 + + +@dataclass +class MathSchedulePlan: + """Full tile schedule for one element-wise math operation.""" + + commands: list # list[MathTileCommand] + M_tiles: int + N_tiles: int + + diff --git a/src/kernbench/policy/placement/dp.py b/src/kernbench/policy/placement/dp.py index 705b791..5b0e01a 100644 --- a/src/kernbench/policy/placement/dp.py +++ b/src/kernbench/policy/placement/dp.py @@ -13,11 +13,19 @@ class DPPolicy: - "replicate": full copy at each unit - "column_wise": split K (column) axis across units - "row_wise": split M (row) axis across units + + Optional overrides (default None = use topology dimensions): + - num_pes: override PEs per cube (e.g., 1 for single-PE test) + - num_cubes: override cubes per SIP (e.g., 1 for single-cube test) + - num_sips: override SIP count """ sip: Literal["replicate", "column_wise", "row_wise"] = "replicate" cube: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" + num_pes: int | None = None + num_cubes: int | None = None + num_sips: int | None = None def _split_shape( diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index 7e94877..3fee065 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -269,15 +269,25 @@ class RuntimeContext: allocators = self._ensure_allocators() itemsize = dtype_itemsize(dtype) shape_2d = (shape[0], shape[1]) if len(shape) >= 2 else (1, shape[0]) + # DPPolicy overrides take precedence over topology dimensions + eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube + eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes + eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips placement = resolve_dp_policy( dp, shape=shape_2d, itemsize=itemsize, - num_pe=self._pes_per_cube, num_cubes=self._num_cubes, - num_sips=self._num_sips, + num_pe=eff_num_pe, num_cubes=eff_num_cubes, + num_sips=eff_num_sips, ) - # Infer target_pe from placement: multi-PE → "all", single PE → pe_index - pe_indices = {s.pe_index for s in placement} - target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices)) + # Infer target_pe from placement using local (within-cube) PE IDs. + # This ensures M_CPU only fans out to PEs that own shards, not all PEs. + local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement}) + if len(local_pe_ids) == 1: + target_pe: int | tuple[int, ...] | str = local_pe_ids[0] + elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube: + target_pe = "all" + else: + target_pe = tuple(local_pe_ids) t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy) # Allocate PAs via PEMemAllocator + VA via VirtualAllocator @@ -407,7 +417,8 @@ class RuntimeContext: # Collect tensors and scalars tensor_args: list[Tensor] = [] scalar_args: list = [] - target_pe: int | str = 0 + _pe_set: set[int] = set() + _pe_all = False for a in args: if isinstance(a, Tensor): @@ -415,9 +426,11 @@ class RuntimeContext: if a._dp_metadata is not None: dp_target = a._dp_metadata.target_pe if dp_target == "all": - target_pe = "all" - elif isinstance(dp_target, int) and target_pe != "all": - target_pe = dp_target + _pe_all = True + elif isinstance(dp_target, tuple): + _pe_set.update(dp_target) + elif isinstance(dp_target, int): + _pe_set.add(dp_target) elif isinstance(a, (int, float)): dtype_str = "f32" if isinstance(a, float) else "i32" scalar_args.append(ScalarArg(dtype=dtype_str, value=a)) @@ -427,6 +440,16 @@ class RuntimeContext: dtype_str = "f32" if isinstance(v, float) else "i32" scalar_args.append(ScalarArg(dtype=dtype_str, value=v)) + # Resolve target_pe from collected PE info + if _pe_all: + target_pe: int | tuple[int, ...] | str = "all" + elif len(_pe_set) == 1: + target_pe = next(iter(_pe_set)) + elif len(_pe_set) > 1: + target_pe = tuple(sorted(_pe_set)) + else: + target_pe = 0 + # Determine all target SIPs from tensor shards sip_set: set[int] = set() for t in tensor_args: diff --git a/src/kernbench/runtime_api/kernel.py b/src/kernbench/runtime_api/kernel.py index 3fc8624..acda736 100644 --- a/src/kernbench/runtime_api/kernel.py +++ b/src/kernbench/runtime_api/kernel.py @@ -89,7 +89,7 @@ class KernelLaunchMsg: kernel_ref: KernelRef args: tuple[KernelArg, ...] target_cubes: tuple[int, ...] | Literal["all"] = "all" - target_pe: int | Literal["all"] = "all" + target_pe: int | tuple[int, ...] | Literal["all"] = "all" msg_type: Literal["kernel_launch"] = "kernel_launch" diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 51369fe..88ff5a3 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -105,7 +105,7 @@ class DPMetadata: dp_policy: DPPolicy | None = None sip: int = 0 cube: int = 0 - target_pe: int | str = 0 # int → single PE, "all" → all PEs + target_pe: int | tuple[int, ...] | str = 0 # int → single PE, tuple → specific PEs, "all" → all PEs class Tensor: @@ -166,7 +166,7 @@ class Tensor: dp_policy: DPPolicy | None = None, sip: int = 0, cube: int = 0, - target_pe: int | str = 0, + target_pe: int | tuple[int, ...] | str = 0, ) -> Tensor: """Set DP placement metadata (like torch.Tensor.to()).""" if placement is None: diff --git a/topology.yaml b/topology.yaml index 9fce8f9..0104960 100644 --- a/topology.yaml +++ b/topology.yaml @@ -61,7 +61,7 @@ cube: pe_template: components: pe_cpu: { kind: pe_cpu, impl: pe_cpu_v1, attrs: { overhead_ns: 2.0 } } - pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v1, attrs: { overhead_ns: 1.0 } } + pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { overhead_ns: 1.0 } } pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } } pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } } pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }