From b6eb97c49abcac2d1ba7507dac57610bde923ba1 Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Wed, 8 Apr 2026 23:35:31 -0700 Subject: [PATCH] Implement ADR-0021: PE pipeline refactor with token self-routing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step 1-2: Backup existing code - builtin/ → builtin_legacy/ (unchanged backup) - custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup) Step 3-4: New pipeline types and tiling - pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken - tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel) Step 5: Component implementations (ADR-0021 D4-D6) - PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation - PE_FETCH_STORE: new component — TCM ↔ Register File - PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode - PE_MATH: TileToken pipeline + legacy dual-mode - PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode - PE_TCM: TcmRequest handler with dual-channel BW serialization Step 6: Infrastructure - topology.yaml: pe_fetch_store component + chaining edges - components.yaml: pe_fetch_store_v1 registration - builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions - Tests: node/edge counts, PE component sets updated All components handle both TileToken (pipeline) and PeInternalTxn (legacy). Token self-routing: components read next stage from token.plan, chain via out_port. 366 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) --- components.yaml | 15 +- src/kernbench/components/builtin/pe_dma.py | 67 +++ .../components/builtin/pe_fetch_store.py | 77 ++++ src/kernbench/components/builtin/pe_gemm.py | 93 +++- src/kernbench/components/builtin/pe_math.py | 64 ++- .../components/builtin/pe_scheduler.py | 268 ++++------- src/kernbench/components/builtin/pe_tcm.py | 69 ++- src/kernbench/components/builtin/pe_types.py | 115 +++++ src/kernbench/components/builtin/tiling.py | 176 +++++++ .../components/builtin_legacy/__init__.py | 34 ++ .../components/builtin_legacy/forwarding.py | 27 ++ .../components/builtin_legacy/hbm_ctrl.py | 129 ++++++ .../components/builtin_legacy/io_cpu.py | 157 +++++++ .../components/builtin_legacy/m_cpu.py | 327 +++++++++++++ .../components/builtin_legacy/pcie_ep.py | 27 ++ .../components/builtin_legacy/pe_cpu.py | 214 +++++++++ .../components/builtin_legacy/pe_dma.py | 138 ++++++ .../components/builtin_legacy/pe_gemm.py | 90 ++++ .../components/builtin_legacy/pe_math.py | 54 +++ .../components/builtin_legacy/pe_mmu.py | 66 +++ .../components/builtin_legacy/pe_scheduler.py | 245 ++++++++++ .../components/builtin_legacy/pe_tcm.py | 25 + .../components/builtin_legacy/sram.py | 59 +++ .../custom/pe_accel_legacy/__init__.py | 20 + .../custom/pe_accel_legacy/blocks/__init__.py | 16 + .../custom/pe_accel_legacy/blocks/dma_in.py | 96 ++++ .../custom/pe_accel_legacy/blocks/dma_wb.py | 88 ++++ .../custom/pe_accel_legacy/blocks/gemm.py | 160 +++++++ .../custom/pe_accel_legacy/blocks/math.py | 181 ++++++++ .../custom/pe_accel_legacy/blocks/tcm.py | 80 ++++ .../pe_accel_legacy/scheduler/__init__.py | 10 + .../scheduler/gemm_pipeline.py | 157 +++++++ .../scheduler/math_pipeline.py | 132 ++++++ .../pe_accel_legacy/scheduler/scheduler.py | 434 ++++++++++++++++++ .../pe_accel_legacy/scheduler/tiling.py | 121 +++++ .../custom/pe_accel_legacy/types.py | 148 ++++++ src/kernbench/topology/builder.py | 53 ++- tests/test_topology_compile.py | 13 +- tests/test_topology_load.py | 3 +- topology.yaml | 21 +- 40 files changed, 4055 insertions(+), 214 deletions(-) create mode 100644 src/kernbench/components/builtin/pe_fetch_store.py create mode 100644 src/kernbench/components/builtin/pe_types.py create mode 100644 src/kernbench/components/builtin/tiling.py create mode 100644 src/kernbench/components/builtin_legacy/__init__.py create mode 100644 src/kernbench/components/builtin_legacy/forwarding.py create mode 100644 src/kernbench/components/builtin_legacy/hbm_ctrl.py create mode 100644 src/kernbench/components/builtin_legacy/io_cpu.py create mode 100644 src/kernbench/components/builtin_legacy/m_cpu.py create mode 100644 src/kernbench/components/builtin_legacy/pcie_ep.py create mode 100644 src/kernbench/components/builtin_legacy/pe_cpu.py create mode 100644 src/kernbench/components/builtin_legacy/pe_dma.py create mode 100644 src/kernbench/components/builtin_legacy/pe_gemm.py create mode 100644 src/kernbench/components/builtin_legacy/pe_math.py create mode 100644 src/kernbench/components/builtin_legacy/pe_mmu.py create mode 100644 src/kernbench/components/builtin_legacy/pe_scheduler.py create mode 100644 src/kernbench/components/builtin_legacy/pe_tcm.py create mode 100644 src/kernbench/components/builtin_legacy/sram.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/dma_in.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/dma_wb.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/gemm.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/math.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/blocks/tcm.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/scheduler/__init__.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/scheduler/gemm_pipeline.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/scheduler/math_pipeline.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/scheduler/scheduler.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/scheduler/tiling.py create mode 100644 src/kernbench/components/custom/pe_accel_legacy/types.py diff --git a/components.yaml b/components.yaml index 1e22d69..bf48f89 100644 --- a/components.yaml +++ b/components.yaml @@ -38,13 +38,14 @@ components: sram_v1: kernbench.components.builtin.sram:SramComponent # PE-level - pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent - pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent - pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent - pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent - pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent - pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent - pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent + pe_cpu_v1: kernbench.components.builtin.pe_cpu:PeCpuComponent + pe_scheduler_v1: kernbench.components.builtin.pe_scheduler:PeSchedulerComponent + pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent + pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent + pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent + pe_fetch_store_v1: kernbench.components.builtin.pe_fetch_store:PeFetchStoreComponent + pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent + pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent # Custom — add your implementations here pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component diff --git a/src/kernbench/components/builtin/pe_dma.py b/src/kernbench/components/builtin/pe_dma.py index c8ee823..4412f8e 100644 --- a/src/kernbench/components/builtin/pe_dma.py +++ b/src/kernbench/components/builtin/pe_dma.py @@ -105,6 +105,73 @@ class PeDmaComponent(PeEngineBase): yield sub_done pe_txn.done.succeed() + def _worker(self, env: simpy.Environment) -> Generator: + """Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric).""" + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.builtin.pe_types import TileToken + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, TileToken): + env.process(self._pipeline_process(env, msg)) + elif isinstance(msg, PeInternalTxn): + env.process(self._handle_with_hooks(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: + """Pipeline mode: DMA read/write via fabric, then self-route.""" + from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, TensorHandle + from kernbench.policy.address.phyaddr import PhysAddr + from kernbench.runtime_api.kernel import PeDmaMsg + + self._on_process_start(env, token) + + params = token.params + stage_type = token.current_stage.stage_type + + from kernbench.components.builtin.pe_types import StageType + is_write = stage_type == StageType.DMA_WRITE + addr = params.get("dst_addr" if is_write else "src_addr", 0) + nbytes = params.get("nbytes", 0) + + if nbytes > 0 and self.ctx: + dma_res = self._dma_write if is_write else self._dma_read + assert dma_res is not None + + pa = PhysAddr.decode(addr) + dst_node = self.ctx.resolver.resolve(pa) + path = self.ctx.router.find_path(self._pe_prefix, dst_node) + drain_ns = self.ctx.compute_drain_ns(path, nbytes) + + with dma_res.request() as req: + yield req + sub_done = env.event() + sub_request = PeDmaMsg( + correlation_id="pipeline", + request_id=f"tile_{token.tile_id}", + src_sip=0, src_cube=0, src_pe=0, + dst_pa=addr, nbytes=nbytes, + is_write=is_write, + ) + sub_txn = Transaction( + request=sub_request, path=path, step=0, + nbytes=nbytes, done=sub_done, drain_ns=drain_ns, + ) + if len(path) > 1: + yield self.out_ports[path[1]].put(sub_txn.advance()) + + yield sub_done + + self._on_process_end(env, token) + + # Self-routing + next_stage = token.advance() + if next_stage is not None: + yield self.out_ports[next_stage.component].put(token) + else: + token.pipeline_ctx.complete_tile() + def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: """Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition.""" # Response transactions bypass DMA channel (no outbound resource needed) diff --git a/src/kernbench/components/builtin/pe_fetch_store.py b/src/kernbench/components/builtin/pe_fetch_store.py new file mode 100644 index 0000000..3d65e2c --- /dev/null +++ b/src/kernbench/components/builtin/pe_fetch_store.py @@ -0,0 +1,77 @@ +"""PE_FETCH_STORE: TCM ↔ Register File transfer unit (ADR-0021 D5). + +Handles both fetch (TCM → register) and store (register → TCM). +BW serialization is delegated to PE_TCM via port communication. +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import PeEngineBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeFetchStoreComponent(PeEngineBase): + """PE_FETCH_STORE: TCM ↔ Register File (ADR-0021 D5). + + Receives TileTokens via pipeline self-routing. + Sends TcmRequest to PE_TCM for BW-based latency. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._tcm_id = f"{self._pe_prefix}.pe_tcm" + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + """Handle both PeInternalTxn (legacy) and TileToken (pipeline).""" + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.builtin.pe_types import TileToken + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, TileToken): + env.process(self._pipeline_process(env, msg)) + elif isinstance(msg, PeInternalTxn): + env.process(self.handle_command(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: + """Process a pipeline TileToken: fetch or store via TCM.""" + from kernbench.components.builtin.pe_tcm import TcmRequest + + self._on_process_start(env, token) + + direction = token.params.get("direction", "read") + nbytes = token.params.get("nbytes", 0) + + if nbytes > 0 and self._tcm_id in self.out_ports: + done = env.event() + yield self.out_ports[self._tcm_id].put( + TcmRequest(direction=direction, nbytes=nbytes, done=done) + ) + yield done + + self._on_process_end(env, token) + + # Self-routing: advance to next stage + next_stage = token.advance() + if next_stage is not None: + yield self.out_ports[next_stage.component].put(token) + else: + token.pipeline_ctx.complete_tile() + + def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator: + """Legacy PeInternalTxn handling.""" + yield from self.run(env, 0) + pe_txn.done.succeed() diff --git a/src/kernbench/components/builtin/pe_gemm.py b/src/kernbench/components/builtin/pe_gemm.py index 3fc74e3..718d130 100644 --- a/src/kernbench/components/builtin/pe_gemm.py +++ b/src/kernbench/components/builtin/pe_gemm.py @@ -1,6 +1,18 @@ +"""PE_GEMM: matrix multiplication engine (ADR-0021 D6). + +Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken. +In pipeline mode, receives token after fetch stage, computes MAC, chains to next. + +MAC latency model (from pe_accel): + cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n) + latency_ns = cycles / clock_freq_ghz + +Falls back to TFLOPS model when mac dimensions not configured. +""" from __future__ import annotations from collections.abc import Generator +from math import ceil from typing import TYPE_CHECKING, Any import simpy @@ -12,33 +24,29 @@ if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node - -# dtype → bit width (for TFLOPS scaling) _DTYPE_BITS: dict[str, int] = { "f16": 16, "fp16": 16, "float16": 16, "bf16": 16, "f32": 32, "fp32": 32, "float32": 32, - "i8": 8, "int8": 8, - "i16": 16, "int16": 16, - "i32": 32, "int32": 32, + "i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32, } class PeGemmComponent(PeEngineBase): - """PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4). + """PE_GEMM: MAC array (ADR-0021 D6). - Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually - exclusive with PE_MATH within the same PE. - - Compute latency model: - FLOPs = 2 * M * K * N - effective_tflops = peak_tflops_f16 * (16 / dtype_bits) - compute_ns = FLOPs / (effective_tflops * 1e3) + In pipeline mode: pure compute — register data already fetched. + In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._accel: simpy.Resource | None = None self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0)) + # Cycle-accurate MAC dimensions (from pe_accel) + self._mac_m: int = int(node.attrs.get("mac_m", 0)) + self._mac_k: int = int(node.attrs.get("mac_k", 0)) + self._mac_n: int = int(node.attrs.get("mac_n", 0)) + self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0)) def init_resources(self, env: simpy.Environment) -> None: resource_name = self.node.attrs.get("shared_resource") @@ -47,8 +55,15 @@ class PeGemmComponent(PeEngineBase): env, f"{self._pe_prefix}.{resource_name}" ) - def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float: - """Compute GEMM latency in nanoseconds.""" + def _compute_ns_mac(self, m: int, k: int, n: int) -> float: + """Cycle-accurate MAC latency (pe_accel model).""" + if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0: + cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n) + return cycles / self._clock_freq + return 0.0 + + def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float: + """TFLOPS-based latency (legacy model).""" if self._peak_tflops_f16 <= 0: return float(self.node.attrs.get("overhead_ns", 0.0)) dtype_bits = _DTYPE_BITS.get(dtype, 16) @@ -56,11 +71,58 @@ class PeGemmComponent(PeEngineBase): flops = 2.0 * m * k * n return flops / (effective_tflops * 1e3) + def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float: + """Choose best available latency model.""" + mac_ns = self._compute_ns_mac(m, k, n) + if mac_ns > 0: + return mac_ns + return self._compute_ns_tflops(m, k, n, dtype) + def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) + def _worker(self, env: simpy.Environment) -> Generator: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.builtin.pe_types import TileToken + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, TileToken): + env.process(self._pipeline_process(env, msg)) + elif isinstance(msg, PeInternalTxn): + env.process(self._handle_with_hooks(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: + """Pipeline mode: pure MAC compute, then self-route.""" + self._on_process_start(env, token) + + m = token.params.get("m", 0) + k = token.params.get("k", 0) + n = token.params.get("n", 0) + + if self._accel: + with self._accel.request() as req: + yield req + ns = self._compute_ns(m, k, n) + yield env.timeout(ns) + else: + ns = self._compute_ns(m, k, n) + yield env.timeout(ns) + + self._on_process_end(env, token) + + # Self-routing + next_stage = token.advance() + if next_stage is not None: + yield self.out_ports[next_stage.component].put(token) + else: + token.pipeline_ctx.complete_tile() + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + """Legacy PeInternalTxn handling.""" from kernbench.common.pe_commands import GemmCmd cmd = pe_txn.command @@ -81,7 +143,6 @@ class PeGemmComponent(PeEngineBase): pe_txn.done.succeed() def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: - """Transaction forwarding with accel_slot acquisition.""" if self._accel: with self._accel.request() as req: yield req diff --git a/src/kernbench/components/builtin/pe_math.py b/src/kernbench/components/builtin/pe_math.py index c3c3a83..664c545 100644 --- a/src/kernbench/components/builtin/pe_math.py +++ b/src/kernbench/components/builtin/pe_math.py @@ -1,6 +1,16 @@ +"""PE_MATH: element-wise / reduction computation engine (ADR-0021 D6). + +Handles both legacy PeInternalTxn (MathCmd) and pipeline TileToken. +In pipeline mode, receives token after fetch stage, computes SIMD, chains to next. + +SIMD latency model (from pe_accel): + cycles = ceil(num_elements / vector_width) + latency_ns = cycles / clock_freq_ghz +""" from __future__ import annotations from collections.abc import Generator +from math import ceil from typing import TYPE_CHECKING, Any import simpy @@ -14,15 +24,17 @@ if TYPE_CHECKING: class PeMathComponent(PeEngineBase): - """PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4). + """PE_MATH: SIMD/Vector unit (ADR-0021 D6). - Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually - exclusive with PE_GEMM within the same PE. + In pipeline mode: pure compute — register data already fetched. + In legacy mode: handles PeInternalTxn(MathCmd) with shared accel_slot. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._accel: simpy.Resource | None = None + self._vector_width: int = int(node.attrs.get("vector_width", 256)) + self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0)) def init_resources(self, env: simpy.Environment) -> None: resource_name = self.node.attrs.get("shared_resource") @@ -31,11 +43,56 @@ class PeMathComponent(PeEngineBase): env, f"{self._pe_prefix}.{resource_name}" ) + def _compute_ns(self, num_elements: int) -> float: + """SIMD latency (pe_accel model).""" + if self._vector_width > 0 and self._clock_freq > 0 and num_elements > 0: + cycles = ceil(num_elements / self._vector_width) + return cycles / self._clock_freq + return float(self.node.attrs.get("overhead_ns", 0.0)) + def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) + def _worker(self, env: simpy.Environment) -> Generator: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.builtin.pe_types import TileToken + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, TileToken): + env.process(self._pipeline_process(env, msg)) + elif isinstance(msg, PeInternalTxn): + env.process(self._handle_with_hooks(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: + """Pipeline mode: pure SIMD compute, then self-route.""" + self._on_process_start(env, token) + + num_elements = token.params.get("num_elements", 0) + + if self._accel: + with self._accel.request() as req: + yield req + ns = self._compute_ns(num_elements) + yield env.timeout(ns) + else: + ns = self._compute_ns(num_elements) + yield env.timeout(ns) + + self._on_process_end(env, token) + + # Self-routing + next_stage = token.advance() + if next_stage is not None: + yield self.out_ports[next_stage.component].put(token) + else: + token.pipeline_ctx.complete_tile() + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + """Legacy PeInternalTxn handling.""" if self._accel: with self._accel.request() as req: yield req @@ -45,7 +102,6 @@ class PeMathComponent(PeEngineBase): pe_txn.done.succeed() def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: - """Transaction forwarding with accel_slot acquisition.""" if self._accel: with self._accel.request() as req: yield req diff --git a/src/kernbench/components/builtin/pe_scheduler.py b/src/kernbench/components/builtin/pe_scheduler.py index daa7c3a..ea07ee2 100644 --- a/src/kernbench/components/builtin/pe_scheduler.py +++ b/src/kernbench/components/builtin/pe_scheduler.py @@ -1,3 +1,13 @@ +"""PE_SCHEDULER: plan generation + tile dispatch (ADR-0021 D2). + +Receives PeInternalTxn from PE_CPU, routes to engines: + - Simple commands (DmaReadCmd, GemmCmd, etc.) → direct dispatch to engine + - CompositeCmd → generate TilePlan, feed tiles via _feed_loop + +Composite pipeline uses token self-routing (ADR-0021 D4): + Scheduler only does initial dispatch + completion tracking. + Tiles chain through components based on their plan's stage sequence. +""" from __future__ import annotations from collections.abc import Generator @@ -14,29 +24,18 @@ if TYPE_CHECKING: class PeSchedulerComponent(ComponentBase): - """PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1). + """PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1, ADR-0021 D2). - Receives PeInternalTxn from PE_CPU, routes to the appropriate engine: - - DmaReadCmd / DmaWriteCmd → PE_DMA - - GemmCmd → PE_GEMM - - MathCmd → PE_MATH - - CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2) + Simple commands are forwarded to the appropriate engine. + CompositeCmd creates a TilePlan and feeds tiles into the pipeline. - Composite GEMM pipeline (32x64x32 tiles): - DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t) - with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) - - Applies scheduler overhead_ns before dispatching each command. - Non-PeInternalTxn messages are forwarded via inherited _forward_txn(). + Single _feed_loop process per scheduler ensures FIFO command ordering. """ - # Scheduler tile dimensions (ADR-0014 D3.2) TILE_M = 32 TILE_K = 64 TILE_N = 32 - # Command → engine suffix dispatch table. - # New engines: add a single entry here (e.g. ConvCmd: "pe_conv"). _CMD_DISPATCH: dict[type, str] = {} @classmethod @@ -44,7 +43,6 @@ class PeSchedulerComponent(ComponentBase): if cls._CMD_DISPATCH: return from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd - cls._CMD_DISPATCH = { DmaReadCmd: "pe_dma", DmaWriteCmd: "pe_dma", @@ -56,6 +54,13 @@ class PeSchedulerComponent(ComponentBase): super().__init__(node, ctx) self._pe_prefix = node.id.rsplit(".", 1)[0] self._ensure_dispatch_table() + self._pending_feeds: simpy.Store | None = None + self._pipeline_counter = 0 + + def start(self, env: simpy.Environment) -> None: + self._pending_feeds = simpy.Store(env) + super().start(env) + env.process(self._feed_loop(env)) def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) @@ -72,174 +77,103 @@ class PeSchedulerComponent(ComponentBase): yield from self._forward_txn(env, msg) def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: - """Route a PeInternalTxn to the correct engine via dispatch table.""" - from kernbench.common.pe_commands import CompositeCmd + from kernbench.common.pe_commands import CompositeCmd, PeCpuOverheadCmd - # Scheduler overhead - yield from self.run(env, 0) + yield from self.run(env, 0) # scheduler overhead cmd = pe_txn.command - # Check dispatch table first + # Simple command dispatch engine_suffix = self._CMD_DISPATCH.get(type(cmd)) if engine_suffix is not None: yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn) return - # CompositeCmd: tiled pipeline (not a simple forward) + # CompositeCmd: generate plan and feed if isinstance(cmd, CompositeCmd): - yield from self._dispatch_composite(env, pe_txn) + yield from self._dispatch_composite(env, pe_txn, cmd) + return + + if isinstance(cmd, PeCpuOverheadCmd): + yield env.timeout(cmd.cycles) + pe_txn.done.succeed() return - # Unknown command — signal done immediately pe_txn.done.succeed() - def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: - """Composite tiled pipeline (ADR-0014 D3.2). + def _dispatch_composite( + self, env: simpy.Environment, pe_txn: Any, cmd: Any, + ) -> Generator: + """Generate plan and enqueue to feeder. Non-blocking (ADR-0021 D4).""" + from kernbench.components.builtin.pe_types import PipelineContext - GEMM: 3-stage pipeline with b-tile streaming from HBM. - MATH: sequential compute + DMA_WRITE (no tiling). + plan = self._generate_plan(cmd) + + self._pipeline_counter += 1 + ctx = PipelineContext( + id=f"p{self._pipeline_counter}", + total_tiles=len(plan.tiles), + done_event=pe_txn.done, + ) + + # Enqueue to feeder — scheduler worker returns immediately + assert self._pending_feeds is not None + yield self._pending_feeds.put((plan, ctx)) + + def _feed_loop(self, env: simpy.Environment) -> Generator: + """Single feeder process: FIFO command ordering (ADR-0021 D2). + + No tile feed interleaving between commands. + Queue full → only this process blocks. """ - from kernbench.common.pe_commands import CompositeCmd + from kernbench.components.builtin.pe_types import TileToken + + assert self._pending_feeds is not None + while True: + plan, ctx = yield self._pending_feeds.get() + for tile in plan.tiles: + first_stage = tile.stages[0] + token = TileToken( + tile_id=tile.tile_id, + pipeline_ctx=ctx, + plan=tile, + stage_idx=0, + params=first_stage.params, + ) + yield self.out_ports[first_stage.component].put(token) + + def _generate_plan(self, cmd: Any) -> Any: + """Generate a PipelinePlan from CompositeCmd.""" + from kernbench.components.builtin.tiling import ( + generate_gemm_plan, + generate_math_plan, + ) + + pp = self._pe_prefix + bpe = 2 # default bytes per element (f16) - cmd = pe_txn.command - assert isinstance(cmd, CompositeCmd) if cmd.op == "gemm" and cmd.b is not None: - yield from self._pipeline_gemm(env, pe_txn, cmd) + a = cmd.a + b = cmd.b + M, K = a.shape[-2], a.shape[-1] + N = b.shape[-1] + return generate_gemm_plan( + M=M, K=K, N=N, + tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N, + bytes_per_element=bpe, + A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr, + pe_prefix=pp, + ) else: - yield from self._pipeline_math(env, pe_txn, cmd) - - def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: - """Tiled GEMM pipeline: stream b tiles from HBM, compute, write results. - - Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref). - Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t) - Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) - """ - from kernbench.common.pe_commands import ( - DmaReadCmd, - DmaWriteCmd, - GemmCmd, - PeInternalTxn as PeTxn, - TensorHandle, - ) - - pp = self._pe_prefix - a = cmd.a # already in TCM - b = cmd.b # HBM reference (via tl.ref) - - M, K_a = a.shape[-2], a.shape[-1] - K_b, N = b.shape[-2], b.shape[-1] - dtype = a.dtype - dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2 - - # Tile counts - n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K) - n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N) - n_tiles = n_tiles_k * n_tiles_n - - prev_compute_done = None - prev_write_done = None - total_dma_ns = 0.0 - total_compute_ns = 0.0 - - for tile_idx in range(n_tiles): - tk = tile_idx // n_tiles_n - tn = tile_idx % n_tiles_n - - k_start = tk * self.TILE_K - n_start = tn * self.TILE_N - tile_k = min(self.TILE_K, K_a - k_start) - tile_n = min(self.TILE_N, N - n_start) - tile_nbytes = tile_k * tile_n * dtype_bytes - - # --- Stage 1: DMA_READ b_tile from HBM --- - read_done = env.event() - b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes - b_tile_handle = TensorHandle( - id=f"b_tile_{tile_idx}", addr=b_tile_addr, - shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes, + # Math composite + a = cmd.a + M = a.shape[-2] if len(a.shape) >= 2 else a.shape[0] + N = a.shape[-1] if len(a.shape) >= 2 else 1 + return generate_math_plan( + M=M, N=N, + tile_m=self.TILE_M, tile_n=self.TILE_N, + bytes_per_element=bpe, + math_op=cmd.math_op or "identity", + src_addr=a.addr, dst_addr=cmd.out_addr, + pe_prefix=pp, ) - read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes) - read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp) - t0 = env.now - yield self.out_ports[f"{pp}.pe_dma"].put(read_txn) - - # Wait for previous compute before starting this tile's compute - if prev_compute_done is not None: - yield prev_compute_done - - # Wait for this tile's DMA_READ - yield read_done - total_dma_ns += env.now - t0 - - # --- Stage 2: COMPUTE (GEMM) --- - compute_done = env.event() - out_handle = TensorHandle( - id=f"out_tile_{tile_idx}", addr=0, - shape=(M, tile_n), dtype=dtype, - nbytes=M * tile_n * dtype_bytes, - ) - compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle, - m=M, k=tile_k, n=tile_n) - compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) - t0 = env.now - yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn) - - # Wait for previous write (DMA_WRITE serialization) - if prev_write_done is not None: - yield prev_write_done - - # Wait for compute of THIS tile - yield compute_done - total_compute_ns += env.now - t0 - prev_compute_done = compute_done - - # --- Stage 3: DMA_WRITE out_tile to HBM --- - write_done = env.event() - out_tile_pa = cmd.out_addr + n_start * dtype_bytes - write_nbytes = M * tile_n * dtype_bytes - write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes) - write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) - t0 = env.now - yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) - prev_write_done = write_done - - # Wait for final write - if prev_write_done is not None: - t0 = env.now - yield prev_write_done - total_dma_ns += env.now - t0 - - pe_txn.result_data["dma_ns"] = total_dma_ns - pe_txn.result_data["compute_ns"] = total_compute_ns - pe_txn.done.succeed() - - def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: - """Non-GEMM composite: sequential compute + DMA_WRITE (no tiling).""" - from kernbench.common.pe_commands import ( - DmaWriteCmd, - MathCmd, - PeInternalTxn as PeTxn, - ) - - pp = self._pe_prefix - - # Step 1: Compute (MATH) - compute_done = env.event() - compute_cmd = MathCmd( - op=cmd.math_op or "identity", - inputs=(cmd.a,), out=cmd.a, - ) - compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) - yield self.out_ports[f"{pp}.pe_math"].put(compute_txn) - yield compute_done - - # Step 2: DMA_WRITE result to HBM - write_done = env.event() - write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes) - write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) - yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) - yield write_done - - pe_txn.done.succeed() diff --git a/src/kernbench/components/builtin/pe_tcm.py b/src/kernbench/components/builtin/pe_tcm.py index 6458d56..dfe940e 100644 --- a/src/kernbench/components/builtin/pe_tcm.py +++ b/src/kernbench/components/builtin/pe_tcm.py @@ -1,7 +1,18 @@ +"""PE_TCM: tightly-coupled memory with BW-based access serialization (ADR-0021). + +Models scratchpad memory inside the PE. Handles both legacy Transaction forwarding +and TcmRequest from PE_FETCH_STORE for BW-serialized read/write access. + +Two channels (read/write) with independent serialization. +Ported from pe_accel TcmBlock timing model. +""" from __future__ import annotations from collections.abc import Generator -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import simpy from kernbench.components.base import ComponentBase @@ -10,16 +21,62 @@ if TYPE_CHECKING: from kernbench.topology.types import Node -class PeTcmComponent(ComponentBase): - """PE_TCM: tightly-coupled memory / local SRAM staging buffer. +@dataclass +class TcmRequest: + """Request to read from or write to TCM (used by PE_FETCH_STORE).""" - Terminal storage component for PE-internal dataflow (ADR-0014 D5). - Phase 0: applies overhead_ns and drain_ns at terminal. + direction: str # "read" or "write" + nbytes: int + done: simpy.Event + tag: str = "" + + +class PeTcmComponent(ComponentBase): + """PE_TCM: BW-serialized scratchpad memory (ADR-0021 D1). + + Dual-channel: read and write can proceed in parallel, + but concurrent reads serialize, concurrent writes serialize. + BW from topology attrs or pe_template links. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) + self._read_bw: float = float(node.attrs.get("read_bw_gbs", 512.0)) + self._write_bw: float = float(node.attrs.get("write_bw_gbs", 512.0)) + self._read_res: simpy.Resource | None = None + self._write_res: simpy.Resource | None = None - def run(self, env, nbytes: int) -> Generator: + def start(self, env: simpy.Environment) -> None: + self._read_res = simpy.Resource(env, capacity=1) + self._write_res = simpy.Resource(env, capacity=1) + super().start(env) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + """Dispatch TcmRequest (from fetch_store) and Transaction (fabric).""" + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, TcmRequest): + env.process(self._handle_tcm_request(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + def _handle_tcm_request(self, env: simpy.Environment, req: TcmRequest) -> Generator: + """BW-serialized access: acquire channel, apply delay, signal done.""" + if req.direction == "write": + res = self._write_res + bw = self._write_bw + else: + res = self._read_res + bw = self._read_bw + + assert res is not None + with res.request() as lock: + yield lock + if bw > 0 and req.nbytes > 0: + delay_ns = req.nbytes / bw + yield env.timeout(delay_ns) + req.done.succeed() diff --git a/src/kernbench/components/builtin/pe_types.py b/src/kernbench/components/builtin/pe_types.py new file mode 100644 index 0000000..77b92bb --- /dev/null +++ b/src/kernbench/components/builtin/pe_types.py @@ -0,0 +1,115 @@ +"""PE pipeline types for ADR-0021: TileToken, TilePlan, Stage, PipelineContext. + +These types are used by the PE_SCHEDULER and all PE engine components +for tile-based pipeline execution with self-routing. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import simpy + + +# ── Stage types ────────────────────────────────────────────────────── + + +class StageType(Enum): + DMA_READ = auto() + FETCH = auto() + GEMM = auto() + MATH = auto() + STORE = auto() + DMA_WRITE = auto() + + +@dataclass +class Stage: + """One stage in a tile's execution plan.""" + + stage_type: StageType + component: str # topology node ID (e.g. "sip0.cube0.pe0.pe_dma") + params: dict = field(default_factory=dict) + + +# ── Plan ───────────────────────────────────────────────────────────── + + +@dataclass +class TilePlan: + """Execution plan for a single tile (immutable stage sequence).""" + + tile_id: int + stages: tuple[Stage, ...] + + +@dataclass +class PipelinePlan: + """Full pipeline plan for one CompositeCmd.""" + + tiles: list[TilePlan] + # Metadata for metrics + m_tiles: int = 0 + k_tiles: int = 0 + n_tiles: int = 0 + + +# ── Pipeline Context ───────────────────────────────────────────────── + + +@dataclass +class PipelineContext: + """Tracks completion of a pipeline (exactly-once contract). + + Each tile's last stage calls complete_tile() exactly once. + When all tiles complete, done_event.succeed() is called. + """ + + id: str + total_tiles: int + completed_tiles: int = 0 + done_event: Any = None # simpy.Event + + def complete_tile(self) -> None: + self.completed_tiles += 1 + if self.completed_tiles == self.total_tiles: + if self.done_event is not None: + self.done_event.succeed() + + +# ── TileToken ──────────────────────────────────────────────────────── + + +@dataclass +class TileToken: + """Self-routing tile token passed between PE components (ADR-0021 D9). + + Single-owner: only one component holds this token at any time. + params is a cache of plan.stages[stage_idx].params (canonical source). + """ + + tile_id: int + pipeline_ctx: PipelineContext + plan: TilePlan + stage_idx: int + params: dict = field(default_factory=dict) + data_op: bool = True # op_log recording target (ADR-0020) + + @property + def current_stage(self) -> Stage: + return self.plan.stages[self.stage_idx] + + @property + def has_next_stage(self) -> bool: + return self.stage_idx + 1 < len(self.plan.stages) + + def advance(self) -> Stage | None: + """Advance to next stage. Returns next Stage or None if last.""" + self.stage_idx += 1 + if self.stage_idx < len(self.plan.stages): + next_stage = self.plan.stages[self.stage_idx] + self.params = next_stage.params + return next_stage + return None diff --git a/src/kernbench/components/builtin/tiling.py b/src/kernbench/components/builtin/tiling.py new file mode 100644 index 0000000..4ee63ad --- /dev/null +++ b/src/kernbench/components/builtin/tiling.py @@ -0,0 +1,176 @@ +"""Tile plan generators for PE pipeline (ADR-0021). + +Generates TilePlan with stage sequences for GEMM and Math operations. +Ported from pe_accel tiling.py with stage-based plan structure. +""" +from __future__ import annotations + +from math import ceil + +from kernbench.components.builtin.pe_types import ( + PipelinePlan, + Stage, + StageType, + TilePlan, +) + + +def generate_gemm_plan( + M: int, K: int, N: int, + tile_m: int, tile_k: int, tile_n: int, + bytes_per_element: int, + A_addr: int, B_addr: int, C_addr: int, + pe_prefix: str, +) -> PipelinePlan: + """Generate GEMM tile plan: M→N→K order. + + Each tile follows stage sequence: + DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE + On last K-tile per (m,n): → DMA_WRITE + + Args: + pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs. + """ + M_tiles = max(1, ceil(M / tile_m)) + K_tiles = max(1, ceil(K / tile_k)) + N_tiles = max(1, ceil(N / tile_n)) + bpe = bytes_per_element + + dma_id = f"{pe_prefix}.pe_dma" + fetch_id = f"{pe_prefix}.pe_fetch_store" + gemm_id = f"{pe_prefix}.pe_gemm" + # math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed + + tiles: list[TilePlan] = [] + tile_id = 0 + + for m in range(M_tiles): + for n in range(N_tiles): + c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe + for k in range(K_tiles): + last_k = k == K_tiles - 1 + a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe + b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe + + a_bytes = tile_m * tile_k * bpe + b_bytes = tile_k * tile_n * bpe + out_bytes = tile_m * tile_n * bpe + + stages: list[Stage] = [] + + # DMA READ: load A and B tiles from HBM → TCM + stages.append(Stage( + stage_type=StageType.DMA_READ, + component=dma_id, + params={ + "src_addr": a_addr, "nbytes": a_bytes, + "operand": "A", "tile_m": tile_m, "tile_k": tile_k, + }, + )) + stages.append(Stage( + stage_type=StageType.DMA_READ, + component=dma_id, + params={ + "src_addr": b_addr, "nbytes": b_bytes, + "operand": "B", "tile_k": tile_k, "tile_n": tile_n, + }, + )) + + # FETCH: TCM → Register File + stages.append(Stage( + stage_type=StageType.FETCH, + component=fetch_id, + params={ + "direction": "read", + "nbytes": a_bytes + b_bytes, + }, + )) + + # GEMM: MAC compute + stages.append(Stage( + stage_type=StageType.GEMM, + component=gemm_id, + params={ + "m": tile_m, "k": tile_k, "n": tile_n, + "is_last_k": last_k, + }, + )) + + # STORE: Register File → TCM + stages.append(Stage( + stage_type=StageType.STORE, + component=fetch_id, + params={ + "direction": "write", + "nbytes": out_bytes, + }, + )) + + # DMA WRITE: TCM → HBM (only on last K-tile) + if last_k: + stages.append(Stage( + stage_type=StageType.DMA_WRITE, + component=dma_id, + params={ + "dst_addr": c_addr, "nbytes": out_bytes, + }, + )) + + tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages))) + tile_id += 1 + + return PipelinePlan( + tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles, + ) + + +def generate_math_plan( + M: int, N: int, + tile_m: int, tile_n: int, + bytes_per_element: int, + math_op: str, + src_addr: int, dst_addr: int, + pe_prefix: str, +) -> PipelinePlan: + """Generate element-wise math tile plan. + + Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE + """ + M_tiles = max(1, ceil(M / tile_m)) + N_tiles = max(1, ceil(N / tile_n)) + bpe = bytes_per_element + + dma_id = f"{pe_prefix}.pe_dma" + fetch_id = f"{pe_prefix}.pe_fetch_store" + math_id = f"{pe_prefix}.pe_math" + + tiles: list[TilePlan] = [] + tile_id = 0 + + for m in range(M_tiles): + for n in range(N_tiles): + offset = (m * tile_m * N + n * tile_n) * bpe + tile_bytes = tile_m * tile_n * bpe + + stages = [ + Stage(StageType.DMA_READ, dma_id, { + "src_addr": src_addr + offset, "nbytes": tile_bytes, + }), + Stage(StageType.FETCH, fetch_id, { + "direction": "read", "nbytes": tile_bytes, + }), + Stage(StageType.MATH, math_id, { + "op": math_op, "num_elements": tile_m * tile_n, + }), + Stage(StageType.STORE, fetch_id, { + "direction": "write", "nbytes": tile_bytes, + }), + Stage(StageType.DMA_WRITE, dma_id, { + "dst_addr": dst_addr + offset, "nbytes": tile_bytes, + }), + ] + + tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages))) + tile_id += 1 + + return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles) diff --git a/src/kernbench/components/builtin_legacy/__init__.py b/src/kernbench/components/builtin_legacy/__init__.py new file mode 100644 index 0000000..9e2e26b --- /dev/null +++ b/src/kernbench/components/builtin_legacy/__init__.py @@ -0,0 +1,34 @@ +"""Concrete component implementations. + +Loaded from components.yaml via ComponentRegistry.load_components_yaml(). +Manual imports are no longer needed — add new impls to components.yaml. + +Classes are still importable from this package via lazy __getattr__. +""" + +from kernbench.components.base import ComponentRegistry + +ComponentRegistry.load_components_yaml() + +# Lazy re-export: allow `from kernbench.components.builtin import FooComponent` +# without eagerly importing every module. +_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName" + + +def _build_class_map() -> None: + if _CLASS_MAP: + return + for class_path in ComponentRegistry._lazy.values(): + module_path, class_name = class_path.rsplit(":", 1) + _CLASS_MAP[class_name] = class_path + + +def __getattr__(name: str): + _build_class_map() + class_path = _CLASS_MAP.get(name) + if class_path is None: + raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'") + import importlib + module_path, class_name = class_path.rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name) diff --git a/src/kernbench/components/builtin_legacy/forwarding.py b/src/kernbench/components/builtin_legacy/forwarding.py new file mode 100644 index 0000000..1fa8eee --- /dev/null +++ b/src/kernbench/components/builtin_legacy/forwarding.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING + +import simpy + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class TransitComponent(ComponentBase): + """Transit component for NOC, UCIe, XBAR nodes. + + Applies overhead_ns processing delay (from node.attrs) then forwards the + Transaction to the next hop via inherited _forward_txn(). + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) diff --git a/src/kernbench/components/builtin_legacy/hbm_ctrl.py b/src/kernbench/components/builtin_legacy/hbm_ctrl.py new file mode 100644 index 0000000..a75ec25 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/hbm_ctrl.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class HbmCtrlComponent(ComponentBase): + """HBM controller: terminal component that models HBM access latency. + + Dual-channel model: separate read and write resources (each capacity=1) + allowing concurrent read/write like PE_DMA. Multiple reads or multiple + writes still serialize within their respective channel. + + On completion, creates a ResponseMsg and sends it back on the reverse path + so that response latency is modeled through the fabric. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._read: simpy.Resource | None = None + self._write: simpy.Resource | None = None + + def start(self, env: simpy.Environment) -> None: + capacity = int(self.node.attrs.get("capacity", 1)) + self._read = simpy.Resource(env, capacity=capacity) + self._write = simpy.Resource(env, capacity=capacity) + super().start(env) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _select_channel(self, txn: Any) -> simpy.Resource: + """Select channel based on request type: write requests → write, else → read.""" + from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg + + assert self._read is not None and self._write is not None + req = txn.request + if isinstance(req, MemoryWriteMsg): + return self._write + if isinstance(req, PeDmaMsg) and req.is_write: + return self._write + return self._read + + def _worker(self, env: simpy.Environment) -> Generator: + """Dispatch each incoming txn to a concurrent process for channel-level parallelism.""" + while True: + txn: Any = yield self._inbox.get() + env.process(self._handle_txn(env, txn)) + + def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator: + """Acquire channel, run, apply drain, send response.""" + channel = self._select_channel(txn) + with channel.request() as req: + yield req + yield from self.run(env, txn.nbytes) + drain = getattr(txn, "drain_ns", 0.0) + if drain > 0: + yield env.timeout(drain) + yield from self._send_response(env, txn) + + def _send_response(self, env: simpy.Environment, txn: Any) -> Generator: + """Route completion based on path type. + + - PeDmaMsg: succeed done directly (probe). + - Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends + data back on reverse path with original done event. + - M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation. + """ + from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg + + if isinstance(txn.request, PeDmaMsg): + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + resp_txn = Transaction( + request=txn.request, path=reverse_path, step=0, + nbytes=0, done=txn.done, is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + return + txn.done.succeed() + return + + # Bypass path: no m_cpu in the transaction path + is_bypass = not any("m_cpu" in n for n in txn.path) + if is_bypass: + if isinstance(txn.request, MemoryReadMsg): + # D2H: send data back on reverse path to pcie_ep + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + resp_txn = Transaction( + request=txn.request, path=reverse_path, step=0, + nbytes=txn.request.nbytes, done=txn.done, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + return + # MemoryWrite bypass or short path: done + txn.done.succeed() + return + + # M_CPU DMA path: send ResponseMsg for aggregation + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2 and self.ctx: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + pe_id = 0 # single hbm_ctrl, PE info from request + resp_msg = ResponseMsg( + correlation_id=txn.request.correlation_id, + request_id=txn.request.request_id, + src_cube=cube_id, src_pe=pe_id, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() diff --git a/src/kernbench/components/builtin_legacy/io_cpu.py b/src/kernbench/components/builtin_legacy/io_cpu.py new file mode 100644 index 0000000..83f2b8a --- /dev/null +++ b/src/kernbench/components/builtin_legacy/io_cpu.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class IoCpuComponent(ComponentBase): + """IO_CPU component: multi-cube fan-out with response aggregation. + + Forward path: + 1. Applies overhead_ns processing overhead. + 2. Resolves target cube(s) from request.target_cubes. + 3. Fans out sub-Transactions to each target cube's M_CPU. + + Response path: + Collects ResponseMsg from each M_CPU. When all cube responses are + received, succeeds the parent txn.done. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + # Pending fan-out tracking: request_id → (expected, received, parent_txn_done) + self._pending: dict[str, tuple[int, int, simpy.Event]] = {} + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + while True: + txn: Any = yield self._inbox.get() + if getattr(txn, "is_response", False): + self._collect_response(txn) + else: + yield from self.run(env, txn.nbytes) + env.process(self._dispatch_to_m_cpus(env, txn)) + + def _collect_response(self, resp_txn: Any) -> None: + """Receive a cube response and increment the aggregation counter.""" + key = resp_txn.request.request_id + if key not in self._pending: + return + expected, received, parent_done = self._pending[key] + received += 1 + if received >= expected: + parent_done.succeed() + del self._pending[key] + else: + self._pending[key] = (expected, received, parent_done) + + def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator: + """Fan out sub-Transactions to target cube M_CPUs, wait for responses.""" + from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg + + request = txn.request + try: + cube_targets = self._resolve_cube_targets(request) + except Exception: + txn.done.succeed() + return + + if not cube_targets: + txn.done.succeed() + return + + # Setup aggregation + self._pending[request.request_id] = (len(cube_targets), 0, txn.done) + + # Fan out to each target cube's M_CPU + for sip, cube in cube_targets: + try: + m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube) + path = self.ctx.router.find_node_path(self.node.id, m_cpu_id) + except Exception: + continue + if len(path) < 2: + continue + sub_txn = Transaction( + request=request, path=path, step=0, + nbytes=txn.nbytes, done=env.event(), + result_data=txn.result_data, + ) + yield self.out_ports[path[1]].put(sub_txn.advance()) + + def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]: + """Return list of (sip, cube) pairs to fan out to.""" + from kernbench.runtime_api.kernel import ( + KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg, + ) + + target_cubes = getattr(request, "target_cubes", "all") + + if isinstance(request, MemoryWriteMsg): + sip = request.dst_sip + if target_cubes == "all": + cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube) + return [(sip, cube)] + return [(sip, c) for c in target_cubes] + + if isinstance(request, MemoryReadMsg): + sip = request.src_sip + if target_cubes == "all": + cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube) + return [(sip, cube)] + return [(sip, c) for c in target_cubes] + + if isinstance(request, KernelLaunchMsg): + my_sip = self._my_sip() + if target_cubes != "all": + return [(my_sip, c) for c in target_cubes] + # "all": derive from tensor shards, filtered to this SIP + seen: set[tuple[int, int]] = set() + targets: list[tuple[int, int]] = [] + for arg in request.args: + if arg.arg_kind != "tensor": + continue + for shard in arg.shards: + if shard.sip != my_sip: + continue + key = (shard.sip, shard.cube) + if key not in seen: + seen.add(key) + targets.append(key) + return targets + + if isinstance(request, (MmuMapMsg, MmuUnmapMsg)): + my_sip = self._my_sip() + if target_cubes == "all": + n_cubes = 16 + if self.ctx and self.ctx.spec: + sips = self.ctx.spec.get("system", {}).get("sips", {}) + n_cubes = sips.get("cubes_per_sip", 16) + return [(my_sip, c) for c in range(n_cubes)] + return [(my_sip, c) for c in target_cubes] + + return [] + + def _cube_from_pa(self, pa_val: int, fallback: int) -> int: + """Extract cube_id from a physical address, with fallback.""" + from kernbench.policy.address.phyaddr import PhysAddr + try: + return PhysAddr.decode(pa_val).cube_id + except Exception: + return fallback + + def _my_sip(self) -> int: + """Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0).""" + return int(self.node.id.split(".")[0].replace("sip", "")) diff --git a/src/kernbench/components/builtin_legacy/m_cpu.py b/src/kernbench/components/builtin_legacy/m_cpu.py new file mode 100644 index 0000000..4fb9a12 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/m_cpu.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class MCpuComponent(ComponentBase): + """M_CPU component: multi-PE DMA fan-out with response aggregation. + + Forward path (ADR-0015 D5): + When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out + DMA sub-Transactions to target PEs' HBM slices. target_pe on the request + controls fan-out: int → single PE, "all" → all PEs in the cube. + + Response path: + ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses + are collected, m_cpu sends an aggregate ResponseMsg on the reverse command + path back to io_cpu. + + Transit: + When m_cpu is NOT the terminal hop (transit or response relay), the + Transaction is forwarded normally to the next hop. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + # Pending fan-out tracking: request_id → (expected, received, all_done_event) + self._pending: dict[str, tuple[int, int, simpy.Event]] = {} + # Store parent txn for response sending: request_id → parent_txn + self._parent_txns: dict[str, Any] = {} + # DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each + self._dma_write: simpy.Resource | None = None + self._dma_read: simpy.Resource | None = None + + def start(self, env: simpy.Environment) -> None: + self._dma_write = simpy.Resource(env, capacity=1) + self._dma_read = simpy.Resource(env, capacity=1) + super().start(env) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + """Dispatch forward txns, collect response txns.""" + from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg + + while True: + txn: Any = yield self._inbox.get() + if getattr(txn, "is_response", False): + self._collect_response(txn) + else: + yield from self.run(env, txn.nbytes) + next_hop = txn.next_hop + if next_hop: + yield self.out_ports[next_hop].put(txn.advance()) + elif self.ctx is not None and txn.request is not None: + if isinstance(txn.request, KernelLaunchMsg): + env.process(self._kernel_launch_fanout(env, txn)) + elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)): + env.process(self._mmu_msg_fanout(env, txn)) + else: + env.process(self._dma_fanout(env, txn)) + else: + txn.done.succeed() + + def _collect_response(self, resp_txn: Any) -> None: + """Receive a PE response and increment the aggregation counter.""" + key = resp_txn.request.request_id + if key not in self._pending: + return + expected, received, all_done = self._pending[key] + received += 1 + if received >= expected: + all_done.succeed() + del self._pending[key] + else: + self._pending[key] = (expected, received, all_done) + + def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator: + """Fan out DMA sub-Transactions to target PE(s), wait for responses, + then send aggregate response on reverse command path. + + Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4), + so multi-PE fan-out is serialized through the DMA engine. + """ + from kernbench.runtime_api.kernel import MemoryWriteMsg + + request = txn.request + target_pe = getattr(request, "target_pe", "all") + + dst_nodes = self._resolve_dma_destinations(request, target_pe) + if not dst_nodes: + txn.done.succeed() + return + + # Setup aggregation + all_done = env.event() + self._pending[request.request_id] = (len(dst_nodes), 0, all_done) + self._parent_txns[request.request_id] = txn + + # Select DMA resource based on operation type + dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read + + # Fan out DMA sub-txns (serialized through DMA resource) + max_drain_ns = 0.0 + for dst_node in dst_nodes: + try: + dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node) + except Exception: + continue + if len(dma_path) < 2: + continue + drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes) + max_drain_ns = max(max_drain_ns, drain_ns) + sub_txn = Transaction( + request=request, path=dma_path, step=0, + nbytes=txn.nbytes, done=env.event(), + drain_ns=drain_ns, + ) + with dma_res.request() as req: + yield req + yield self.out_ports[dma_path[1]].put(sub_txn.advance()) + + # Wait for all PE responses + yield all_done + txn.result_data["xfer_ns"] = max_drain_ns + del self._parent_txns[request.request_id] + + # Send aggregate response on reverse command path + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + resp_msg = ResponseMsg( + correlation_id=request.correlation_id, + request_id=request.request_id, + src_cube=cube_id, src_pe=-1, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() + + def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator: + """Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3). + + Routes through find_node_path (M_CPU → NOC → PE_CPU command edges). + PE_CPU sends ResponseMsg back via NOC → M_CPU on completion. + Then sends aggregate ResponseMsg back to IO_CPU on the reverse path. + """ + request = txn.request + target_pe = getattr(request, "target_pe", "all") + cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" + pe_ids = self._resolve_pe_ids(target_pe) + + if not pe_ids: + txn.done.succeed() + return + + # Fan out to each PE_CPU, using response-based aggregation + sub_txns: list[Transaction] = [] + n_dispatched = 0 + for pe_id in pe_ids: + pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu" + try: + path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id) + except Exception: + continue + if len(path) < 2: + continue + sub_txn = Transaction( + request=request, path=path, step=0, + nbytes=0, done=env.event(), + ) + yield self.out_ports[path[1]].put(sub_txn.advance()) + sub_txns.append(sub_txn) + n_dispatched += 1 + + if n_dispatched == 0: + txn.done.succeed() + return + + # Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response) + all_done = env.event() + self._pending[request.request_id] = (n_dispatched, 0, all_done) + self._parent_txns[request.request_id] = txn + + # Wait for all PE_CPU responses via NOC + yield all_done + del self._parent_txns[request.request_id] + + # Aggregate PE-internal metrics (max across PEs) + pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns] + if pe_exec_values: + txn.result_data["pe_exec_ns"] = max(pe_exec_values) + dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns] + if dma_values: + txn.result_data["dma_ns"] = max(dma_values) + compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns] + if compute_values: + txn.result_data["compute_ns"] = max(compute_values) + + # Send aggregate response on reverse command path back to IO_CPU + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + resp_msg = ResponseMsg( + correlation_id=request.correlation_id, + request_id=request.request_id, + src_cube=cube_id, src_pe=-1, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() + + def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]: + """Return list of HBM destination node_ids for DMA fan-out. + + With single hbm_ctrl per cube (ADR-0019), always returns one node. + PA-based resolution still used for cross-cube routing. + """ + cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" + + # PA-based resolution: extract actual target from physical address + pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None) + if pa_val is not None: + from kernbench.policy.address.phyaddr import PhysAddr + try: + pa = PhysAddr.decode(pa_val) + return [self.ctx.resolver.resolve(pa)] + except Exception: + pass + + # Default: single hbm_ctrl in local cube + return [f"{cube_prefix}.hbm_ctrl"] + + def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator: + """Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC. + + Routes through find_node_path (M_CPU → NOC → PE_MMU command edges). + PE_MMU is a terminal node — completes the transaction directly. + """ + request = txn.request + target_pe = getattr(request, "target_pe", "all") + cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" + pe_ids = self._resolve_pe_ids(target_pe) + + if not pe_ids: + txn.done.succeed() + return + + # Fan out to each PE_MMU + sub_dones: list[simpy.Event] = [] + for pe_id in pe_ids: + pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu" + try: + path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id) + except Exception: + continue + if len(path) < 2: + continue + sub_done = env.event() + sub_txn = Transaction( + request=request, path=path, step=0, + nbytes=0, done=sub_done, + ) + yield self.out_ports[path[1]].put(sub_txn.advance()) + sub_dones.append(sub_done) + + # Wait for all PE_MMUs to complete + for sd in sub_dones: + yield sd + + # Send aggregate response on reverse path + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + resp_msg = ResponseMsg( + correlation_id=request.correlation_id, + request_id=request.request_id, + src_cube=cube_id, src_pe=-1, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() + + def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]: + """Return list of PE IDs to fan out to (used by kernel launch fan-out).""" + if isinstance(target_pe, int): + return [target_pe] + if isinstance(target_pe, tuple): + return list(target_pe) + # "all": all PEs in local cube + n_slices = 8 + if self.ctx and self.ctx.spec: + mm = self.ctx.spec.get("cube", {}).get("memory_map", {}) + n_slices = mm.get("hbm_slices_per_cube", 8) + return list(range(n_slices)) diff --git a/src/kernbench/components/builtin_legacy/pcie_ep.py b/src/kernbench/components/builtin_legacy/pcie_ep.py new file mode 100644 index 0000000..53faac0 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pcie_ep.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING + +import simpy + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PcieEpComponent(ComponentBase): + """PCIe endpoint: protocol processing overhead before forwarding. + + Applies overhead_ns (from node.attrs) for PCIe protocol handling, + then forwards via inherited _forward_txn(). + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) diff --git a/src/kernbench/components/builtin_legacy/pe_cpu.py b/src/kernbench/components/builtin_legacy/pe_cpu.py new file mode 100644 index 0000000..4947b9d --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_cpu.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeCpuComponent(ComponentBase): + """PE_CPU: kernel execution controller (Stage 2). + + Two-phase kernel execution (ADR-0014 D1): + Phase 1 (compile): look up kernel from registry, run it with TLContext + to generate a PeCommand list. + Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via + PeInternalTxn, wait for blocking commands. + + Non-kernel Transactions are forwarded normally. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0" + try: + self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1]) + except (IndexError, ValueError): + self._pe_idx = 0 + # Extract sip/cube index for multi-SIP/cube shard matching + parts = node.id.split(".") + try: + self._sip_idx = int(parts[0].replace("sip", "")) + except (IndexError, ValueError): + self._sip_idx = 0 + try: + self._cube_idx = int(parts[1].replace("cube", "")) + except (IndexError, ValueError): + self._cube_idx = 0 + + def _find_shard(self, shards: tuple) -> Any: + """Find shard matching this PE's (sip, cube, pe). Fallback to positional index.""" + for s in shards: + if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx: + return s + return shards[min(self._pe_idx, len(shards) - 1)] + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + while True: + txn: Any = yield self._inbox.get() + from kernbench.runtime_api.kernel import KernelLaunchMsg + + if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg): + yield from self._execute_kernel(env, txn) + else: + yield from self._forward_txn(env, txn) + + def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator: + """Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay.""" + from kernbench.triton_emu.registry import get_kernel + + request = txn.request + yield from self.run(env, 0) + + kernel_fn = get_kernel(request.kernel_ref.name) + num_programs = self._derive_num_programs(request) + kernel_args = self._unpack_kernel_args(request) + + pe_exec_start = env.now + scheduler_id = f"{self._pe_prefix}.pe_scheduler" + + # Choose execution mode: greenlet (ADR-0020) or legacy command-list + store = getattr(self.ctx, "memory_store", None) if self.ctx else None + + if store is not None: + composite_results = yield from self._execute_greenlet( + env, kernel_fn, kernel_args, num_programs, scheduler_id, store, + ) + else: + composite_results = yield from self._execute_legacy( + env, kernel_fn, kernel_args, num_programs, scheduler_id, + ) + + # Record PE-internal execution time + txn.result_data["pe_exec_ns"] = env.now - pe_exec_start + total_dma_ns = 0.0 + total_compute_ns = 0.0 + for rd in composite_results: + total_dma_ns += rd.get("dma_ns", 0.0) + total_compute_ns += rd.get("compute_ns", 0.0) + txn.result_data["dma_ns"] = total_dma_ns + txn.result_data["compute_ns"] = total_compute_ns + + # Send ResponseMsg on reverse path + yield from self._send_response(env, txn, request) + + def _derive_num_programs(self, request: Any) -> int: + num_programs = 1 + for arg in request.args: + if arg.arg_kind == "tensor": + cube_pe_count = sum( + 1 for s in arg.shards + if s.sip == self._sip_idx and s.cube == self._cube_idx + ) + if cube_pe_count > num_programs: + num_programs = cube_pe_count + return num_programs + + def _unpack_kernel_args(self, request: Any) -> list: + kernel_args: list = [] + for arg in request.args: + if arg.arg_kind == "tensor": + if arg.va_base: + kernel_args.append(arg.va_base) + else: + shard = self._find_shard(arg.shards) + kernel_args.append(shard.pa) + elif arg.arg_kind == "scalar": + kernel_args.append(arg.value) + return kernel_args + + def _execute_greenlet( + self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store, + ) -> Generator: + """Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved.""" + from kernbench.triton_emu.kernel_runner import KernelRunner + + runner = KernelRunner( + pe_prefix=self._pe_prefix, + pe_idx=self._pe_idx, + sip_idx=self._sip_idx, + cube_idx=self._cube_idx, + scheduler_id=scheduler_id, + out_ports=self.out_ports, + store=store, + ) + yield from runner.run(env, kernel_fn, kernel_args, num_programs) + return getattr(runner, "_composite_results", []) + + def _execute_legacy( + self, env, kernel_fn, kernel_args, num_programs, scheduler_id, + ) -> Generator: + """Legacy Phase 0 + replay: generate command list, then dispatch.""" + from kernbench.common.pe_commands import ( + CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd, + ) + from kernbench.triton_emu.tl_context import TLContext, run_kernel + + tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) + run_kernel(kernel_fn, tl, *kernel_args) + commands = tl.commands + + pending: dict[str, simpy.Event] = {} + composite_results: list[dict] = [] + + for cmd in commands: + if isinstance(cmd, PeCpuOverheadCmd): + yield env.timeout(cmd.cycles) + elif isinstance(cmd, WaitCmd): + if cmd.handle is not None: + evt = pending.pop(cmd.handle.id, None) + if evt: + yield evt + else: + for evt in pending.values(): + yield evt + pending.clear() + elif isinstance(cmd, CompositeCmd): + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + composite_results.append(pe_txn.result_data) + yield self.out_ports[scheduler_id].put(pe_txn) + pending[cmd.completion.id] = done_evt + else: + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + yield self.out_ports[scheduler_id].put(pe_txn) + yield done_evt + + for evt in pending.values(): + yield evt + return composite_results + + def _send_response(self, env, txn, request) -> Generator: + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2: + from kernbench.runtime_api.kernel import ResponseMsg + + resp_msg = ResponseMsg( + correlation_id=request.correlation_id, + request_id=request.request_id, + src_cube=self._cube_idx, src_pe=self._pe_idx, + success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() diff --git a/src/kernbench/components/builtin_legacy/pe_dma.py b/src/kernbench/components/builtin_legacy/pe_dma.py new file mode 100644 index 0000000..c8ee823 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_dma.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import PeEngineBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeDmaComponent(PeEngineBase): + """PE_DMA: dual-channel DMA engine with READ and WRITE resources. + + Each channel has capacity=1 (ADR-0014 D4): + - DMA_READ and DMA_WRITE may execute concurrently. + - Multiple READs cannot overlap; multiple WRITEs cannot overlap. + + Handles two message types: + - Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA) + - PeInternalTxn: PE-internal commands from PE_SCHEDULER + (DmaReadCmd → HBM read, DmaWriteCmd → HBM write) + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._dma_read: simpy.Resource | None = None + self._dma_write: simpy.Resource | None = None + self._mmu = None # PeMMU instance, set by engine wiring + + def init_resources(self, env: simpy.Environment) -> None: + self._dma_read = simpy.Resource(env, capacity=1) + self._dma_write = simpy.Resource(env, capacity=1) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + yield env.timeout(0) + + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + """Handle PE-internal DMA command: resolve PA → HBM path → transfer.""" + from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd + from kernbench.policy.address.phyaddr import PhysAddr + from kernbench.runtime_api.kernel import PeDmaMsg + + cmd = pe_txn.command + assert self._dma_read is not None and self._dma_write is not None + + # Determine direction and target address (VA → PA via MMU) + if isinstance(cmd, DmaReadCmd): + dma_res = self._dma_read + raw_addr = cmd.src_addr + is_write = False + elif isinstance(cmd, DmaWriteCmd): + dma_res = self._dma_write + raw_addr = cmd.dst_addr + is_write = True + else: + pe_txn.done.succeed() + return + + # Translate VA → PA via MMU (if available), then resolve HBM node + # If MMU has no mapping for this address (PageFault), treat as PA directly + # (backward-compatible with PA-only mode) + if self._mmu is not None: + from kernbench.policy.address.pe_mmu import PageFault + try: + target_pa = self._mmu.translate(raw_addr) + if self._mmu.overhead_ns > 0: + yield env.timeout(self._mmu.overhead_ns) + except PageFault: + target_pa = raw_addr + else: + target_pa = raw_addr # fallback: treat as PA directly + pa = PhysAddr.decode(target_pa) + dst_node = self.ctx.resolver.resolve(pa) + path = self.ctx.router.find_path(self._pe_prefix, dst_node) + drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes) + + # Acquire DMA channel (command issue serialization) + with dma_res.request() as req: + yield req + # Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly) + sub_done = env.event() + sub_request = PeDmaMsg( + correlation_id="pe_internal", + request_id=f"dma_{id(pe_txn)}", + src_sip=0, src_cube=0, src_pe=0, + dst_pa=target_pa, nbytes=cmd.nbytes, + is_write=is_write, + ) + sub_txn = Transaction( + request=sub_request, path=path, step=0, + nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns, + ) + # Send to next hop (path[0] is pe_dma itself, path[1] is router) + if len(path) > 1: + yield self.out_ports[path[1]].put(sub_txn.advance()) + # DMA channel released after issue + + # Wait for HBM transfer completion + yield sub_done + pe_txn.done.succeed() + + def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: + """Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition.""" + # Response transactions bypass DMA channel (no outbound resource needed) + if getattr(txn, "is_response", False): + next_hop = txn.next_hop + if next_hop: + yield self.out_ports[next_hop].put(txn.advance()) + else: + txn.done.succeed() + return + + dma_res = self._select_channel(txn) + with dma_res.request() as req: + yield req + next_hop = txn.next_hop + if next_hop: + yield self.out_ports[next_hop].put(txn.advance()) + else: + drain = getattr(txn, "drain_ns", 0.0) + if drain > 0: + yield env.timeout(drain) + txn.done.succeed() + + def _select_channel(self, txn: Any) -> simpy.Resource: + """Select DMA channel based on request type.""" + from kernbench.runtime_api.kernel import MemoryWriteMsg + + assert self._dma_read is not None and self._dma_write is not None + if isinstance(txn.request, MemoryWriteMsg): + return self._dma_write + return self._dma_read diff --git a/src/kernbench/components/builtin_legacy/pe_gemm.py b/src/kernbench/components/builtin_legacy/pe_gemm.py new file mode 100644 index 0000000..3fc74e3 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_gemm.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import PeEngineBase + +if TYPE_CHECKING: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +# dtype → bit width (for TFLOPS scaling) +_DTYPE_BITS: dict[str, int] = { + "f16": 16, "fp16": 16, "float16": 16, "bf16": 16, + "f32": 32, "fp32": 32, "float32": 32, + "i8": 8, "int8": 8, + "i16": 16, "int16": 16, + "i32": 32, "int32": 32, +} + + +class PeGemmComponent(PeEngineBase): + """PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4). + + Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually + exclusive with PE_MATH within the same PE. + + Compute latency model: + FLOPs = 2 * M * K * N + effective_tflops = peak_tflops_f16 * (16 / dtype_bits) + compute_ns = FLOPs / (effective_tflops * 1e3) + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._accel: simpy.Resource | None = None + self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0)) + + def init_resources(self, env: simpy.Environment) -> None: + resource_name = self.node.attrs.get("shared_resource") + if resource_name and self.ctx: + self._accel = self.ctx.get_shared_resource( + env, f"{self._pe_prefix}.{resource_name}" + ) + + def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float: + """Compute GEMM latency in nanoseconds.""" + if self._peak_tflops_f16 <= 0: + return float(self.node.attrs.get("overhead_ns", 0.0)) + dtype_bits = _DTYPE_BITS.get(dtype, 16) + effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits) + flops = 2.0 * m * k * n + return flops / (effective_tflops * 1e3) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + from kernbench.common.pe_commands import GemmCmd + + cmd = pe_txn.command + if self._accel: + with self._accel.request() as req: + yield req + if isinstance(cmd, GemmCmd): + ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype) + yield env.timeout(ns) + else: + yield from self.run(env, 0) + else: + if isinstance(cmd, GemmCmd): + ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype) + yield env.timeout(ns) + else: + yield from self.run(env, 0) + pe_txn.done.succeed() + + def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: + """Transaction forwarding with accel_slot acquisition.""" + if self._accel: + with self._accel.request() as req: + yield req + yield from super()._forward_txn(env, txn) + else: + yield from super()._forward_txn(env, txn) diff --git a/src/kernbench/components/builtin_legacy/pe_math.py b/src/kernbench/components/builtin_legacy/pe_math.py new file mode 100644 index 0000000..c3c3a83 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_math.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import PeEngineBase + +if TYPE_CHECKING: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeMathComponent(PeEngineBase): + """PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4). + + Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually + exclusive with PE_GEMM within the same PE. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._accel: simpy.Resource | None = None + + def init_resources(self, env: simpy.Environment) -> None: + resource_name = self.node.attrs.get("shared_resource") + if resource_name and self.ctx: + self._accel = self.ctx.get_shared_resource( + env, f"{self._pe_prefix}.{resource_name}" + ) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + if self._accel: + with self._accel.request() as req: + yield req + yield from self.run(env, 0) + else: + yield from self.run(env, 0) + pe_txn.done.succeed() + + def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: + """Transaction forwarding with accel_slot acquisition.""" + if self._accel: + with self._accel.request() as req: + yield req + yield from super()._forward_txn(env, txn) + else: + yield from super()._forward_txn(env, txn) diff --git a/src/kernbench/components/builtin_legacy/pe_mmu.py b/src/kernbench/components/builtin_legacy/pe_mmu.py new file mode 100644 index 0000000..3481cc4 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_mmu.py @@ -0,0 +1,66 @@ +"""PE_MMU component: address translation unit. + +Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU). +Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead). +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase, ComponentRegistry +from kernbench.policy.address.pe_mmu import PeMMU + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeMmuComponent(ComponentBase): + """PE_MMU: per-PE virtual-to-physical address translation. + + Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal + page table. PE_DMA and PE_GEMM access the underlying PeMMU object + via the ``mmu`` property for synchronous VA→PA translation. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024)) + overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0)) + self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns) + + @property + def mmu(self) -> PeMMU: + """The underlying PeMMU utility object for direct translate() calls.""" + return self._mmu + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + yield env.timeout(0) + + def _worker(self, env: simpy.Environment) -> Generator: + """Process MmuMapMsg/MmuUnmapMsg from inbox.""" + from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg + + while True: + txn: Any = yield self._inbox.get() + + if hasattr(txn, "request"): + request = txn.request + if isinstance(request, MmuMapMsg): + for entry in request.entries: + self._mmu.map( + va=entry["va"], pa=entry["pa"], size=entry["size"], + ) + txn.done.succeed() + elif isinstance(request, MmuUnmapMsg): + for entry in request.entries: + self._mmu.unmap(va=entry["va"], size=entry["size"]) + txn.done.succeed() + else: + # Forward non-MMU transactions normally + yield from self._forward_txn(env, txn) + else: + yield from self._forward_txn(env, txn) diff --git a/src/kernbench/components/builtin_legacy/pe_scheduler.py b/src/kernbench/components/builtin_legacy/pe_scheduler.py new file mode 100644 index 0000000..daa7c3a --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_scheduler.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.common.pe_commands import PeInternalTxn + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeSchedulerComponent(ComponentBase): + """PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1). + + Receives PeInternalTxn from PE_CPU, routes to the appropriate engine: + - DmaReadCmd / DmaWriteCmd → PE_DMA + - GemmCmd → PE_GEMM + - MathCmd → PE_MATH + - CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2) + + Composite GEMM pipeline (32x64x32 tiles): + DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t) + with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) + + Applies scheduler overhead_ns before dispatching each command. + Non-PeInternalTxn messages are forwarded via inherited _forward_txn(). + """ + + # Scheduler tile dimensions (ADR-0014 D3.2) + TILE_M = 32 + TILE_K = 64 + TILE_N = 32 + + # Command → engine suffix dispatch table. + # New engines: add a single entry here (e.g. ConvCmd: "pe_conv"). + _CMD_DISPATCH: dict[type, str] = {} + + @classmethod + def _ensure_dispatch_table(cls) -> None: + if cls._CMD_DISPATCH: + return + from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd + + cls._CMD_DISPATCH = { + DmaReadCmd: "pe_dma", + DmaWriteCmd: "pe_dma", + GemmCmd: "pe_gemm", + MathCmd: "pe_math", + } + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._pe_prefix = node.id.rsplit(".", 1)[0] + self._ensure_dispatch_table() + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + from kernbench.common.pe_commands import PeInternalTxn + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, PeInternalTxn): + env.process(self._dispatch(env, msg)) + else: + yield from self._forward_txn(env, msg) + + def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + """Route a PeInternalTxn to the correct engine via dispatch table.""" + from kernbench.common.pe_commands import CompositeCmd + + # Scheduler overhead + yield from self.run(env, 0) + + cmd = pe_txn.command + + # Check dispatch table first + engine_suffix = self._CMD_DISPATCH.get(type(cmd)) + if engine_suffix is not None: + yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn) + return + + # CompositeCmd: tiled pipeline (not a simple forward) + if isinstance(cmd, CompositeCmd): + yield from self._dispatch_composite(env, pe_txn) + return + + # Unknown command — signal done immediately + pe_txn.done.succeed() + + def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: + """Composite tiled pipeline (ADR-0014 D3.2). + + GEMM: 3-stage pipeline with b-tile streaming from HBM. + MATH: sequential compute + DMA_WRITE (no tiling). + """ + from kernbench.common.pe_commands import CompositeCmd + + cmd = pe_txn.command + assert isinstance(cmd, CompositeCmd) + if cmd.op == "gemm" and cmd.b is not None: + yield from self._pipeline_gemm(env, pe_txn, cmd) + else: + yield from self._pipeline_math(env, pe_txn, cmd) + + def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: + """Tiled GEMM pipeline: stream b tiles from HBM, compute, write results. + + Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref). + Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t) + Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1) + """ + from kernbench.common.pe_commands import ( + DmaReadCmd, + DmaWriteCmd, + GemmCmd, + PeInternalTxn as PeTxn, + TensorHandle, + ) + + pp = self._pe_prefix + a = cmd.a # already in TCM + b = cmd.b # HBM reference (via tl.ref) + + M, K_a = a.shape[-2], a.shape[-1] + K_b, N = b.shape[-2], b.shape[-1] + dtype = a.dtype + dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2 + + # Tile counts + n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K) + n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N) + n_tiles = n_tiles_k * n_tiles_n + + prev_compute_done = None + prev_write_done = None + total_dma_ns = 0.0 + total_compute_ns = 0.0 + + for tile_idx in range(n_tiles): + tk = tile_idx // n_tiles_n + tn = tile_idx % n_tiles_n + + k_start = tk * self.TILE_K + n_start = tn * self.TILE_N + tile_k = min(self.TILE_K, K_a - k_start) + tile_n = min(self.TILE_N, N - n_start) + tile_nbytes = tile_k * tile_n * dtype_bytes + + # --- Stage 1: DMA_READ b_tile from HBM --- + read_done = env.event() + b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes + b_tile_handle = TensorHandle( + id=f"b_tile_{tile_idx}", addr=b_tile_addr, + shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes, + ) + read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes) + read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp) + t0 = env.now + yield self.out_ports[f"{pp}.pe_dma"].put(read_txn) + + # Wait for previous compute before starting this tile's compute + if prev_compute_done is not None: + yield prev_compute_done + + # Wait for this tile's DMA_READ + yield read_done + total_dma_ns += env.now - t0 + + # --- Stage 2: COMPUTE (GEMM) --- + compute_done = env.event() + out_handle = TensorHandle( + id=f"out_tile_{tile_idx}", addr=0, + shape=(M, tile_n), dtype=dtype, + nbytes=M * tile_n * dtype_bytes, + ) + compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle, + m=M, k=tile_k, n=tile_n) + compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) + t0 = env.now + yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn) + + # Wait for previous write (DMA_WRITE serialization) + if prev_write_done is not None: + yield prev_write_done + + # Wait for compute of THIS tile + yield compute_done + total_compute_ns += env.now - t0 + prev_compute_done = compute_done + + # --- Stage 3: DMA_WRITE out_tile to HBM --- + write_done = env.event() + out_tile_pa = cmd.out_addr + n_start * dtype_bytes + write_nbytes = M * tile_n * dtype_bytes + write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes) + write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) + t0 = env.now + yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) + prev_write_done = write_done + + # Wait for final write + if prev_write_done is not None: + t0 = env.now + yield prev_write_done + total_dma_ns += env.now - t0 + + pe_txn.result_data["dma_ns"] = total_dma_ns + pe_txn.result_data["compute_ns"] = total_compute_ns + pe_txn.done.succeed() + + def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator: + """Non-GEMM composite: sequential compute + DMA_WRITE (no tiling).""" + from kernbench.common.pe_commands import ( + DmaWriteCmd, + MathCmd, + PeInternalTxn as PeTxn, + ) + + pp = self._pe_prefix + + # Step 1: Compute (MATH) + compute_done = env.event() + compute_cmd = MathCmd( + op=cmd.math_op or "identity", + inputs=(cmd.a,), out=cmd.a, + ) + compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp) + yield self.out_ports[f"{pp}.pe_math"].put(compute_txn) + yield compute_done + + # Step 2: DMA_WRITE result to HBM + write_done = env.event() + write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes) + write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp) + yield self.out_ports[f"{pp}.pe_dma"].put(write_txn) + yield write_done + + pe_txn.done.succeed() diff --git a/src/kernbench/components/builtin_legacy/pe_tcm.py b/src/kernbench/components/builtin_legacy/pe_tcm.py new file mode 100644 index 0000000..6458d56 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/pe_tcm.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class PeTcmComponent(ComponentBase): + """PE_TCM: tightly-coupled memory / local SRAM staging buffer. + + Terminal storage component for PE-internal dataflow (ADR-0014 D5). + Phase 0: applies overhead_ns and drain_ns at terminal. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + + def run(self, env, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) diff --git a/src/kernbench/components/builtin_legacy/sram.py b/src/kernbench/components/builtin_legacy/sram.py new file mode 100644 index 0000000..d631ec4 --- /dev/null +++ b/src/kernbench/components/builtin_legacy/sram.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase +from kernbench.sim_engine.transaction import Transaction + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +class SramComponent(ComponentBase): + """Cube SRAM: terminal component that models SRAM access latency. + + Applies overhead_ns processing overhead (from node.attrs). + On completion, sends a ResponseMsg back on the reverse path. + """ + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) + yield env.timeout(overhead_ns) + + def _worker(self, env: simpy.Environment) -> Generator: + """Terminal worker: process, apply drain, send response.""" + while True: + txn: Any = yield self._inbox.get() + yield from self.run(env, txn.nbytes) + drain = getattr(txn, "drain_ns", 0.0) + if drain > 0: + yield env.timeout(drain) + yield from self._send_response(env, txn) + + def _send_response(self, env: simpy.Environment, txn: Any) -> Generator: + """Create ResponseMsg and send on reverse path.""" + reverse_path = list(reversed(txn.path)) + if len(reverse_path) >= 2 and self.ctx: + from kernbench.runtime_api.kernel import ResponseMsg + + parts = self.node.id.split(".") + cube_id = int(parts[1].replace("cube", "")) + resp_msg = ResponseMsg( + correlation_id=txn.request.correlation_id, + request_id=txn.request.request_id, + src_cube=cube_id, src_pe=-1, success=True, + ) + resp_txn = Transaction( + request=resp_msg, path=reverse_path, step=0, + nbytes=0, done=env.event(), is_response=True, + ) + yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) + else: + txn.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel_legacy/__init__.py b/src/kernbench/components/custom/pe_accel_legacy/__init__.py new file mode 100644 index 0000000..60b3647 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/__init__.py @@ -0,0 +1,20 @@ +"""PeAccel: cycle-accurate accelerator component for pe_scheduler slot. + +Register in components.yaml as: + pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component + +Then reference in topology.yaml: + pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { ... } } + +Package layout: + scheduler/ — scheduler block (component + dispatch + tiling) + scheduler.py — SchedulerV2Component + gemm_pipeline.py — tiled GEMM coordinator + math_pipeline.py — tiled element-wise math coordinator + tile_address.py — per-tile address computation + blocks/ — hardware blocks (DMA_IN, DMA_WB, GEMM, MATH, TCM) + types.py — data classes (descriptors, triggers, tile commands) +""" +from kernbench.components.custom.pe_accel.scheduler import SchedulerV2Component + +__all__ = ["SchedulerV2Component"] diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/__init__.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/__init__.py new file mode 100644 index 0000000..9864a31 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/__init__.py @@ -0,0 +1,16 @@ +"""Hardware blocks for pe_accel. + +Each block is a concurrent SimPy process modeling one functional unit: + - DmaInBlock: HBM → TCM tile reads (issues real DmaReadCmd to PE_DMA) + - DmaWbBlock: TCM → HBM tile writes (issues real DmaWriteCmd to PE_DMA) + - GemmBlock: 2-stage MAC pipeline (fetch + compute) + - MathBlock: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.) + - TcmBlock: TCM access serialization with BW-based timing +""" +from kernbench.components.custom.pe_accel.blocks.dma_in import DmaInBlock +from kernbench.components.custom.pe_accel.blocks.dma_wb import DmaWbBlock +from kernbench.components.custom.pe_accel.blocks.gemm import GemmBlock +from kernbench.components.custom.pe_accel.blocks.math import MathBlock +from kernbench.components.custom.pe_accel.blocks.tcm import TcmBlock, TcmRequest + +__all__ = ["DmaInBlock", "DmaWbBlock", "GemmBlock", "MathBlock", "TcmBlock", "TcmRequest"] diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_in.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_in.py new file mode 100644 index 0000000..979fcc7 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_in.py @@ -0,0 +1,96 @@ +"""DMA IN Block: reads tiles from HBM into TCM via real PE_DMA fabric. + +Flow per tile: + 1. Receive DmaRequest from tiling pipeline + 2. Look up DmaInDescriptor for address and size + 3. Issue DmaReadCmd → PE_DMA → fabric → HBM controller → response + 4. Route completion Trigger to next block (GEMM, MATH, or COMPLETION) + +Timing is real fabric latency (not analytical) — includes BW contention, +propagation delay, and HBM controller serialization. +""" +from __future__ import annotations + +import simpy + +from kernbench.components.custom.pe_accel.types import DmaInDescriptor, DmaRequest, Trigger + + +class DmaInBlock: + """HBM → TCM tile reader. Shared across all concurrent pipelines. + + Pipelines pre-load DmaInDescriptors keyed by (pipeline_id, tile_id, operand). + The _load_loop process reads DmaRequests, issues real DmaReadCmd to PE_DMA, + and routes completion triggers based on descriptor.next_block. + """ + + def __init__( + self, + env: simpy.Environment, + cmd_q: simpy.Store, + to_fetch_trig: simpy.Store, + to_math_trig: simpy.Store, + completion_q: simpy.Store, + *, + pe_dma_port: simpy.Store | None, + pe_prefix: str, + ) -> None: + self.env = env + self.cmd_q = cmd_q + self.to_fetch_trig = to_fetch_trig + self.to_math_trig = to_math_trig + self.completion_q = completion_q + self._pe_dma_port = pe_dma_port + self._pe_prefix = pe_prefix + self._descriptor_table: dict[tuple[int, int, str], DmaInDescriptor] = {} + + # Per-pipeline timing histogram (keyed by pipeline_id) + self.t_dma_read_per_request: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, DmaInDescriptor]) -> None: + """Pre-load per-operand DMA descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _load_loop(self): + """Main process: receive DmaRequests, issue DmaReadCmd, route triggers.""" + from kernbench.common.pe_commands import DmaReadCmd, PeInternalTxn, TensorHandle + + while True: + req: DmaRequest = yield self.cmd_q.get() + if req is None: + break + + desc = self._descriptor_table[(req.pipeline_id, req.tile_id, req.operand)] + + # Issue real DMA read through PE_DMA → fabric → HBM + read_done = self.env.event() + handle = TensorHandle( + id=f"accel_rd_{req.pipeline_id}_{req.tile_id}_{req.operand}", + addr=desc.src_addr, + shape=(desc.size_bytes,), + dtype="uint8", + nbytes=desc.size_bytes, + ) + txn = PeInternalTxn( + command=DmaReadCmd(handle=handle, src_addr=desc.src_addr, nbytes=desc.size_bytes), + done=read_done, + pe_prefix=self._pe_prefix, + ) + t0 = self.env.now + yield self._pe_dma_port.put(txn) + yield read_done + self.t_dma_read_per_request.setdefault(req.pipeline_id, []).append(self.env.now - t0) + + # Route trigger to next block + trig = Trigger( + tile_id=req.tile_id, + pipeline_id=req.pipeline_id, + vc=0 if req.operand == "A" else 1, + source_block="DMA_IN", + ) + if desc.next_block == "MATH": + yield self.to_math_trig.put(trig) + elif desc.next_block == "COMPLETION": + yield self.completion_q.put(trig) + else: # "GEMM" (default) + yield self.to_fetch_trig.put(trig) diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_wb.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_wb.py new file mode 100644 index 0000000..538eca6 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/dma_wb.py @@ -0,0 +1,88 @@ +"""DMA Writeback Block: writes result tiles from TCM to HBM via real PE_DMA fabric. + +Flow per tile: + 1. Receive flush Trigger from GEMM or MATH block + 2. Look up DmaWBDescriptor for address and tile size + 3. Issue DmaWriteCmd → PE_DMA → fabric → HBM controller → response + 4. Send completion Trigger to pipeline + +Two _flush_loop processes run concurrently: + - One drains GEMM → DMA_WB triggers (direct writeback path) + - One drains MATH → DMA_WB triggers (K-accumulation or element-wise flush) +""" +from __future__ import annotations + +import simpy + +from kernbench.components.custom.pe_accel.types import DmaWBDescriptor, Trigger + + +class DmaWbBlock: + """TCM → HBM tile writer. Shared across all concurrent pipelines. + + Pipelines pre-load DmaWBDescriptors keyed by (pipeline_id, tile_id). + Each _flush_loop process reads triggers, issues real DmaWriteCmd to PE_DMA, + and forwards completion to the pipeline's reply queue. + """ + + def __init__( + self, + env: simpy.Environment, + completion_q: simpy.Store, + *, + pe_dma_port: simpy.Store | None, + pe_prefix: str, + bytes_per_element: int, + ) -> None: + self.env = env + self.completion_q = completion_q + self._pe_dma_port = pe_dma_port + self._pe_prefix = pe_prefix + self._bpe = bytes_per_element + self._descriptor_table: dict[tuple[int, int], DmaWBDescriptor] = {} + + # Per-pipeline timing histogram (keyed by pipeline_id) + self.t_dma_write_per_tile: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, DmaWBDescriptor]) -> None: + """Pre-load per-tile writeback descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _flush_loop(self, trig_q: simpy.Store): + """Main process: receive flush triggers, issue DmaWriteCmd, send completion.""" + from kernbench.common.pe_commands import DmaWriteCmd, PeInternalTxn, TensorHandle + + while True: + trigger: Trigger = yield trig_q.get() + if trigger is None: + break + + pid = trigger.pipeline_id + tile_id = trigger.tile_id + desc = self._descriptor_table.get((pid, tile_id)) + + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + + # Issue real DMA write through PE_DMA → fabric → HBM + write_done = self.env.event() + handle = TensorHandle( + id=f"accel_wb_{pid}_{tile_id}", + addr=desc.dst_addr, + shape=(desc.Tm, desc.Tn), + dtype="float16", + nbytes=c_bytes, + ) + txn = PeInternalTxn( + command=DmaWriteCmd(handle=handle, dst_addr=desc.dst_addr, nbytes=c_bytes), + done=write_done, + pe_prefix=self._pe_prefix, + ) + t0 = self.env.now + yield self._pe_dma_port.put(txn) + yield write_done + self.t_dma_write_per_tile.setdefault(pid, []).append(self.env.now - t0) + + yield self.completion_q.put( + Trigger(tile_id=tile_id, pipeline_id=pid, source_block="DMA_WB") + ) diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/gemm.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/gemm.py new file mode 100644 index 0000000..4ac6d2b --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/gemm.py @@ -0,0 +1,160 @@ +"""GEMM Block: 2-stage MAC pipeline (fetch + compute). + +Stage 1 — Fetch (_fetch_stage): + Collects DMA completion triggers (one per operand per tile). + When all operands arrive, issues TCM read request (SPMem → MAC registers). + +Stage 2 — Compute (_gemm_stage): + Models MAC array computation time. + Issues TCM write request (MAC result → SPMem). + Routes output trigger to MathBlock, DmaWbBlock, or completion. + +TCM access goes through TcmBlock for real BW serialization. +MAC compute time is cycle-accurate: ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n). +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest +from kernbench.components.custom.pe_accel.types import GemmDescriptor, Trigger + + +class GemmBlock: + """2-stage MAC pipeline shared across all concurrent pipelines. + + Pipelines pre-load GemmDescriptors keyed by (pipeline_id, tile_id). + Two SimPy processes run concurrently: _fetch_stage and _gemm_stage. + """ + + def __init__( + self, + env: simpy.Environment, + trig_in: simpy.Store, + fetch_to_gemm_trig: simpy.Store, + to_math_trig: simpy.Store, + to_dmaWB_trig: simpy.Store, + completion_q: simpy.Store, + *, + tcm_port: simpy.Store, + mac_m: int, + mac_k: int, + mac_n: int, + bytes_per_element: int, + clock_freq_ghz: float, + ) -> None: + self.env = env + self.trig_in = trig_in + self.fetch_to_gemm_trig = fetch_to_gemm_trig + self.to_math_trig = to_math_trig + self.to_dmaWB_trig = to_dmaWB_trig + self.completion_q = completion_q + self._tcm_port = tcm_port + + self._mac_m = mac_m + self._mac_k = mac_k + self._mac_n = mac_n + self._bpe = bytes_per_element + self._freq = clock_freq_ghz + + self._descriptor_table: dict[tuple[int, int], GemmDescriptor] = {} + + # Per-pipeline timing histograms + self.t_tcm_load_per_tile: dict[int, list[float]] = {} + self.t_compute_per_tile: dict[int, list[float]] = {} + + def load_descriptors(self, descs: dict[tuple, GemmDescriptor]) -> None: + """Pre-load per-tile GEMM descriptors (cumulative across pipelines).""" + self._descriptor_table.update(descs) + + def _compute_ns(self, desc: GemmDescriptor) -> float: + """MAC array compute time for one tile (ns).""" + cycles = ceil(desc.Tm / self._mac_m) * ceil(desc.Tk / self._mac_k) * ceil(desc.Tn / self._mac_n) + return cycles / self._freq + + # -- Stage 1: Fetch (TCM → MAC load) -------------------------------------- + + def _fetch_stage(self): + """Collect DMA triggers per tile, issue TCM read for operand load.""" + pending: dict[tuple[int, int], list[Trigger]] = {} + while True: + trigger = yield self.trig_in.get() + if trigger is None: + yield self.fetch_to_gemm_trig.put(None) + break + + key = (trigger.pipeline_id, trigger.tile_id) + pending.setdefault(key, []).append(trigger) + + desc = self._descriptor_table.get(key) + needed = desc.triggers_needed if desc else 2 + if len(pending[key]) < needed: + continue + + del pending[key] + + # TCM load: read A and B tile data from SPMem → MAC registers + if desc and desc.gemm_load: + a_bytes = desc.Tm * desc.Tk * self._bpe + b_bytes = desc.Tk * desc.Tn * self._bpe + load_bytes = a_bytes + b_bytes + + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("read", load_bytes, done, tag="gemm_load")) + yield done + self.t_tcm_load_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + yield self.fetch_to_gemm_trig.put(trigger) + + # -- Stage 2: Compute (MAC array) + Store (MAC → TCM) --------------------- + + def _gemm_stage(self): + """MAC computation, then TCM store, then route to next block.""" + while True: + trigger = yield self.fetch_to_gemm_trig.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._descriptor_table.get(key) + + # MAC compute + if desc and desc.gemm_compute: + t_compute = self._compute_ns(desc) + t0 = self.env.now + if t_compute > 0: + yield self.env.timeout(t_compute) + self.t_compute_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # Route output + route = desc.next_block if desc else "MATH" + out_trig = Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="GEMM", + ) + + if route == "MATH": + yield self.to_math_trig.put(out_trig) + elif route == "DMAWB": + # TCM store before writeback + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store")) + yield done + yield self.to_dmaWB_trig.put(out_trig) + else: # "DONE" — C stays in SPMem, no flush + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store")) + yield done + yield self.completion_q.put(out_trig) diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/math.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/math.py new file mode 100644 index 0000000..e1f5305 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/math.py @@ -0,0 +1,181 @@ +"""Math Block: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.). + +Two concurrent processing modes: + + 1. K-accumulation (_run_k_accumulation): + Receives triggers from GemmBlock after each K-tile compute. + Issues TCM write for partial-result store. + On final K-tile, routes to DMA_WB or completion. + + 2. Element-wise ops (_run_element_wise): + Receives triggers from DMA_IN after each tile read. + Issues TCM read (load input), compute (SIMD), TCM write (store result). + Routes output to DMA_WB for writeback. + +TCM access goes through TcmBlock for real BW serialization. +SIMD compute time: ceil(num_elements / vector_width) / clock_freq. +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest +from kernbench.components.custom.pe_accel.types import MathDescriptor, MathOpDescriptor, Trigger + + +class MathBlock: + """K-accumulation + element-wise math unit. + + Descriptor tables: + - _accum_table: MathDescriptor (pipeline_id, tile_id) — for GEMM K-accumulation + - _elemwise_table: MathOpDescriptor (pipeline_id, tile_id) — for element-wise ops + """ + + def __init__( + self, + env: simpy.Environment, + trig_in: simpy.Store, + to_dmaWB_trig: simpy.Store, + completion_q: simpy.Store, + *, + tcm_port: simpy.Store, + bytes_per_element: int, + clock_freq_ghz: float, + vector_width: int = 256, + ) -> None: + self.env = env + self.trig_in = trig_in # from GemmBlock (K-accumulation) + self.to_dmaWB_trig = to_dmaWB_trig + self.completion_q = completion_q + self._tcm_port = tcm_port + + self._bpe = bytes_per_element + self._freq = clock_freq_ghz + self._vector_width = vector_width + + # Descriptor tables + self._accum_table: dict[tuple[int, int], MathDescriptor] = {} + self._elemwise_table: dict[tuple[int, int], MathOpDescriptor] = {} + + # -- Timing histograms (per pipeline_id) -- + + # K-accumulation + self.t_tcm_store_per_tile: dict[int, list[float]] = {} + + # Element-wise ops + self.t_math_op_load_per_tile: dict[int, list[float]] = {} + self.t_math_op_compute_per_tile: dict[int, list[float]] = {} + self.t_math_op_store_per_tile: dict[int, list[float]] = {} + + # -- Descriptor loading ---------------------------------------------------- + + def load_descriptors(self, descs: dict[tuple, MathDescriptor]) -> None: + """Pre-load K-accumulation descriptors (cumulative across pipelines).""" + self._accum_table.update(descs) + + def load_math_op_descriptors(self, descs: dict[tuple, MathOpDescriptor]) -> None: + """Pre-load element-wise op descriptors (cumulative across pipelines).""" + self._elemwise_table.update(descs) + + # -- Mode 1: K-accumulation ------------------------------------------------ + + def _run(self): + """Backward-compat alias.""" + yield from self._run_k_accumulation() + + def _run_k_accumulation(self): + """Receive GEMM output triggers, TCM store partial result, flush on last-K.""" + while True: + trigger = yield self.trig_in.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._accum_table.get(key) + + # TCM store: write partial sum to SPMem + if desc: + c_bytes = desc.Tm * desc.Tn * self._bpe + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="k_accum_store")) + yield done + self.t_tcm_store_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + if not desc or not desc.is_last_k: + continue # intermediate K-tile: store done, no flush yet + + out_trig = Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="MATH", + ) + if desc.skip_dmaWB: + yield self.completion_q.put(out_trig) + else: + yield self.to_dmaWB_trig.put(out_trig) + + # -- Mode 2: Element-wise ops ---------------------------------------------- + + def _run_math_op(self, trig_q: simpy.Store): + """Backward-compat alias.""" + yield from self._run_element_wise(trig_q) + + def _run_element_wise(self, trig_q: simpy.Store): + """Receive DMA_IN triggers, apply element-wise op via TCM, route to DMA_WB. + + Per tile: + 1. TCM read — load input tile from SPMem to SIMD + 2. Compute — SIMD operation (exp/log/etc.) + 3. TCM write — store result from SIMD to SPMem + 4. Route to DMA_WB + """ + while True: + trigger = yield trig_q.get() + if trigger is None: + break + + key = (trigger.pipeline_id, trigger.tile_id) + desc = self._elemwise_table.get(key) + + if desc: + tile_bytes = desc.Tm * desc.Tn * self._bpe + num_elements = desc.Tm * desc.Tn + + # 1. TCM read + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("read", tile_bytes, done, tag="elemwise_load")) + yield done + self.t_math_op_load_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # 2. SIMD compute + t0 = self.env.now + compute_cycles = ceil(num_elements / self._vector_width) + compute_ns = compute_cycles / self._freq + if compute_ns > 0: + yield self.env.timeout(compute_ns) + self.t_math_op_compute_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + # 3. TCM write + t0 = self.env.now + done = self.env.event() + yield self._tcm_port.put(TcmRequest("write", tile_bytes, done, tag="elemwise_store")) + yield done + self.t_math_op_store_per_tile.setdefault(trigger.pipeline_id, []).append( + self.env.now - t0 + ) + + yield self.to_dmaWB_trig.put(Trigger( + tile_id=trigger.tile_id, + pipeline_id=trigger.pipeline_id, + source_block="MATH_OP", + )) diff --git a/src/kernbench/components/custom/pe_accel_legacy/blocks/tcm.py b/src/kernbench/components/custom/pe_accel_legacy/blocks/tcm.py new file mode 100644 index 0000000..d5d098b --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/blocks/tcm.py @@ -0,0 +1,80 @@ +"""TCM Block: tightly-coupled memory with BW-based access serialization. + +Models SPMem (scratchpad memory) inside the PE. Compute blocks (GEMM, MATH) +send TcmRequests for load/store operations. The TCM block serializes access +per channel and computes timing based on data size and bandwidth. + +Two channels: + - READ (SPMem → compute unit): models operand fetch for MAC/SIMD + - WRITE (compute unit → SPMem): models result store from MAC/SIMD + +Each channel has capacity=1: concurrent reads serialize, concurrent writes +serialize, but a read and a write can proceed in parallel. +""" +from __future__ import annotations + +from dataclasses import dataclass + +import simpy + + +@dataclass +class TcmRequest: + """Request to read from or write to TCM.""" + + direction: str # "read" or "write" + nbytes: int + done: simpy.Event + tag: str = "" # optional label for debugging + + +class TcmBlock: + """BW-serialized TCM model with dual read/write channels. + + Args: + env: SimPy environment. + read_bw_gbs: read bandwidth in GB/s (SPMem → compute). + write_bw_gbs: write bandwidth in GB/s (compute → SPMem). + """ + + def __init__( + self, + env: simpy.Environment, + read_bw_gbs: float = 512.0, + write_bw_gbs: float = 512.0, + ) -> None: + self.env = env + self._read_bw = read_bw_gbs + self._write_bw = write_bw_gbs + self._read_res = simpy.Resource(env, capacity=1) + self._write_res = simpy.Resource(env, capacity=1) + self._port: simpy.Store = simpy.Store(env) + + @property + def port(self) -> simpy.Store: + """The SimPy Store that blocks send TcmRequests to.""" + return self._port + + def _run(self): + """Main process: receive TcmRequests, dispatch to channel processes.""" + while True: + req: TcmRequest = yield self._port.get() + if req is None: + break + self.env.process(self._handle(req)) + + def _handle(self, req: TcmRequest): + """Acquire channel, apply BW-based delay, signal done.""" + if req.direction == "write": + res = self._write_res + bw = self._write_bw + else: + res = self._read_res + bw = self._read_bw + + with res.request() as lock: + yield lock + if bw > 0 and req.nbytes > 0: + delay_ns = req.nbytes / bw + yield self.env.timeout(delay_ns) + req.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel_legacy/scheduler/__init__.py b/src/kernbench/components/custom/pe_accel_legacy/scheduler/__init__.py new file mode 100644 index 0000000..885fc36 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/scheduler/__init__.py @@ -0,0 +1,10 @@ +"""Scheduler: accelerator component + dispatch + tiling pipelines. + + scheduler.py — SchedulerV2Component (init, wiring, dispatch, metrics) + gemm_pipeline.py — GemmPipeline (tiled GEMM coordinator) + math_pipeline.py — MathPipeline (tiled element-wise math coordinator) + tile_address.py — per-tile address computation +""" +from kernbench.components.custom.pe_accel.scheduler.scheduler import SchedulerV2Component + +__all__ = ["SchedulerV2Component"] diff --git a/src/kernbench/components/custom/pe_accel_legacy/scheduler/gemm_pipeline.py b/src/kernbench/components/custom/pe_accel_legacy/scheduler/gemm_pipeline.py new file mode 100644 index 0000000..afce12d --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/scheduler/gemm_pipeline.py @@ -0,0 +1,157 @@ +"""GEMM Tiling Pipeline: splits (M,K)×(K,N) into tiles and coordinates execution. + +Flow per tile: + DMA_IN(A tile) + DMA_IN(B tile) → GEMM(fetch + compute) → MATH(K-accum) → DMA_WB + +The pipeline does NOT own hardware blocks — it uses the component's shared +blocks via descriptor tables and SimPy queues. + +Constructor starts two SimPy processes: + - _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q + - _collect_completions(): waits for all output tiles to flush +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.scheduler.tiling import generate_gemm_tiles +from kernbench.components.custom.pe_accel.types import ( + CmdType, + DmaInDescriptor, + DmaRequest, + DmaWBDescriptor, + GemmDescriptor, + MathDescriptor, + Trigger, +) + + +class GemmPipeline: + """Coordinates one tiled GEMM operation across shared hardware blocks.""" + + def __init__( + self, + env: simpy.Environment, + M: int, K: int, N: int, + tile_m: int, tile_k: int, tile_n: int, + bytes_per_element: int, + pipeline_id: int, + reply_queue: simpy.Store, + dmaIN_cmd_q: simpy.Store, + dmaIN_to_fetch_trig: simpy.Store, + A_addr: int = 0, + B_addr: int = 0, + C_addr: int = 0, + dma_a: bool = True, + dma_b: bool = True, + dma_c: bool = True, + ) -> None: + self.env = env + self.M, self.K, self.N = M, K, N + self.pipeline_id = pipeline_id + self.reply_queue = reply_queue + self.dmaIN_cmd_q = dmaIN_cmd_q + self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig + + self._dma_a = dma_a + self._dma_b = dma_b + self._skip_dmaWB = not dma_c + + _Tm = min(tile_m, M) + _Tk = min(tile_k, K) + _Tn = min(tile_n, N) + + self.M_tiles = ceil(M / tile_m) + self.K_tiles = ceil(K / tile_k) + self.N_tiles = ceil(N / tile_n) + + triggers_per_tile = 2 if (dma_a and dma_b) else 1 + + # Generate tile schedule with pre-computed addresses + self.schedule = generate_gemm_tiles( + self.M_tiles, self.K_tiles, self.N_tiles, + M=M, K=K, N=N, + tile_m=_Tm, tile_k=_Tk, tile_n=_Tn, + bytes_per_element=bytes_per_element, + A_addr=A_addr, B_addr=B_addr, C_addr=C_addr, + pipeline_id=pipeline_id, + ) + + # Build descriptor tables for shared blocks + pid = pipeline_id + a_tile_bytes = _Tm * _Tk * bytes_per_element + b_tile_bytes = _Tk * _Tn * bytes_per_element + + self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {} + self.gemm_descs: dict[tuple, GemmDescriptor] = {} + self.math_descs: dict[tuple, MathDescriptor] = {} + self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {} + + for cmd in self.schedule.commands: + if cmd.cmd_type != CmdType.DMA_LOAD: + continue + t = cmd.tile_id + + if dma_a: + self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor( + size_bytes=a_tile_bytes, src_addr=cmd.a_tile_addr + ) + if dma_b: + self.dmaIN_descs[(pid, t, "B")] = DmaInDescriptor( + size_bytes=b_tile_bytes, src_addr=cmd.b_tile_addr + ) + + self.gemm_descs[(pid, t)] = GemmDescriptor( + Tm=_Tm, Tk=_Tk, Tn=_Tn, + triggers_needed=triggers_per_tile, + next_block="MATH", + ) + + self.math_descs[(pid, t)] = MathDescriptor( + Tm=_Tm, Tn=_Tn, + is_last_k=cmd.is_last_k, + skip_dmaWB=self._skip_dmaWB, + ) + if not self._skip_dmaWB and cmd.is_last_k: + self.dmaWB_descs[(pid, t)] = DmaWBDescriptor( + Tm=_Tm, Tn=_Tn, dst_addr=cmd.c_tile_addr + ) + + self.expected_flushes = self.M_tiles * self.N_tiles + self.completed_flushes = 0 + self.done_at: int = 0 + self.done: simpy.Event = env.event() + + env.process(self._feed_commands()) + env.process(self._collect_completions()) + + def _feed_commands(self): + """Send DmaRequests for each tile's operands to dmaIN_cmd_q.""" + for cmd in self.schedule.commands: + if cmd.cmd_type != CmdType.DMA_LOAD: + continue + + if self._dma_a: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A", + )) + if self._dma_b: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="B", + )) + + if not self._dma_a and not self._dma_b: + yield self.dmaIN_to_fetch_trig.put(Trigger( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE", + )) + + def _collect_completions(self): + """Wait for all output tile flush completions, then signal done.""" + while self.completed_flushes < self.expected_flushes: + yield self.reply_queue.get() + self.completed_flushes += 1 + + self.done_at = int(self.env.now) + self.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel_legacy/scheduler/math_pipeline.py b/src/kernbench/components/custom/pe_accel_legacy/scheduler/math_pipeline.py new file mode 100644 index 0000000..146b9b5 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/scheduler/math_pipeline.py @@ -0,0 +1,132 @@ +"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution. + +Flow per tile: + DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile) + +Mirrors GemmTilingPipeline but for unary element-wise operations. +Pipeline overlap across tiles: while one tile is in MATH_OP, the next +tile's DMA_IN can proceed concurrently. + +Constructor starts two SimPy processes: + - _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q + - _collect_completions(): waits for all tiles to writeback +""" +from __future__ import annotations + +from math import ceil + +import simpy + +from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles +from kernbench.components.custom.pe_accel.types import ( + DmaInDescriptor, + DmaRequest, + DmaWBDescriptor, + MathOpDescriptor, + Trigger, +) + + +class MathPipeline: + """Coordinates one tiled element-wise math operation across shared blocks.""" + + def __init__( + self, + env: simpy.Environment, + M: int, N: int, + tile_m: int, tile_n: int, + bytes_per_element: int, + pipeline_id: int, + reply_queue: simpy.Store, + dmaIN_cmd_q: simpy.Store, + dmaIN_to_fetch_trig: simpy.Store, + op: str = "exp", + src_addr: int = 0, + dst_addr: int = 0, + dma_in: bool = True, + dma_out: bool = True, + ) -> None: + self.env = env + self.M, self.N = M, N + self.pipeline_id = pipeline_id + self.reply_queue = reply_queue + self.dmaIN_cmd_q = dmaIN_cmd_q + self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig + self.op = op + + self._dma_in = dma_in + self._skip_dmaWB = not dma_out + + _Tm = min(tile_m, M) + _Tn = min(tile_n, N) + + self.M_tiles = ceil(M / tile_m) + self.N_tiles = ceil(N / tile_n) + + # Generate tile schedule with pre-computed addresses + self.schedule = generate_math_tiles( + self.M_tiles, self.N_tiles, + M=M, N=N, + tile_m=_Tm, tile_n=_Tn, + bytes_per_element=bytes_per_element, + src_addr=src_addr, dst_addr=dst_addr, + pipeline_id=pipeline_id, + ) + + # Build descriptor tables for shared blocks + pid = pipeline_id + tile_bytes = _Tm * _Tn * bytes_per_element + + self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {} + self.math_op_descs: dict[tuple, MathOpDescriptor] = {} + self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {} + + for cmd in self.schedule.commands: + t = cmd.tile_id + + if dma_in: + self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor( + size_bytes=tile_bytes, + src_addr=cmd.src_tile_addr, + next_block="MATH", + ) + + self.math_op_descs[(pid, t)] = MathOpDescriptor( + Tm=_Tm, Tn=_Tn, op=op, + src_addr=cmd.src_tile_addr, + dst_addr=cmd.dst_tile_addr, + ) + + if not self._skip_dmaWB: + self.dmaWB_descs[(pid, t)] = DmaWBDescriptor( + Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr, + ) + + self.expected_flushes = self.M_tiles * self.N_tiles + self.completed_flushes = 0 + self.done_at: int = 0 + self.done: simpy.Event = env.event() + + env.process(self._feed_commands()) + env.process(self._collect_completions()) + + def _feed_commands(self): + """Send DmaRequests for each tile's input to dmaIN_cmd_q.""" + for cmd in self.schedule.commands: + if self._dma_in: + yield self.dmaIN_cmd_q.put(DmaRequest( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A", + )) + else: + yield self.dmaIN_to_fetch_trig.put(Trigger( + tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE", + )) + + def _collect_completions(self): + """Wait for all tile completions, then signal done.""" + while self.completed_flushes < self.expected_flushes: + yield self.reply_queue.get() + self.completed_flushes += 1 + + self.done_at = int(self.env.now) + self.done.succeed() diff --git a/src/kernbench/components/custom/pe_accel_legacy/scheduler/scheduler.py b/src/kernbench/components/custom/pe_accel_legacy/scheduler/scheduler.py new file mode 100644 index 0000000..a803193 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/scheduler/scheduler.py @@ -0,0 +1,434 @@ +"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2). + +Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2). + +Hosts four internal hardware blocks as concurrent SimPy processes: + - DmaInBlock — HBM → TCM tile reads (real fabric DMA) + - DmaWbBlock — TCM → HBM tile writes (real fabric DMA) + - GemmBlock — 2-stage MAC pipeline (fetch + compute) + - MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.) + +Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline: + - DmaReadCmd / DmaWriteCmd → PE_DMA out_port + - GemmCmd → PE_GEMM out_port + - MathCmd → PE_MATH out_port + - CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum) + - CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA) + - PeCpuOverheadCmd → yield timeout + +Config via node.attrs (all optional): + overhead_ns, clock_freq_ghz, bytes_per_element, + mac_m, mac_k, mac_n, tile_m, tile_k, tile_n, + port_a_bw, port_b_bw, store_port_bw, vector_width +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy + +from kernbench.components.base import ComponentBase + +if TYPE_CHECKING: + from kernbench.components.context import ComponentContext + from kernbench.topology.types import Node + + +# ============================================================================== +# Component +# ============================================================================== + + +class SchedulerV2Component(ComponentBase): + """PE accelerator scheduler: wires internal blocks and dispatches commands.""" + + def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: + super().__init__(node, ctx) + self._pe_prefix: str = node.id.rsplit(".", 1)[0] + attrs = node.attrs + + # Hardware config + self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0)) + self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0)) + self._bpe: int = int(attrs.get("bytes_per_element", 2)) + + # MAC array dimensions + self._mac_m: int = int(attrs.get("mac_m", 8)) + self._mac_k: int = int(attrs.get("mac_k", 16)) + self._mac_n: int = int(attrs.get("mac_n", 32)) + + # Tile dimensions + self._tile_m: int = int(attrs.get("tile_m", 32)) + self._tile_k: int = int(attrs.get("tile_k", 32)) + self._tile_n: int = int(attrs.get("tile_n", 32)) + + # Bandwidth (bytes/cycle) + self._port_a_bw: int = int(attrs.get("port_a_bw", 256)) + self._port_b_bw: int = int(attrs.get("port_b_bw", 256)) + self._store_port_bw: int = int(attrs.get("store_port_bw", 256)) + self._vector_width: int = int(attrs.get("vector_width", 256)) + + # Initialized in start() + self._dmaIN_block: Any = None + self._dmaWB_block: Any = None + self._gemm_block: Any = None + self._math_block: Any = None + self._dmaIN_cmd_q: simpy.Store | None = None + self._dmaIN_to_fetch_trig: simpy.Store | None = None + + # Pipeline tracking + self._next_pipeline_id: int = 0 + self._pipeline_queues: dict[int, simpy.Store] = {} + + # -- SimPy lifecycle ------------------------------------------------------- + + def run(self, env: simpy.Environment, nbytes: int) -> Generator: + """Scheduler overhead per dispatch.""" + yield env.timeout(self._overhead_ns) + + def start(self, env: simpy.Environment) -> None: + """Create internal blocks, wire queues, start SimPy processes.""" + from kernbench.components.custom.pe_accel.blocks import ( + DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock, + ) + + pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma") + + # -- TCM block (shared BW-serialized scratchpad) ----------------------- + + # Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm) + pe_links = {} + if self.ctx and self.ctx.spec: + pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {}) + tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0)) + tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0)) + + self._tcm_block = TcmBlock( + env=env, + read_bw_gbs=tcm_read_bw, + write_bw_gbs=tcm_write_bw, + ) + + # -- Internal queues --------------------------------------------------- + + self._dmaIN_cmd_q = simpy.Store(env) + self._dmaIN_to_fetch_trig = simpy.Store(env) + dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math + fetch_to_gemm_trig = simpy.Store(env) + gemm_to_math_trig = simpy.Store(env) + gemm_to_dmaWB_trig = simpy.Store(env) + math_to_dmaWB_trig = simpy.Store(env) + + # Completion queues (block → _completion_router → pipeline reply_q) + dmaIN_completion_q = simpy.Store(env) + dmaWB_completion_q = simpy.Store(env) + gemm_completion_q = simpy.Store(env) + math_completion_q = simpy.Store(env) + + # -- Create blocks ----------------------------------------------------- + + self._dmaIN_block = DmaInBlock( + env=env, + cmd_q=self._dmaIN_cmd_q, + to_fetch_trig=self._dmaIN_to_fetch_trig, + to_math_trig=dmaIN_to_math_trig, + completion_q=dmaIN_completion_q, + pe_dma_port=pe_dma_port, + pe_prefix=self._pe_prefix, + ) + + self._dmaWB_block = DmaWbBlock( + env=env, + completion_q=dmaWB_completion_q, + pe_dma_port=pe_dma_port, + pe_prefix=self._pe_prefix, + bytes_per_element=self._bpe, + ) + + self._gemm_block = GemmBlock( + env=env, + trig_in=self._dmaIN_to_fetch_trig, + fetch_to_gemm_trig=fetch_to_gemm_trig, + to_math_trig=gemm_to_math_trig, + to_dmaWB_trig=gemm_to_dmaWB_trig, + completion_q=gemm_completion_q, + tcm_port=self._tcm_block.port, + mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n, + bytes_per_element=self._bpe, + clock_freq_ghz=self._clock_freq_ghz, + ) + + self._math_block = MathBlock( + env=env, + trig_in=gemm_to_math_trig, + to_dmaWB_trig=math_to_dmaWB_trig, + completion_q=math_completion_q, + tcm_port=self._tcm_block.port, + bytes_per_element=self._bpe, + clock_freq_ghz=self._clock_freq_ghz, + vector_width=self._vector_width, + ) + + # -- Start block processes --------------------------------------------- + + env.process(self._tcm_block._run()) + env.process(self._dmaIN_block._load_loop()) + env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig)) + env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig)) + env.process(self._gemm_block._fetch_stage()) + env.process(self._gemm_block._gemm_stage()) + env.process(self._math_block._run_k_accumulation()) + env.process(self._math_block._run_element_wise(dmaIN_to_math_trig)) + + # Wire in-ports → inbox, start _worker + super().start(env) + + # Start completion routers + for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q): + env.process(self._completion_router(q)) + + # -- Internal processes ---------------------------------------------------- + + def _completion_router(self, queue: simpy.Store) -> Generator: + """Route block completion triggers to the correct pipeline's reply queue.""" + while True: + trigger = yield queue.get() + if trigger is None: + break + reply_q = self._pipeline_queues.get(trigger.pipeline_id) + if reply_q is not None: + yield reply_q.put(trigger) + + def _worker(self, env: simpy.Environment) -> Generator: + """Main inbox loop: dispatch PE commands, forward fabric transactions.""" + from kernbench.common.pe_commands import PeInternalTxn + + while True: + msg: Any = yield self._inbox.get() + if isinstance(msg, PeInternalTxn): + env.process(self._dispatch(env, msg)) + else: + env.process(self._forward_txn(env, msg)) + + # ========================================================================== + # Command Dispatch + # ========================================================================== + + def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator: + """Route a PeInternalTxn to the appropriate engine or pipeline.""" + from kernbench.common.pe_commands import ( + CompositeCmd, DmaReadCmd, DmaWriteCmd, + GemmCmd, MathCmd, PeCpuOverheadCmd, + ) + + yield from self.run(env, 0) # scheduler overhead + + cmd = pe_txn.command + pp = self._pe_prefix + + if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)): + yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn) + + elif isinstance(cmd, GemmCmd): + yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn) + + elif isinstance(cmd, MathCmd): + yield self.out_ports[f"{pp}.pe_math"].put(pe_txn) + + elif isinstance(cmd, CompositeCmd): + if cmd.op == "gemm" and cmd.b is not None: + yield from self._dispatch_composite_gemm(env, pe_txn, cmd) + else: + yield from self._dispatch_composite_math(env, pe_txn, cmd) + + elif isinstance(cmd, PeCpuOverheadCmd): + yield env.timeout(cmd.cycles / self._clock_freq_ghz) + pe_txn.done.succeed() + + else: + pe_txn.done.succeed() + + # -- GEMM composite -------------------------------------------------------- + + def _dispatch_composite_gemm( + self, env: simpy.Environment, pe_txn: Any, cmd: Any, + ) -> Generator: + """Run tiled GEMM pipeline and collect per-stage metrics.""" + from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline + + a, b = cmd.a, cmd.b + M, K, N = a.shape[-2], a.shape[-1], b.shape[-1] + + pid, reply_q = self._alloc_pipeline(env) + + pipeline = GemmPipeline( + env=env, M=M, K=K, N=N, + tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n, + bytes_per_element=self._bpe, + pipeline_id=pid, + reply_queue=reply_q, + dmaIN_cmd_q=self._dmaIN_cmd_q, + dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, + A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr, + ) + + # Load descriptors into shared blocks + self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) + self._gemm_block.load_descriptors(pipeline.gemm_descs) + self._math_block.load_descriptors(pipeline.math_descs) + self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) + + start_ns = env.now + yield pipeline.done + wall_clock_ns = env.now - start_ns + del self._pipeline_queues[pid] + + # -- Collect metrics --------------------------------------------------- + rd = pe_txn.result_data + rd["M_tiles"] = pipeline.M_tiles + rd["K_tiles"] = pipeline.K_tiles + rd["N_tiles"] = pipeline.N_tiles + rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles + rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles + rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values()) + rd["total_dma_write_bytes"] = sum( + d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values() + ) + rd["wall_clock_ns"] = wall_clock_ns + + # Per-stage timing + rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = ( + _collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, [])) + ) + rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = ( + _collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, [])) + ) + rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = ( + _collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, [])) + ) + rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( + _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) + ) + rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( + _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) + ) + + # Derived metrics + rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0 + rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0 + rd["pipeline_parallelism"] = ( + (rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"]) + / wall_clock_ns if wall_clock_ns > 0 else 0.0 + ) + rd["effective_read_bw_gbs"] = ( + rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"] + if rd["t_dma_read_total_ns"] > 0 else None + ) + rd["effective_write_bw_gbs"] = ( + rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"] + if rd["t_dma_write_total_ns"] > 0 else None + ) + rd["bottleneck_stage"] = max( + [("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]), + ("store", rd["t_tcm_store_max_ns"])], + key=lambda x: x[1], + )[0] + + # pe_cpu.py compatibility aliases + rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] + rd["compute_ns"] = rd["t_compute_total_ns"] + + pe_txn.done.succeed() + + # -- Math composite -------------------------------------------------------- + + def _dispatch_composite_math( + self, env: simpy.Environment, pe_txn: Any, cmd: Any, + ) -> Generator: + """Run tiled element-wise math pipeline and collect metrics.""" + from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline + + assert self._dmaIN_cmd_q is not None + assert self._dmaIN_to_fetch_trig is not None + + a = cmd.a + M = a.shape[-2] if len(a.shape) >= 2 else 1 + N = a.shape[-1] + op = cmd.math_op or "identity" + + pid, reply_q = self._alloc_pipeline(env) + + pipeline = MathPipeline( + env=env, M=M, N=N, + tile_m=self._tile_m, tile_n=self._tile_n, + bytes_per_element=self._bpe, + pipeline_id=pid, + reply_queue=reply_q, + dmaIN_cmd_q=self._dmaIN_cmd_q, + dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig, + op=op, + src_addr=a.addr, dst_addr=cmd.out_addr, + ) + + # Load descriptors into shared blocks + self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs) + self._math_block.load_math_op_descriptors(pipeline.math_op_descs) + self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs) + + start_ns = env.now + yield pipeline.done + wall_clock_ns = env.now - start_ns + del self._pipeline_queues[pid] + + # -- Collect metrics --------------------------------------------------- + rd = pe_txn.result_data + rd["M_tiles"] = pipeline.M_tiles + rd["N_tiles"] = pipeline.N_tiles + rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles + rd["wall_clock_ns"] = wall_clock_ns + rd["math_op"] = op + + # DMA timing + rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = ( + _collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, [])) + ) + rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = ( + _collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, [])) + ) + + # Math op timing (load + compute + store) + rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, [])) + ) + rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, [])) + ) + rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = ( + _collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, [])) + ) + + # pe_cpu.py compatibility aliases + rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"] + rd["compute_ns"] = rd["t_math_compute_total_ns"] + + pe_txn.done.succeed() + + # -- Helpers --------------------------------------------------------------- + + def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]: + """Allocate a pipeline ID and reply queue.""" + pid = self._next_pipeline_id + self._next_pipeline_id += 1 + reply_q = simpy.Store(env) + self._pipeline_queues[pid] = reply_q + return pid, reply_q + + +# ============================================================================== +# Utility +# ============================================================================== + +def _collect_timing(times: list) -> tuple[list, float, float]: + """Return (raw_list, max_ns, total_ns) from a timing list.""" + return times, max(times, default=0.0), sum(times) diff --git a/src/kernbench/components/custom/pe_accel_legacy/scheduler/tiling.py b/src/kernbench/components/custom/pe_accel_legacy/scheduler/tiling.py new file mode 100644 index 0000000..c84f577 --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/scheduler/tiling.py @@ -0,0 +1,121 @@ +"""Tile schedule generators for GEMM and element-wise math operations. + +Each generator produces a plan of tile commands with pre-computed addresses. +Pipelines use these plans to build descriptor tables and feed commands +to the shared hardware blocks. +""" +from __future__ import annotations + +from math import ceil + +from kernbench.components.custom.pe_accel.types import ( + CmdType, + MathSchedulePlan, + MathTileCommand, + SchedulePlan, + TileCommand, +) + + +def generate_gemm_tiles( + M_tiles: int, K_tiles: int, N_tiles: int, + M: int = 0, K: int = 0, N: int = 0, + tile_m: int = 0, tile_k: int = 0, tile_n: int = 0, + bytes_per_element: int = 2, + A_addr: int = 0, B_addr: int = 0, C_addr: int = 0, + pipeline_id: int = 0, +) -> SchedulePlan: + """Generate GEMM tile commands in M (outer) -> N -> K (inner) order. + + Stamps is_last_k=True on the final K-tile per (m, n) pair. + Emits one DMA_FLUSH per (m, n) pair after all K tiles. + + Per-tile addresses (row-major layout): + A (M,K): A_addr + (m * tile_m * K + k * tile_k) * bpe + B (K,N): B_addr + (k * tile_k * N + n * tile_n) * bpe + C (M,N): C_addr + (m * tile_m * N + n * tile_n) * bpe + """ + commands: list[TileCommand] = [] + cmd_id = 0 + tile_id = 0 + bpe = bytes_per_element + + for m in range(M_tiles): + for n in range(N_tiles): + c_tile_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe + + for k in range(K_tiles): + last_k = k == K_tiles - 1 + a_tile_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe + b_tile_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe + + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.DMA_LOAD, + tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n, + is_last_k=last_k, pipeline_id=pipeline_id, + a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.TENSOR_OP, + tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n, + is_last_k=last_k, pipeline_id=pipeline_id, + a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + tile_id += 1 + + # One flush per (m, n) pair after all K tiles + commands.append(TileCommand( + cmd_id=cmd_id, cmd_type=CmdType.DMA_FLUSH, + tile_id=tile_id - 1, m_idx=m, k_idx=0, n_idx=n, + pipeline_id=pipeline_id, + c_tile_addr=c_tile_addr, + )) + cmd_id += 1 + + return SchedulePlan( + commands=commands, M_tiles=M_tiles, K_tiles=K_tiles, N_tiles=N_tiles + ) + + +def generate_math_tiles( + M_tiles: int, N_tiles: int, + M: int = 0, N: int = 0, + tile_m: int = 0, tile_n: int = 0, + bytes_per_element: int = 2, + src_addr: int = 0, dst_addr: int = 0, + pipeline_id: int = 0, +) -> MathSchedulePlan: + """Generate element-wise math tile commands in row-major order. + + Per-tile addresses (row-major layout): + src: src_addr + (m * tile_m * N + n * tile_n) * bpe + dst: dst_addr + (m * tile_m * N + n * tile_n) * bpe + """ + commands: list[MathTileCommand] = [] + cmd_id = 0 + tile_id = 0 + bpe = bytes_per_element + + for m in range(M_tiles): + for n in range(N_tiles): + offset = (m * tile_m * N + n * tile_n) * bpe + commands.append(MathTileCommand( + cmd_id=cmd_id, + tile_id=tile_id, + m_idx=m, + n_idx=n, + src_tile_addr=src_addr + offset, + dst_tile_addr=dst_addr + offset, + pipeline_id=pipeline_id, + )) + cmd_id += 1 + tile_id += 1 + + return MathSchedulePlan( + commands=commands, M_tiles=M_tiles, N_tiles=N_tiles + ) diff --git a/src/kernbench/components/custom/pe_accel_legacy/types.py b/src/kernbench/components/custom/pe_accel_legacy/types.py new file mode 100644 index 0000000..0e6babe --- /dev/null +++ b/src/kernbench/components/custom/pe_accel_legacy/types.py @@ -0,0 +1,148 @@ +"""Data types for pe_accel_v1: descriptors, triggers, tile commands. + +All types are frozen/plain dataclasses with no logic. +Schedule generators live in tiling/schedule.py. +""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, auto + + +# -- Enums --------------------------------------------------------------------- + +class CmdType(Enum): + DMA_LOAD = auto() + TENSOR_OP = auto() + DMA_FLUSH = auto() + + +# -- Inter-block messaging ----------------------------------------------------- + +@dataclass +class Trigger: + """Completion token passed between hardware blocks.""" + + tile_id: int + pipeline_id: int + vc: int | None = None + source_block: str = "" + + +@dataclass +class DmaRequest: + """DMA load request — descriptor lookup key only. + + Transfer params (size, address) live in the pre-loaded DmaInDescriptor. + """ + + tile_id: int + pipeline_id: int + operand: str # "A" or "B" + + +# -- Descriptors (pre-loaded by pipelines, consumed by blocks) ----------------- + +@dataclass +class DmaInDescriptor: + """Per-operand DMA read descriptor.""" + + size_bytes: int + src_addr: int = 0 + next_block: str = "GEMM" # "GEMM" | "MATH" | "COMPLETION" + + +@dataclass +class GemmDescriptor: + """Per-tile GEMM descriptor.""" + + Tm: int + Tk: int + Tn: int + triggers_needed: int = 2 # 2 = both operands from DMA; 1 = SPMem bypass + gemm_load: bool = True + gemm_compute: bool = True + next_block: str = "MATH" # "MATH" | "DMAWB" | "DONE" + + +@dataclass +class MathDescriptor: + """Per-tile K-accumulation descriptor (used by GEMM pipeline).""" + + Tm: int + Tn: int + is_last_k: bool + skip_dmaWB: bool # True = C stays in SPMem; False = flush to HBM + + +@dataclass +class MathOpDescriptor: + """Per-tile element-wise math op descriptor (used by math pipeline).""" + + Tm: int + Tn: int + op: str # "exp", "log", "sqrt", "sigmoid", etc. + src_addr: int = 0 + dst_addr: int = 0 + + +@dataclass +class DmaWBDescriptor: + """Per-tile DMA writeback descriptor.""" + + Tm: int + Tn: int + dst_addr: int = 0 + + +# -- Tile commands (produced by schedule generators) --------------------------- + +@dataclass +class TileCommand: + """A single GEMM tile command.""" + + cmd_id: int + cmd_type: CmdType + tile_id: int + m_idx: int + k_idx: int + n_idx: int + is_last_k: bool = False + pipeline_id: int = 0 + a_tile_addr: int = 0 + b_tile_addr: int = 0 + c_tile_addr: int = 0 + + +@dataclass +class SchedulePlan: + """Full tile schedule for one GEMM operation.""" + + commands: list # list[TileCommand] + M_tiles: int + K_tiles: int + N_tiles: int + + +@dataclass +class MathTileCommand: + """A single element-wise math tile command.""" + + cmd_id: int + tile_id: int + m_idx: int + n_idx: int + src_tile_addr: int = 0 + dst_tile_addr: int = 0 + pipeline_id: int = 0 + + +@dataclass +class MathSchedulePlan: + """Full tile schedule for one element-wise math operation.""" + + commands: list # list[MathTileCommand] + M_tiles: int + N_tiles: int + + diff --git a/src/kernbench/topology/builder.py b/src/kernbench/topology/builder.py index a6173bc..c525f4f 100644 --- a/src/kernbench/topology/builder.py +++ b/src/kernbench/topology/builder.py @@ -20,6 +20,7 @@ _PE_COMP_OFFSETS = { "pe_cpu": (-0.3, 0.0), "pe_scheduler": (-0.15, 0.0), "pe_dma": (0.0, -0.15), + "pe_fetch_store": (0.15, 0.0), "pe_gemm": (0.0, 0.0), "pe_math": (0.0, 0.15), "pe_mmu": (0.15, -0.15), @@ -637,12 +638,13 @@ def _instantiate_cube( def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None: - """Add PE-internal edges for a single PE instance.""" + """Add PE-internal edges for a single PE instance (ADR-0021).""" edges.append(Edge( src=f"{pp}.pe_cpu", dst=f"{pp}.pe_scheduler", distance_mm=pe_links["pe_cpu_to_scheduler_mm"], kind="pe_internal", )) + # Scheduler → engines (initial dispatch) for eng, key in [("pe_dma", "scheduler_to_dma_mm"), ("pe_gemm", "scheduler_to_gemm_mm"), ("pe_math", "scheduler_to_math_mm")]: @@ -651,6 +653,15 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None: distance_mm=pe_links[key], kind="pe_internal", )) + # Scheduler → fetch_store (initial dispatch) + if "scheduler_to_fetch_store_mm" in pe_links: + edges.append(Edge( + src=f"{pp}.pe_scheduler", dst=f"{pp}.pe_fetch_store", + distance_mm=pe_links["scheduler_to_fetch_store_mm"], + kind="pe_internal", + )) + + # Engine → TCM (legacy BW edges) for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"), ("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"), ("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]: @@ -661,6 +672,32 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None: kind="pe_internal", )) + # Fetch/Store → TCM (ADR-0021 D5) + if "fetch_store_to_tcm_mm" in pe_links: + edges.append(Edge( + src=f"{pp}.pe_fetch_store", dst=f"{pp}.pe_tcm", + distance_mm=pe_links["fetch_store_to_tcm_mm"], + bw_gbs=pe_links.get("fetch_store_to_tcm_bw_gbs", 512.0), + kind="pe_internal", + )) + + # Chaining edges (ADR-0021 D4 — token self-routing) + chaining = [ + ("pe_dma", "pe_fetch_store", "dma_to_fetch_store_mm"), + ("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"), + ("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"), + ("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"), + ("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"), + ("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"), + ] + for src_eng, dst_eng, mm_key in chaining: + if mm_key in pe_links: + edges.append(Edge( + src=f"{pp}.{src_eng}", dst=f"{pp}.{dst_eng}", + distance_mm=pe_links[mm_key], + kind="pe_internal", + )) + # ── Inter-cube / IO / system edges ────────────────────────────────── @@ -1071,6 +1108,7 @@ def _build_pe_view(spec: dict) -> ViewGraph: "pe_cpu": (1.5, 4.0), "pe_scheduler": (4.0, 4.0), "pe_dma": (7.0, 1.5), + "pe_fetch_store": (8.5, 4.0), "pe_gemm": (7.0, 4.0), "pe_math": (7.0, 6.5), "pe_mmu": (4.0, 1.5), @@ -1101,6 +1139,12 @@ def _build_pe_view(spec: dict) -> ViewGraph: distance_mm=pe_links[key], kind="pe_internal", )) + if "scheduler_to_fetch_store_mm" in pe_links: + view_edges.append(Edge( + src="pe_scheduler", dst="pe_fetch_store", + distance_mm=pe_links["scheduler_to_fetch_store_mm"], + kind="pe_internal", + )) for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"), ("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"), ("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]: @@ -1110,6 +1154,13 @@ def _build_pe_view(spec: dict) -> ViewGraph: bw_gbs=pe_links[bw_key], kind="pe_internal", )) + if "fetch_store_to_tcm_mm" in pe_links: + view_edges.append(Edge( + src="pe_fetch_store", dst="pe_tcm", + distance_mm=pe_links["fetch_store_to_tcm_mm"], + bw_gbs=pe_links.get("fetch_store_to_tcm_bw_gbs", 512.0), + kind="pe_internal", + )) return ViewGraph( name="pe", nodes=nodes, edges=view_edges, diff --git a/tests/test_topology_compile.py b/tests/test_topology_compile.py index c2934ed..ae849aa 100644 --- a/tests/test_topology_compile.py +++ b/tests/test_topology_compile.py @@ -19,16 +19,16 @@ def test_full_graph_node_count(): # + 2 SIPs x (1 IO x 23 io_nodes # + 16 cubes x (32 routers + 1 hbm_ctrl + 1 m_cpu + 1 sram # + 20 ucie (4 ports x (1 port + 4 conn)) - # + 8 PEs x 7 pe_comps)) + # + 8 PEs x 8 pe_comps)) (ADR-0021: +pe_fetch_store) # IO: pcie_ep + io_cpu + noc + 4 io_ucie_ports + 4*4 io_ucie_conn = 23 - # cube: 32 + 3 + 20 + 56 = 111 - # = 1 + 2*(23 + 16*111) = 1 + 2*(23+1776) = 1 + 3598 = 3599 - assert len(g.nodes) == 3599 + # cube: 32 + 3 + 20 + 64 = 119 + # = 1 + 2*(23 + 16*119) = 1 + 2*(23+1904) = 1 + 3854 = 3855 + assert len(g.nodes) == 3855 def test_full_graph_edge_count(): g = _graph() - assert len(g.edges) == 10874 + assert len(g.edges) == 12922 # ADR-0021: +pe_fetch_store + chaining edges # -- Full graph: specific nodes exist ----------------------------------------- @@ -286,7 +286,8 @@ def test_cube_view_pe_to_router(): def test_pe_view_has_all_components(): v = _graph().pe_view assert set(v.nodes.keys()) == { - "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm" + "pe_cpu", "pe_scheduler", "pe_dma", "pe_fetch_store", + "pe_gemm", "pe_math", "pe_mmu", "pe_tcm", } diff --git a/tests/test_topology_load.py b/tests/test_topology_load.py index e16db62..82f2859 100644 --- a/tests/test_topology_load.py +++ b/tests/test_topology_load.py @@ -23,7 +23,8 @@ def test_pe_template_components(): spec = _read_spec(TOPOLOGY_PATH) comps = spec["cube"]["pe_template"]["components"] assert set(comps.keys()) == { - "pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm" + "pe_cpu", "pe_scheduler", "pe_dma", "pe_fetch_store", + "pe_gemm", "pe_math", "pe_mmu", "pe_tcm", } diff --git a/topology.yaml b/topology.yaml index 52ed3ae..e7a0970 100644 --- a/topology.yaml +++ b/topology.yaml @@ -63,19 +63,28 @@ cube: pe_cpu: { kind: pe_cpu, impl: pe_cpu_v1, attrs: { overhead_ns: 2.0 } } pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { overhead_ns: 1.0 } } pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } } - pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } } - pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } } - pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } } - pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs: - { size_mb: 16 } } + pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } } + pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } } + pe_fetch_store: { kind: pe_fetch_store, impl: pe_fetch_store_v1, attrs: { overhead_ns: 0.0 } } + pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } } + pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs: { size_mb: 16, read_bw_gbs: 512.0, write_bw_gbs: 512.0 } } links: pe_cpu_to_scheduler_mm: 0.5 scheduler_to_dma_mm: 0.5 scheduler_to_gemm_mm: 0.5 scheduler_to_math_mm: 0.5 + scheduler_to_fetch_store_mm: 0.5 dma_to_tcm_bw_gbs: 512.0 dma_to_tcm_mm: 0.5 - gemm_to_tcm_bw_gbs: 512.0 # GEMM reads inputs from TCM (ADR-0014 D5) + dma_to_fetch_store_mm: 0.0 # DMA → fetch_store chaining (ADR-0021) + fetch_store_to_tcm_bw_gbs: 512.0 + fetch_store_to_tcm_mm: 0.0 + fetch_store_to_gemm_mm: 0.0 # fetch → GEMM chaining (ADR-0021) + fetch_store_to_math_mm: 0.0 # fetch → MATH chaining (ADR-0021) + gemm_to_fetch_store_mm: 0.0 # GEMM → store chaining (ADR-0021) + math_to_fetch_store_mm: 0.0 # MATH → store chaining (ADR-0021) + fetch_store_to_dma_mm: 0.0 # store → DMA writeback chaining (ADR-0021) + gemm_to_tcm_bw_gbs: 512.0 gemm_to_tcm_mm: 0.5 math_to_tcm_bw_gbs: 512.0 math_to_tcm_mm: 0.5