Implement ADR-0021: PE pipeline refactor with token self-routing

Step 1-2: Backup existing code
- builtin/ → builtin_legacy/ (unchanged backup)
- custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup)

Step 3-4: New pipeline types and tiling
- pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken
- tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel)

Step 5: Component implementations (ADR-0021 D4-D6)
- PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation
- PE_FETCH_STORE: new component — TCM ↔ Register File
- PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode
- PE_MATH: TileToken pipeline + legacy dual-mode
- PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode
- PE_TCM: TcmRequest handler with dual-channel BW serialization

Step 6: Infrastructure
- topology.yaml: pe_fetch_store component + chaining edges
- components.yaml: pe_fetch_store_v1 registration
- builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions
- Tests: node/edge counts, PE component sets updated

All components handle both TileToken (pipeline) and PeInternalTxn (legacy).
Token self-routing: components read next stage from token.plan, chain via out_port.
366 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:35:31 -07:00
parent 161132cdcb
commit b6eb97c49a
40 changed files with 4055 additions and 214 deletions
+1
View File
@@ -43,6 +43,7 @@ components:
pe_dma_v1: kernbench.components.builtin.pe_dma:PeDmaComponent
pe_gemm_v1: kernbench.components.builtin.pe_gemm:PeGemmComponent
pe_math_v1: kernbench.components.builtin.pe_math:PeMathComponent
pe_fetch_store_v1: kernbench.components.builtin.pe_fetch_store:PeFetchStoreComponent
pe_mmu_v1: kernbench.components.builtin.pe_mmu:PeMmuComponent
pe_tcm_v1: kernbench.components.builtin.pe_tcm:PeTcmComponent
@@ -105,6 +105,73 @@ class PeDmaComponent(PeEngineBase):
yield sub_done
pe_txn.done.succeed()
def _worker(self, env: simpy.Environment) -> Generator:
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: DMA read/write via fabric, then self-route."""
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, TensorHandle
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import PeDmaMsg
self._on_process_start(env, token)
params = token.params
stage_type = token.current_stage.stage_type
from kernbench.components.builtin.pe_types import StageType
is_write = stage_type == StageType.DMA_WRITE
addr = params.get("dst_addr" if is_write else "src_addr", 0)
nbytes = params.get("nbytes", 0)
if nbytes > 0 and self.ctx:
dma_res = self._dma_write if is_write else self._dma_read
assert dma_res is not None
pa = PhysAddr.decode(addr)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
drain_ns = self.ctx.compute_drain_ns(path, nbytes)
with dma_res.request() as req:
yield req
sub_done = env.event()
sub_request = PeDmaMsg(
correlation_id="pipeline",
request_id=f"tile_{token.tile_id}",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=addr, nbytes=nbytes,
is_write=is_write,
)
sub_txn = Transaction(
request=sub_request, path=path, step=0,
nbytes=nbytes, done=sub_done, drain_ns=drain_ns,
)
if len(path) > 1:
yield self.out_ports[path[1]].put(sub_txn.advance())
yield sub_done
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
# Response transactions bypass DMA channel (no outbound resource needed)
@@ -0,0 +1,77 @@
"""PE_FETCH_STORE: TCM ↔ Register File transfer unit (ADR-0021 D5).
Handles both fetch (TCM → register) and store (register → TCM).
BW serialization is delegated to PE_TCM via port communication.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeFetchStoreComponent(PeEngineBase):
"""PE_FETCH_STORE: TCM ↔ Register File (ADR-0021 D5).
Receives TileTokens via pipeline self-routing.
Sends TcmRequest to PE_TCM for BW-based latency.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._tcm_id = f"{self._pe_prefix}.pe_tcm"
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Handle both PeInternalTxn (legacy) and TileToken (pipeline)."""
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self.handle_command(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Process a pipeline TileToken: fetch or store via TCM."""
from kernbench.components.builtin.pe_tcm import TcmRequest
self._on_process_start(env, token)
direction = token.params.get("direction", "read")
nbytes = token.params.get("nbytes", 0)
if nbytes > 0 and self._tcm_id in self.out_ports:
done = env.event()
yield self.out_ports[self._tcm_id].put(
TcmRequest(direction=direction, nbytes=nbytes, done=done)
)
yield done
self._on_process_end(env, token)
# Self-routing: advance to next stage
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Legacy PeInternalTxn handling."""
yield from self.run(env, 0)
pe_txn.done.succeed()
+77 -16
View File
@@ -1,6 +1,18 @@
"""PE_GEMM: matrix multiplication engine (ADR-0021 D6).
Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken.
In pipeline mode, receives token after fetch stage, computes MAC, chains to next.
MAC latency model (from pe_accel):
cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n)
latency_ns = cycles / clock_freq_ghz
Falls back to TFLOPS model when mac dimensions not configured.
"""
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
@@ -12,33 +24,29 @@ if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# dtype → bit width (for TFLOPS scaling)
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8,
"i16": 16, "int16": 16,
"i32": 32, "int32": 32,
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
"""PE_GEMM: MAC array (ADR-0021 D6).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_MATH within the same PE.
Compute latency model:
FLOPs = 2 * M * K * N
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
compute_ns = FLOPs / (effective_tflops * 1e3)
In pipeline mode: pure compute — register data already fetched.
In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
# Cycle-accurate MAC dimensions (from pe_accel)
self._mac_m: int = int(node.attrs.get("mac_m", 0))
self._mac_k: int = int(node.attrs.get("mac_k", 0))
self._mac_n: int = int(node.attrs.get("mac_n", 0))
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
@@ -47,8 +55,15 @@ class PeGemmComponent(PeEngineBase):
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
"""Compute GEMM latency in nanoseconds."""
def _compute_ns_mac(self, m: int, k: int, n: int) -> float:
"""Cycle-accurate MAC latency (pe_accel model)."""
if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0:
cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n)
return cycles / self._clock_freq
return 0.0
def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""TFLOPS-based latency (legacy model)."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
@@ -56,11 +71,58 @@ class PeGemmComponent(PeEngineBase):
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""Choose best available latency model."""
mac_ns = self._compute_ns_mac(m, k, n)
if mac_ns > 0:
return mac_ns
return self._compute_ns_tflops(m, k, n, dtype)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure MAC compute, then self-route."""
self._on_process_start(env, token)
m = token.params.get("m", 0)
k = token.params.get("k", 0)
n = token.params.get("n", 0)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
else:
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Legacy PeInternalTxn handling."""
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
@@ -81,7 +143,6 @@ class PeGemmComponent(PeEngineBase):
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
+60 -4
View File
@@ -1,6 +1,16 @@
"""PE_MATH: element-wise / reduction computation engine (ADR-0021 D6).
Handles both legacy PeInternalTxn (MathCmd) and pipeline TileToken.
In pipeline mode, receives token after fetch stage, computes SIMD, chains to next.
SIMD latency model (from pe_accel):
cycles = ceil(num_elements / vector_width)
latency_ns = cycles / clock_freq_ghz
"""
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
@@ -14,15 +24,17 @@ if TYPE_CHECKING:
class PeMathComponent(PeEngineBase):
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
"""PE_MATH: SIMD/Vector unit (ADR-0021 D6).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_GEMM within the same PE.
In pipeline mode: pure compute — register data already fetched.
In legacy mode: handles PeInternalTxn(MathCmd) with shared accel_slot.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._vector_width: int = int(node.attrs.get("vector_width", 256))
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
@@ -31,11 +43,56 @@ class PeMathComponent(PeEngineBase):
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, num_elements: int) -> float:
"""SIMD latency (pe_accel model)."""
if self._vector_width > 0 and self._clock_freq > 0 and num_elements > 0:
cycles = ceil(num_elements / self._vector_width)
return cycles / self._clock_freq
return float(self.node.attrs.get("overhead_ns", 0.0))
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure SIMD compute, then self-route."""
self._on_process_start(env, token)
num_elements = token.params.get("num_elements", 0)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
else:
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Legacy PeInternalTxn handling."""
if self._accel:
with self._accel.request() as req:
yield req
@@ -45,7 +102,6 @@ class PeMathComponent(PeEngineBase):
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
+101 -167
View File
@@ -1,3 +1,13 @@
"""PE_SCHEDULER: plan generation + tile dispatch (ADR-0021 D2).
Receives PeInternalTxn from PE_CPU, routes to engines:
- Simple commands (DmaReadCmd, GemmCmd, etc.) → direct dispatch to engine
- CompositeCmd → generate TilePlan, feed tiles via _feed_loop
Composite pipeline uses token self-routing (ADR-0021 D4):
Scheduler only does initial dispatch + completion tracking.
Tiles chain through components based on their plan's stage sequence.
"""
from __future__ import annotations
from collections.abc import Generator
@@ -14,29 +24,18 @@ if TYPE_CHECKING:
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1, ADR-0021 D2).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Simple commands are forwarded to the appropriate engine.
CompositeCmd creates a TilePlan and feeds tiles into the pipeline.
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
Single _feed_loop process per scheduler ensures FIFO command ordering.
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
@@ -44,7 +43,6 @@ class PeSchedulerComponent(ComponentBase):
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
@@ -56,6 +54,13 @@ class PeSchedulerComponent(ComponentBase):
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
self._pending_feeds: simpy.Store | None = None
self._pipeline_counter = 0
def start(self, env: simpy.Environment) -> None:
self._pending_feeds = simpy.Store(env)
super().start(env)
env.process(self._feed_loop(env))
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
@@ -72,174 +77,103 @@ class PeSchedulerComponent(ComponentBase):
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
from kernbench.common.pe_commands import CompositeCmd, PeCpuOverheadCmd
# Scheduler overhead
yield from self.run(env, 0)
yield from self.run(env, 0) # scheduler overhead
cmd = pe_txn.command
# Check dispatch table first
# Simple command dispatch
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
# CompositeCmd: generate plan and feed
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
yield from self._dispatch_composite(env, pe_txn, cmd)
return
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
pe_txn.done.succeed()
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
def _dispatch_composite(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Generate plan and enqueue to feeder. Non-blocking (ADR-0021 D4)."""
from kernbench.components.builtin.pe_types import PipelineContext
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
plan = self._generate_plan(cmd)
self._pipeline_counter += 1
ctx = PipelineContext(
id=f"p{self._pipeline_counter}",
total_tiles=len(plan.tiles),
done_event=pe_txn.done,
)
# Enqueue to feeder — scheduler worker returns immediately
assert self._pending_feeds is not None
yield self._pending_feeds.put((plan, ctx))
def _feed_loop(self, env: simpy.Environment) -> Generator:
"""Single feeder process: FIFO command ordering (ADR-0021 D2).
No tile feed interleaving between commands.
Queue full → only this process blocks.
"""
from kernbench.common.pe_commands import CompositeCmd
from kernbench.components.builtin.pe_types import TileToken
assert self._pending_feeds is not None
while True:
plan, ctx = yield self._pending_feeds.get()
for tile in plan.tiles:
first_stage = tile.stages[0]
token = TileToken(
tile_id=tile.tile_id,
pipeline_ctx=ctx,
plan=tile,
stage_idx=0,
params=first_stage.params,
)
yield self.out_ports[first_stage.component].put(token)
def _generate_plan(self, cmd: Any) -> Any:
"""Generate a PipelinePlan from CompositeCmd."""
from kernbench.components.builtin.tiling import (
generate_gemm_plan,
generate_math_plan,
)
pp = self._pe_prefix
bpe = 2 # default bytes per element (f16)
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
a = cmd.a
b = cmd.b
M, K = a.shape[-2], a.shape[-1]
N = b.shape[-1]
return generate_gemm_plan(
M=M, K=K, N=N,
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
bytes_per_element=bpe,
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
pe_prefix=pp,
)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
# Math composite
a = cmd.a
M = a.shape[-2] if len(a.shape) >= 2 else a.shape[0]
N = a.shape[-1] if len(a.shape) >= 2 else 1
return generate_math_plan(
M=M, N=N,
tile_m=self.TILE_M, tile_n=self.TILE_N,
bytes_per_element=bpe,
math_op=cmd.math_op or "identity",
src_addr=a.addr, dst_addr=cmd.out_addr,
pe_prefix=pp,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()
+63 -6
View File
@@ -1,7 +1,18 @@
"""PE_TCM: tightly-coupled memory with BW-based access serialization (ADR-0021).
Models scratchpad memory inside the PE. Handles both legacy Transaction forwarding
and TcmRequest from PE_FETCH_STORE for BW-serialized read/write access.
Two channels (read/write) with independent serialization.
Ported from pe_accel TcmBlock timing model.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
@@ -10,16 +21,62 @@ if TYPE_CHECKING:
from kernbench.topology.types import Node
class PeTcmComponent(ComponentBase):
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
@dataclass
class TcmRequest:
"""Request to read from or write to TCM (used by PE_FETCH_STORE)."""
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
Phase 0: applies overhead_ns and drain_ns at terminal.
direction: str # "read" or "write"
nbytes: int
done: simpy.Event
tag: str = ""
class PeTcmComponent(ComponentBase):
"""PE_TCM: BW-serialized scratchpad memory (ADR-0021 D1).
Dual-channel: read and write can proceed in parallel,
but concurrent reads serialize, concurrent writes serialize.
BW from topology attrs or pe_template links.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read_bw: float = float(node.attrs.get("read_bw_gbs", 512.0))
self._write_bw: float = float(node.attrs.get("write_bw_gbs", 512.0))
self._read_res: simpy.Resource | None = None
self._write_res: simpy.Resource | None = None
def run(self, env, nbytes: int) -> Generator:
def start(self, env: simpy.Environment) -> None:
self._read_res = simpy.Resource(env, capacity=1)
self._write_res = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch TcmRequest (from fetch_store) and Transaction (fabric)."""
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TcmRequest):
env.process(self._handle_tcm_request(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _handle_tcm_request(self, env: simpy.Environment, req: TcmRequest) -> Generator:
"""BW-serialized access: acquire channel, apply delay, signal done."""
if req.direction == "write":
res = self._write_res
bw = self._write_bw
else:
res = self._read_res
bw = self._read_bw
assert res is not None
with res.request() as lock:
yield lock
if bw > 0 and req.nbytes > 0:
delay_ns = req.nbytes / bw
yield env.timeout(delay_ns)
req.done.succeed()
@@ -0,0 +1,115 @@
"""PE pipeline types for ADR-0021: TileToken, TilePlan, Stage, PipelineContext.
These types are used by the PE_SCHEDULER and all PE engine components
for tile-based pipeline execution with self-routing.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import simpy
# ── Stage types ──────────────────────────────────────────────────────
class StageType(Enum):
DMA_READ = auto()
FETCH = auto()
GEMM = auto()
MATH = auto()
STORE = auto()
DMA_WRITE = auto()
@dataclass
class Stage:
"""One stage in a tile's execution plan."""
stage_type: StageType
component: str # topology node ID (e.g. "sip0.cube0.pe0.pe_dma")
params: dict = field(default_factory=dict)
# ── Plan ─────────────────────────────────────────────────────────────
@dataclass
class TilePlan:
"""Execution plan for a single tile (immutable stage sequence)."""
tile_id: int
stages: tuple[Stage, ...]
@dataclass
class PipelinePlan:
"""Full pipeline plan for one CompositeCmd."""
tiles: list[TilePlan]
# Metadata for metrics
m_tiles: int = 0
k_tiles: int = 0
n_tiles: int = 0
# ── Pipeline Context ─────────────────────────────────────────────────
@dataclass
class PipelineContext:
"""Tracks completion of a pipeline (exactly-once contract).
Each tile's last stage calls complete_tile() exactly once.
When all tiles complete, done_event.succeed() is called.
"""
id: str
total_tiles: int
completed_tiles: int = 0
done_event: Any = None # simpy.Event
def complete_tile(self) -> None:
self.completed_tiles += 1
if self.completed_tiles == self.total_tiles:
if self.done_event is not None:
self.done_event.succeed()
# ── TileToken ────────────────────────────────────────────────────────
@dataclass
class TileToken:
"""Self-routing tile token passed between PE components (ADR-0021 D9).
Single-owner: only one component holds this token at any time.
params is a cache of plan.stages[stage_idx].params (canonical source).
"""
tile_id: int
pipeline_ctx: PipelineContext
plan: TilePlan
stage_idx: int
params: dict = field(default_factory=dict)
data_op: bool = True # op_log recording target (ADR-0020)
@property
def current_stage(self) -> Stage:
return self.plan.stages[self.stage_idx]
@property
def has_next_stage(self) -> bool:
return self.stage_idx + 1 < len(self.plan.stages)
def advance(self) -> Stage | None:
"""Advance to next stage. Returns next Stage or None if last."""
self.stage_idx += 1
if self.stage_idx < len(self.plan.stages):
next_stage = self.plan.stages[self.stage_idx]
self.params = next_stage.params
return next_stage
return None
+176
View File
@@ -0,0 +1,176 @@
"""Tile plan generators for PE pipeline (ADR-0021).
Generates TilePlan with stage sequences for GEMM and Math operations.
Ported from pe_accel tiling.py with stage-based plan structure.
"""
from __future__ import annotations
from math import ceil
from kernbench.components.builtin.pe_types import (
PipelinePlan,
Stage,
StageType,
TilePlan,
)
def generate_gemm_plan(
M: int, K: int, N: int,
tile_m: int, tile_k: int, tile_n: int,
bytes_per_element: int,
A_addr: int, B_addr: int, C_addr: int,
pe_prefix: str,
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
Each tile follows stage sequence:
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
On last K-tile per (m,n): → DMA_WRITE
Args:
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
"""
M_tiles = max(1, ceil(M / tile_m))
K_tiles = max(1, ceil(K / tile_k))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
gemm_id = f"{pe_prefix}.pe_gemm"
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
for k in range(K_tiles):
last_k = k == K_tiles - 1
a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
a_bytes = tile_m * tile_k * bpe
b_bytes = tile_k * tile_n * bpe
out_bytes = tile_m * tile_n * bpe
stages: list[Stage] = []
# DMA READ: load A and B tiles from HBM → TCM
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# FETCH: TCM → Register File
stages.append(Stage(
stage_type=StageType.FETCH,
component=fetch_id,
params={
"direction": "read",
"nbytes": a_bytes + b_bytes,
},
))
# GEMM: MAC compute
stages.append(Stage(
stage_type=StageType.GEMM,
component=gemm_id,
params={
"m": tile_m, "k": tile_k, "n": tile_n,
"is_last_k": last_k,
},
))
# STORE: Register File → TCM
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
# DMA WRITE: TCM → HBM (only on last K-tile)
if last_k:
stages.append(Stage(
stage_type=StageType.DMA_WRITE,
component=dma_id,
params={
"dst_addr": c_addr, "nbytes": out_bytes,
},
))
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(
tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles,
)
def generate_math_plan(
M: int, N: int,
tile_m: int, tile_n: int,
bytes_per_element: int,
math_op: str,
src_addr: int, dst_addr: int,
pe_prefix: str,
) -> PipelinePlan:
"""Generate element-wise math tile plan.
Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE
"""
M_tiles = max(1, ceil(M / tile_m))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
math_id = f"{pe_prefix}.pe_math"
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
offset = (m * tile_m * N + n * tile_n) * bpe
tile_bytes = tile_m * tile_n * bpe
stages = [
Stage(StageType.DMA_READ, dma_id, {
"src_addr": src_addr + offset, "nbytes": tile_bytes,
}),
Stage(StageType.FETCH, fetch_id, {
"direction": "read", "nbytes": tile_bytes,
}),
Stage(StageType.MATH, math_id, {
"op": math_op, "num_elements": tile_m * tile_n,
}),
Stage(StageType.STORE, fetch_id, {
"direction": "write", "nbytes": tile_bytes,
}),
Stage(StageType.DMA_WRITE, dma_id, {
"dst_addr": dst_addr + offset, "nbytes": tile_bytes,
}),
]
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)
@@ -0,0 +1,34 @@
"""Concrete component implementations.
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
Manual imports are no longer needed — add new impls to components.yaml.
Classes are still importable from this package via lazy __getattr__.
"""
from kernbench.components.base import ComponentRegistry
ComponentRegistry.load_components_yaml()
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
# without eagerly importing every module.
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
def _build_class_map() -> None:
if _CLASS_MAP:
return
for class_path in ComponentRegistry._lazy.values():
module_path, class_name = class_path.rsplit(":", 1)
_CLASS_MAP[class_name] = class_path
def __getattr__(name: str):
_build_class_map()
class_path = _CLASS_MAP.get(name)
if class_path is None:
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
import importlib
module_path, class_name = class_path.rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class TransitComponent(ComponentBase):
"""Transit component for NOC, UCIe, XBAR nodes.
Applies overhead_ns processing delay (from node.attrs) then forwards the
Transaction to the next hop via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
@@ -0,0 +1,129 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class HbmCtrlComponent(ComponentBase):
"""HBM controller: terminal component that models HBM access latency.
Dual-channel model: separate read and write resources (each capacity=1)
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
writes still serialize within their respective channel.
On completion, creates a ResponseMsg and sends it back on the reverse path
so that response latency is modeled through the fabric.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read: simpy.Resource | None = None
self._write: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
capacity = int(self.node.attrs.get("capacity", 1))
self._read = simpy.Resource(env, capacity=capacity)
self._write = simpy.Resource(env, capacity=capacity)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select channel based on request type: write requests → write, else → read."""
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
assert self._read is not None and self._write is not None
req = txn.request
if isinstance(req, MemoryWriteMsg):
return self._write
if isinstance(req, PeDmaMsg) and req.is_write:
return self._write
return self._read
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._handle_txn(env, txn))
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Acquire channel, run, apply drain, send response."""
channel = self._select_channel(txn)
with channel.request() as req:
yield req
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Route completion based on path type.
- PeDmaMsg: succeed done directly (probe).
- Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends
data back on reverse path with original done event.
- M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation.
"""
from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg
if isinstance(txn.request, PeDmaMsg):
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
resp_txn = Transaction(
request=txn.request, path=reverse_path, step=0,
nbytes=0, done=txn.done, is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
return
txn.done.succeed()
return
# Bypass path: no m_cpu in the transaction path
is_bypass = not any("m_cpu" in n for n in txn.path)
if is_bypass:
if isinstance(txn.request, MemoryReadMsg):
# D2H: send data back on reverse path to pcie_ep
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
resp_txn = Transaction(
request=txn.request, path=reverse_path, step=0,
nbytes=txn.request.nbytes, done=txn.done,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
return
# MemoryWrite bypass or short path: done
txn.done.succeed()
return
# M_CPU DMA path: send ResponseMsg for aggregation
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
pe_id = 0 # single hbm_ctrl, PE info from request
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=pe_id, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
@@ -0,0 +1,157 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class IoCpuComponent(ComponentBase):
"""IO_CPU component: multi-cube fan-out with response aggregation.
Forward path:
1. Applies overhead_ns processing overhead.
2. Resolves target cube(s) from request.target_cubes.
3. Fans out sub-Transactions to each target cube's M_CPU.
Response path:
Collects ResponseMsg from each M_CPU. When all cube responses are
received, succeeds the parent txn.done.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
env.process(self._dispatch_to_m_cpus(env, txn))
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a cube response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, parent_done = self._pending[key]
received += 1
if received >= expected:
parent_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, parent_done)
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
request = txn.request
try:
cube_targets = self._resolve_cube_targets(request)
except Exception:
txn.done.succeed()
return
if not cube_targets:
txn.done.succeed()
return
# Setup aggregation
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
# Fan out to each target cube's M_CPU
for sip, cube in cube_targets:
try:
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=txn.nbytes, done=env.event(),
result_data=txn.result_data,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import (
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
)
target_cubes = getattr(request, "target_cubes", "all")
if isinstance(request, MemoryWriteMsg):
sip = request.dst_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, MemoryReadMsg):
sip = request.src_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, KernelLaunchMsg):
my_sip = self._my_sip()
if target_cubes != "all":
return [(my_sip, c) for c in target_cubes]
# "all": derive from tensor shards, filtered to this SIP
seen: set[tuple[int, int]] = set()
targets: list[tuple[int, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip != my_sip:
continue
key = (shard.sip, shard.cube)
if key not in seen:
seen.add(key)
targets.append(key)
return targets
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
my_sip = self._my_sip()
if target_cubes == "all":
n_cubes = 16
if self.ctx and self.ctx.spec:
sips = self.ctx.spec.get("system", {}).get("sips", {})
n_cubes = sips.get("cubes_per_sip", 16)
return [(my_sip, c) for c in range(n_cubes)]
return [(my_sip, c) for c in target_cubes]
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
"""Extract cube_id from a physical address, with fallback."""
from kernbench.policy.address.phyaddr import PhysAddr
try:
return PhysAddr.decode(pa_val).cube_id
except Exception:
return fallback
def _my_sip(self) -> int:
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
return int(self.node.id.split(".")[0].replace("sip", ""))
@@ -0,0 +1,327 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class MCpuComponent(ComponentBase):
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
Forward path (ADR-0015 D5):
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
controls fan-out: int → single PE, "all" → all PEs in the cube.
Response path:
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
path back to io_cpu.
Transit:
When m_cpu is NOT the terminal hop (transit or response relay), the
Transaction is forwarded normally to the next hop.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
# Store parent txn for response sending: request_id → parent_txn
self._parent_txns: dict[str, Any] = {}
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
self._dma_write: simpy.Resource | None = None
self._dma_read: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
self._dma_write = simpy.Resource(env, capacity=1)
self._dma_read = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
env.process(self._mmu_msg_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
txn.done.succeed()
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a PE response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, all_done = self._pending[key]
received += 1
if received >= expected:
all_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, all_done)
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
then send aggregate response on reverse command path.
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
so multi-PE fan-out is serialized through the DMA engine.
"""
from kernbench.runtime_api.kernel import MemoryWriteMsg
request = txn.request
target_pe = getattr(request, "target_pe", "all")
dst_nodes = self._resolve_dma_destinations(request, target_pe)
if not dst_nodes:
txn.done.succeed()
return
# Setup aggregation
all_done = env.event()
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
self._parent_txns[request.request_id] = txn
# Select DMA resource based on operation type
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
# Fan out DMA sub-txns (serialized through DMA resource)
max_drain_ns = 0.0
for dst_node in dst_nodes:
try:
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
except Exception:
continue
if len(dma_path) < 2:
continue
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
max_drain_ns = max(max_drain_ns, drain_ns)
sub_txn = Transaction(
request=request, path=dma_path, step=0,
nbytes=txn.nbytes, done=env.event(),
drain_ns=drain_ns,
)
with dma_res.request() as req:
yield req
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
# Wait for all PE responses
yield all_done
txn.result_data["xfer_ns"] = max_drain_ns
del self._parent_txns[request.request_id]
# Send aggregate response on reverse command path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
PE_CPU sends ResponseMsg back via NOC → M_CPU on completion.
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_CPU, using response-based aggregation
sub_txns: list[Transaction] = []
n_dispatched = 0
for pe_id in pe_ids:
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=env.event(),
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_txns.append(sub_txn)
n_dispatched += 1
if n_dispatched == 0:
txn.done.succeed()
return
# Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response)
all_done = env.event()
self._pending[request.request_id] = (n_dispatched, 0, all_done)
self._parent_txns[request.request_id] = txn
# Wait for all PE_CPU responses via NOC
yield all_done
del self._parent_txns[request.request_id]
# Aggregate PE-internal metrics (max across PEs)
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
if pe_exec_values:
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
if dma_values:
txn.result_data["dma_ns"] = max(dma_values)
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
if compute_values:
txn.result_data["compute_ns"] = max(compute_values)
# Send aggregate response on reverse command path back to IO_CPU
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
"""Return list of HBM destination node_ids for DMA fan-out.
With single hbm_ctrl per cube (ADR-0019), always returns one node.
PA-based resolution still used for cross-cube routing.
"""
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
# PA-based resolution: extract actual target from physical address
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
if pa_val is not None:
from kernbench.policy.address.phyaddr import PhysAddr
try:
pa = PhysAddr.decode(pa_val)
return [self.ctx.resolver.resolve(pa)]
except Exception:
pass
# Default: single hbm_ctrl in local cube
return [f"{cube_prefix}.hbm_ctrl"]
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
PE_MMU is a terminal node — completes the transaction directly.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_MMU
sub_dones: list[simpy.Event] = []
for pe_id in pe_ids:
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
# Wait for all PE_MMUs to complete
for sd in sub_dones:
yield sd
# Send aggregate response on reverse path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
return [target_pe]
if isinstance(target_pe, tuple):
return list(target_pe)
# "all": all PEs in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return list(range(n_slices))
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PcieEpComponent(ComponentBase):
"""PCIe endpoint: protocol processing overhead before forwarding.
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
then forwards via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
@@ -0,0 +1,214 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeCpuComponent(ComponentBase):
"""PE_CPU: kernel execution controller (Stage 2).
Two-phase kernel execution (ADR-0014 D1):
Phase 1 (compile): look up kernel from registry, run it with TLContext
to generate a PeCommand list.
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
PeInternalTxn, wait for blocking commands.
Non-kernel Transactions are forwarded normally.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
try:
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
except (IndexError, ValueError):
self._pe_idx = 0
# Extract sip/cube index for multi-SIP/cube shard matching
parts = node.id.split(".")
try:
self._sip_idx = int(parts[0].replace("sip", ""))
except (IndexError, ValueError):
self._sip_idx = 0
try:
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
for s in shards:
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
return s
return shards[min(self._pe_idx, len(shards) - 1)]
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
from kernbench.runtime_api.kernel import KernelLaunchMsg
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
yield from self._execute_kernel(env, txn)
else:
yield from self._forward_txn(env, txn)
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
"""Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay."""
from kernbench.triton_emu.registry import get_kernel
request = txn.request
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
num_programs = self._derive_num_programs(request)
kernel_args = self._unpack_kernel_args(request)
pe_exec_start = env.now
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
# Choose execution mode: greenlet (ADR-0020) or legacy command-list
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
if store is not None:
composite_results = yield from self._execute_greenlet(
env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
)
else:
composite_results = yield from self._execute_legacy(
env, kernel_fn, kernel_args, num_programs, scheduler_id,
)
# Record PE-internal execution time
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
total_dma_ns = 0.0
total_compute_ns = 0.0
for rd in composite_results:
total_dma_ns += rd.get("dma_ns", 0.0)
total_compute_ns += rd.get("compute_ns", 0.0)
txn.result_data["dma_ns"] = total_dma_ns
txn.result_data["compute_ns"] = total_compute_ns
# Send ResponseMsg on reverse path
yield from self._send_response(env, txn, request)
def _derive_num_programs(self, request: Any) -> int:
num_programs = 1
for arg in request.args:
if arg.arg_kind == "tensor":
cube_pe_count = sum(
1 for s in arg.shards
if s.sip == self._sip_idx and s.cube == self._cube_idx
)
if cube_pe_count > num_programs:
num_programs = cube_pe_count
return num_programs
def _unpack_kernel_args(self, request: Any) -> list:
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
if arg.va_base:
kernel_args.append(arg.va_base)
else:
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
return kernel_args
def _execute_greenlet(
self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
) -> Generator:
"""Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved."""
from kernbench.triton_emu.kernel_runner import KernelRunner
runner = KernelRunner(
pe_prefix=self._pe_prefix,
pe_idx=self._pe_idx,
sip_idx=self._sip_idx,
cube_idx=self._cube_idx,
scheduler_id=scheduler_id,
out_ports=self.out_ports,
store=store,
)
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
return getattr(runner, "_composite_results", [])
def _execute_legacy(
self, env, kernel_fn, kernel_args, num_programs, scheduler_id,
) -> Generator:
"""Legacy Phase 0 + replay: generate command list, then dispatch."""
from kernbench.common.pe_commands import (
CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd,
)
from kernbench.triton_emu.tl_context import TLContext, run_kernel
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands
pending: dict[str, simpy.Event] = {}
composite_results: list[dict] = []
for cmd in commands:
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
elif isinstance(cmd, WaitCmd):
if cmd.handle is not None:
evt = pending.pop(cmd.handle.id, None)
if evt:
yield evt
else:
for evt in pending.values():
yield evt
pending.clear()
elif isinstance(cmd, CompositeCmd):
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
)
composite_results.append(pe_txn.result_data)
yield self.out_ports[scheduler_id].put(pe_txn)
pending[cmd.completion.id] = done_evt
else:
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
)
yield self.out_ports[scheduler_id].put(pe_txn)
yield done_evt
for evt in pending.values():
yield evt
return composite_results
def _send_response(self, env, txn, request) -> Generator:
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=self._cube_idx, src_pe=self._pe_idx,
success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
@@ -0,0 +1,138 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeDmaComponent(PeEngineBase):
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
Each channel has capacity=1 (ADR-0014 D4):
- DMA_READ and DMA_WRITE may execute concurrently.
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
Handles two message types:
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
self._mmu = None # PeMMU instance, set by engine wiring
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
self._dma_write = simpy.Resource(env, capacity=1)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import PeDmaMsg
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target address (VA → PA via MMU)
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
raw_addr = cmd.src_addr
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
raw_addr = cmd.dst_addr
is_write = True
else:
pe_txn.done.succeed()
return
# Translate VA → PA via MMU (if available), then resolve HBM node
# If MMU has no mapping for this address (PageFault), treat as PA directly
# (backward-compatible with PA-only mode)
if self._mmu is not None:
from kernbench.policy.address.pe_mmu import PageFault
try:
target_pa = self._mmu.translate(raw_addr)
if self._mmu.overhead_ns > 0:
yield env.timeout(self._mmu.overhead_ns)
except PageFault:
target_pa = raw_addr
else:
target_pa = raw_addr # fallback: treat as PA directly
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
# Acquire DMA channel (command issue serialization)
with dma_res.request() as req:
yield req
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
sub_done = env.event()
sub_request = PeDmaMsg(
correlation_id="pe_internal",
request_id=f"dma_{id(pe_txn)}",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=target_pa, nbytes=cmd.nbytes,
is_write=is_write,
)
sub_txn = Transaction(
request=sub_request, path=path, step=0,
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
)
# Send to next hop (path[0] is pe_dma itself, path[1] is router)
if len(path) > 1:
yield self.out_ports[path[1]].put(sub_txn.advance())
# DMA channel released after issue
# Wait for HBM transfer completion
yield sub_done
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
# Response transactions bypass DMA channel (no outbound resource needed)
if getattr(txn, "is_response", False):
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
txn.done.succeed()
return
dma_res = self._select_channel(txn)
with dma_res.request() as req:
yield req
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select DMA channel based on request type."""
from kernbench.runtime_api.kernel import MemoryWriteMsg
assert self._dma_read is not None and self._dma_write is not None
if isinstance(txn.request, MemoryWriteMsg):
return self._dma_write
return self._dma_read
@@ -0,0 +1,90 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# dtype → bit width (for TFLOPS scaling)
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8,
"i16": 16, "int16": 16,
"i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_MATH within the same PE.
Compute latency model:
FLOPs = 2 * M * K * N
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
compute_ns = FLOPs / (effective_tflops * 1e3)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
"""Compute GEMM latency in nanoseconds."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
if self._accel:
with self._accel.request() as req:
yield req
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
else:
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
@@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMathComponent(PeEngineBase):
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_GEMM within the same PE.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
if self._accel:
with self._accel.request() as req:
yield req
yield from self.run(env, 0)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
@@ -0,0 +1,66 @@
"""PE_MMU component: address translation unit.
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.policy.address.pe_mmu import PeMMU
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMmuComponent(ComponentBase):
"""PE_MMU: per-PE virtual-to-physical address translation.
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
via the ``mmu`` property for synchronous VA→PA translation.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
@property
def mmu(self) -> PeMMU:
"""The underlying PeMMU utility object for direct translate() calls."""
return self._mmu
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def _worker(self, env: simpy.Environment) -> Generator:
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if hasattr(txn, "request"):
request = txn.request
if isinstance(request, MmuMapMsg):
for entry in request.entries:
self._mmu.map(
va=entry["va"], pa=entry["pa"], size=entry["size"],
)
txn.done.succeed()
elif isinstance(request, MmuUnmapMsg):
for entry in request.entries:
self._mmu.unmap(va=entry["va"], size=entry["size"])
txn.done.succeed()
else:
# Forward non-MMU transactions normally
yield from self._forward_txn(env, txn)
else:
yield from self._forward_txn(env, txn)
@@ -0,0 +1,245 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
def _ensure_dispatch_table(cls) -> None:
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
GemmCmd: "pe_gemm",
MathCmd: "pe_math",
}
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
# Scheduler overhead
yield from self.run(env, 0)
cmd = pe_txn.command
# Check dispatch table first
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
"""
from kernbench.common.pe_commands import CompositeCmd
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()
@@ -0,0 +1,25 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeTcmComponent(ComponentBase):
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
Phase 0: applies overhead_ns and drain_ns at terminal.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
@@ -0,0 +1,59 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class SramComponent(ComponentBase):
"""Cube SRAM: terminal component that models SRAM access latency.
Applies overhead_ns processing overhead (from node.attrs).
On completion, sends a ResponseMsg back on the reverse path.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Terminal worker: process, apply drain, send response."""
while True:
txn: Any = yield self._inbox.get()
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Create ResponseMsg and send on reverse path."""
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
@@ -0,0 +1,20 @@
"""PeAccel: cycle-accurate accelerator component for pe_scheduler slot.
Register in components.yaml as:
pe_scheduler_v2: kernbench.components.custom.pe_accel.scheduler:SchedulerV2Component
Then reference in topology.yaml:
pe_scheduler: { kind: pe_scheduler, impl: pe_scheduler_v2, attrs: { ... } }
Package layout:
scheduler/ — scheduler block (component + dispatch + tiling)
scheduler.py — SchedulerV2Component
gemm_pipeline.py — tiled GEMM coordinator
math_pipeline.py — tiled element-wise math coordinator
tile_address.py — per-tile address computation
blocks/ — hardware blocks (DMA_IN, DMA_WB, GEMM, MATH, TCM)
types.py — data classes (descriptors, triggers, tile commands)
"""
from kernbench.components.custom.pe_accel.scheduler import SchedulerV2Component
__all__ = ["SchedulerV2Component"]
@@ -0,0 +1,16 @@
"""Hardware blocks for pe_accel.
Each block is a concurrent SimPy process modeling one functional unit:
- DmaInBlock: HBM → TCM tile reads (issues real DmaReadCmd to PE_DMA)
- DmaWbBlock: TCM → HBM tile writes (issues real DmaWriteCmd to PE_DMA)
- GemmBlock: 2-stage MAC pipeline (fetch + compute)
- MathBlock: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.)
- TcmBlock: TCM access serialization with BW-based timing
"""
from kernbench.components.custom.pe_accel.blocks.dma_in import DmaInBlock
from kernbench.components.custom.pe_accel.blocks.dma_wb import DmaWbBlock
from kernbench.components.custom.pe_accel.blocks.gemm import GemmBlock
from kernbench.components.custom.pe_accel.blocks.math import MathBlock
from kernbench.components.custom.pe_accel.blocks.tcm import TcmBlock, TcmRequest
__all__ = ["DmaInBlock", "DmaWbBlock", "GemmBlock", "MathBlock", "TcmBlock", "TcmRequest"]
@@ -0,0 +1,96 @@
"""DMA IN Block: reads tiles from HBM into TCM via real PE_DMA fabric.
Flow per tile:
1. Receive DmaRequest from tiling pipeline
2. Look up DmaInDescriptor for address and size
3. Issue DmaReadCmd → PE_DMA → fabric → HBM controller → response
4. Route completion Trigger to next block (GEMM, MATH, or COMPLETION)
Timing is real fabric latency (not analytical) — includes BW contention,
propagation delay, and HBM controller serialization.
"""
from __future__ import annotations
import simpy
from kernbench.components.custom.pe_accel.types import DmaInDescriptor, DmaRequest, Trigger
class DmaInBlock:
"""HBM → TCM tile reader. Shared across all concurrent pipelines.
Pipelines pre-load DmaInDescriptors keyed by (pipeline_id, tile_id, operand).
The _load_loop process reads DmaRequests, issues real DmaReadCmd to PE_DMA,
and routes completion triggers based on descriptor.next_block.
"""
def __init__(
self,
env: simpy.Environment,
cmd_q: simpy.Store,
to_fetch_trig: simpy.Store,
to_math_trig: simpy.Store,
completion_q: simpy.Store,
*,
pe_dma_port: simpy.Store | None,
pe_prefix: str,
) -> None:
self.env = env
self.cmd_q = cmd_q
self.to_fetch_trig = to_fetch_trig
self.to_math_trig = to_math_trig
self.completion_q = completion_q
self._pe_dma_port = pe_dma_port
self._pe_prefix = pe_prefix
self._descriptor_table: dict[tuple[int, int, str], DmaInDescriptor] = {}
# Per-pipeline timing histogram (keyed by pipeline_id)
self.t_dma_read_per_request: dict[int, list[float]] = {}
def load_descriptors(self, descs: dict[tuple, DmaInDescriptor]) -> None:
"""Pre-load per-operand DMA descriptors (cumulative across pipelines)."""
self._descriptor_table.update(descs)
def _load_loop(self):
"""Main process: receive DmaRequests, issue DmaReadCmd, route triggers."""
from kernbench.common.pe_commands import DmaReadCmd, PeInternalTxn, TensorHandle
while True:
req: DmaRequest = yield self.cmd_q.get()
if req is None:
break
desc = self._descriptor_table[(req.pipeline_id, req.tile_id, req.operand)]
# Issue real DMA read through PE_DMA → fabric → HBM
read_done = self.env.event()
handle = TensorHandle(
id=f"accel_rd_{req.pipeline_id}_{req.tile_id}_{req.operand}",
addr=desc.src_addr,
shape=(desc.size_bytes,),
dtype="uint8",
nbytes=desc.size_bytes,
)
txn = PeInternalTxn(
command=DmaReadCmd(handle=handle, src_addr=desc.src_addr, nbytes=desc.size_bytes),
done=read_done,
pe_prefix=self._pe_prefix,
)
t0 = self.env.now
yield self._pe_dma_port.put(txn)
yield read_done
self.t_dma_read_per_request.setdefault(req.pipeline_id, []).append(self.env.now - t0)
# Route trigger to next block
trig = Trigger(
tile_id=req.tile_id,
pipeline_id=req.pipeline_id,
vc=0 if req.operand == "A" else 1,
source_block="DMA_IN",
)
if desc.next_block == "MATH":
yield self.to_math_trig.put(trig)
elif desc.next_block == "COMPLETION":
yield self.completion_q.put(trig)
else: # "GEMM" (default)
yield self.to_fetch_trig.put(trig)
@@ -0,0 +1,88 @@
"""DMA Writeback Block: writes result tiles from TCM to HBM via real PE_DMA fabric.
Flow per tile:
1. Receive flush Trigger from GEMM or MATH block
2. Look up DmaWBDescriptor for address and tile size
3. Issue DmaWriteCmd → PE_DMA → fabric → HBM controller → response
4. Send completion Trigger to pipeline
Two _flush_loop processes run concurrently:
- One drains GEMM → DMA_WB triggers (direct writeback path)
- One drains MATH → DMA_WB triggers (K-accumulation or element-wise flush)
"""
from __future__ import annotations
import simpy
from kernbench.components.custom.pe_accel.types import DmaWBDescriptor, Trigger
class DmaWbBlock:
"""TCM → HBM tile writer. Shared across all concurrent pipelines.
Pipelines pre-load DmaWBDescriptors keyed by (pipeline_id, tile_id).
Each _flush_loop process reads triggers, issues real DmaWriteCmd to PE_DMA,
and forwards completion to the pipeline's reply queue.
"""
def __init__(
self,
env: simpy.Environment,
completion_q: simpy.Store,
*,
pe_dma_port: simpy.Store | None,
pe_prefix: str,
bytes_per_element: int,
) -> None:
self.env = env
self.completion_q = completion_q
self._pe_dma_port = pe_dma_port
self._pe_prefix = pe_prefix
self._bpe = bytes_per_element
self._descriptor_table: dict[tuple[int, int], DmaWBDescriptor] = {}
# Per-pipeline timing histogram (keyed by pipeline_id)
self.t_dma_write_per_tile: dict[int, list[float]] = {}
def load_descriptors(self, descs: dict[tuple, DmaWBDescriptor]) -> None:
"""Pre-load per-tile writeback descriptors (cumulative across pipelines)."""
self._descriptor_table.update(descs)
def _flush_loop(self, trig_q: simpy.Store):
"""Main process: receive flush triggers, issue DmaWriteCmd, send completion."""
from kernbench.common.pe_commands import DmaWriteCmd, PeInternalTxn, TensorHandle
while True:
trigger: Trigger = yield trig_q.get()
if trigger is None:
break
pid = trigger.pipeline_id
tile_id = trigger.tile_id
desc = self._descriptor_table.get((pid, tile_id))
if desc:
c_bytes = desc.Tm * desc.Tn * self._bpe
# Issue real DMA write through PE_DMA → fabric → HBM
write_done = self.env.event()
handle = TensorHandle(
id=f"accel_wb_{pid}_{tile_id}",
addr=desc.dst_addr,
shape=(desc.Tm, desc.Tn),
dtype="float16",
nbytes=c_bytes,
)
txn = PeInternalTxn(
command=DmaWriteCmd(handle=handle, dst_addr=desc.dst_addr, nbytes=c_bytes),
done=write_done,
pe_prefix=self._pe_prefix,
)
t0 = self.env.now
yield self._pe_dma_port.put(txn)
yield write_done
self.t_dma_write_per_tile.setdefault(pid, []).append(self.env.now - t0)
yield self.completion_q.put(
Trigger(tile_id=tile_id, pipeline_id=pid, source_block="DMA_WB")
)
@@ -0,0 +1,160 @@
"""GEMM Block: 2-stage MAC pipeline (fetch + compute).
Stage 1 — Fetch (_fetch_stage):
Collects DMA completion triggers (one per operand per tile).
When all operands arrive, issues TCM read request (SPMem → MAC registers).
Stage 2 — Compute (_gemm_stage):
Models MAC array computation time.
Issues TCM write request (MAC result → SPMem).
Routes output trigger to MathBlock, DmaWbBlock, or completion.
TCM access goes through TcmBlock for real BW serialization.
MAC compute time is cycle-accurate: ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n).
"""
from __future__ import annotations
from math import ceil
import simpy
from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest
from kernbench.components.custom.pe_accel.types import GemmDescriptor, Trigger
class GemmBlock:
"""2-stage MAC pipeline shared across all concurrent pipelines.
Pipelines pre-load GemmDescriptors keyed by (pipeline_id, tile_id).
Two SimPy processes run concurrently: _fetch_stage and _gemm_stage.
"""
def __init__(
self,
env: simpy.Environment,
trig_in: simpy.Store,
fetch_to_gemm_trig: simpy.Store,
to_math_trig: simpy.Store,
to_dmaWB_trig: simpy.Store,
completion_q: simpy.Store,
*,
tcm_port: simpy.Store,
mac_m: int,
mac_k: int,
mac_n: int,
bytes_per_element: int,
clock_freq_ghz: float,
) -> None:
self.env = env
self.trig_in = trig_in
self.fetch_to_gemm_trig = fetch_to_gemm_trig
self.to_math_trig = to_math_trig
self.to_dmaWB_trig = to_dmaWB_trig
self.completion_q = completion_q
self._tcm_port = tcm_port
self._mac_m = mac_m
self._mac_k = mac_k
self._mac_n = mac_n
self._bpe = bytes_per_element
self._freq = clock_freq_ghz
self._descriptor_table: dict[tuple[int, int], GemmDescriptor] = {}
# Per-pipeline timing histograms
self.t_tcm_load_per_tile: dict[int, list[float]] = {}
self.t_compute_per_tile: dict[int, list[float]] = {}
def load_descriptors(self, descs: dict[tuple, GemmDescriptor]) -> None:
"""Pre-load per-tile GEMM descriptors (cumulative across pipelines)."""
self._descriptor_table.update(descs)
def _compute_ns(self, desc: GemmDescriptor) -> float:
"""MAC array compute time for one tile (ns)."""
cycles = ceil(desc.Tm / self._mac_m) * ceil(desc.Tk / self._mac_k) * ceil(desc.Tn / self._mac_n)
return cycles / self._freq
# -- Stage 1: Fetch (TCM → MAC load) --------------------------------------
def _fetch_stage(self):
"""Collect DMA triggers per tile, issue TCM read for operand load."""
pending: dict[tuple[int, int], list[Trigger]] = {}
while True:
trigger = yield self.trig_in.get()
if trigger is None:
yield self.fetch_to_gemm_trig.put(None)
break
key = (trigger.pipeline_id, trigger.tile_id)
pending.setdefault(key, []).append(trigger)
desc = self._descriptor_table.get(key)
needed = desc.triggers_needed if desc else 2
if len(pending[key]) < needed:
continue
del pending[key]
# TCM load: read A and B tile data from SPMem → MAC registers
if desc and desc.gemm_load:
a_bytes = desc.Tm * desc.Tk * self._bpe
b_bytes = desc.Tk * desc.Tn * self._bpe
load_bytes = a_bytes + b_bytes
t0 = self.env.now
done = self.env.event()
yield self._tcm_port.put(TcmRequest("read", load_bytes, done, tag="gemm_load"))
yield done
self.t_tcm_load_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
yield self.fetch_to_gemm_trig.put(trigger)
# -- Stage 2: Compute (MAC array) + Store (MAC → TCM) ---------------------
def _gemm_stage(self):
"""MAC computation, then TCM store, then route to next block."""
while True:
trigger = yield self.fetch_to_gemm_trig.get()
if trigger is None:
break
key = (trigger.pipeline_id, trigger.tile_id)
desc = self._descriptor_table.get(key)
# MAC compute
if desc and desc.gemm_compute:
t_compute = self._compute_ns(desc)
t0 = self.env.now
if t_compute > 0:
yield self.env.timeout(t_compute)
self.t_compute_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
# Route output
route = desc.next_block if desc else "MATH"
out_trig = Trigger(
tile_id=trigger.tile_id,
pipeline_id=trigger.pipeline_id,
source_block="GEMM",
)
if route == "MATH":
yield self.to_math_trig.put(out_trig)
elif route == "DMAWB":
# TCM store before writeback
if desc:
c_bytes = desc.Tm * desc.Tn * self._bpe
done = self.env.event()
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store"))
yield done
yield self.to_dmaWB_trig.put(out_trig)
else: # "DONE" — C stays in SPMem, no flush
if desc:
c_bytes = desc.Tm * desc.Tn * self._bpe
done = self.env.event()
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="gemm_store"))
yield done
yield self.completion_q.put(out_trig)
@@ -0,0 +1,181 @@
"""Math Block: K-accumulation (GEMM helper) + element-wise ops (exp, log, etc.).
Two concurrent processing modes:
1. K-accumulation (_run_k_accumulation):
Receives triggers from GemmBlock after each K-tile compute.
Issues TCM write for partial-result store.
On final K-tile, routes to DMA_WB or completion.
2. Element-wise ops (_run_element_wise):
Receives triggers from DMA_IN after each tile read.
Issues TCM read (load input), compute (SIMD), TCM write (store result).
Routes output to DMA_WB for writeback.
TCM access goes through TcmBlock for real BW serialization.
SIMD compute time: ceil(num_elements / vector_width) / clock_freq.
"""
from __future__ import annotations
from math import ceil
import simpy
from kernbench.components.custom.pe_accel.blocks.tcm import TcmRequest
from kernbench.components.custom.pe_accel.types import MathDescriptor, MathOpDescriptor, Trigger
class MathBlock:
"""K-accumulation + element-wise math unit.
Descriptor tables:
- _accum_table: MathDescriptor (pipeline_id, tile_id) — for GEMM K-accumulation
- _elemwise_table: MathOpDescriptor (pipeline_id, tile_id) — for element-wise ops
"""
def __init__(
self,
env: simpy.Environment,
trig_in: simpy.Store,
to_dmaWB_trig: simpy.Store,
completion_q: simpy.Store,
*,
tcm_port: simpy.Store,
bytes_per_element: int,
clock_freq_ghz: float,
vector_width: int = 256,
) -> None:
self.env = env
self.trig_in = trig_in # from GemmBlock (K-accumulation)
self.to_dmaWB_trig = to_dmaWB_trig
self.completion_q = completion_q
self._tcm_port = tcm_port
self._bpe = bytes_per_element
self._freq = clock_freq_ghz
self._vector_width = vector_width
# Descriptor tables
self._accum_table: dict[tuple[int, int], MathDescriptor] = {}
self._elemwise_table: dict[tuple[int, int], MathOpDescriptor] = {}
# -- Timing histograms (per pipeline_id) --
# K-accumulation
self.t_tcm_store_per_tile: dict[int, list[float]] = {}
# Element-wise ops
self.t_math_op_load_per_tile: dict[int, list[float]] = {}
self.t_math_op_compute_per_tile: dict[int, list[float]] = {}
self.t_math_op_store_per_tile: dict[int, list[float]] = {}
# -- Descriptor loading ----------------------------------------------------
def load_descriptors(self, descs: dict[tuple, MathDescriptor]) -> None:
"""Pre-load K-accumulation descriptors (cumulative across pipelines)."""
self._accum_table.update(descs)
def load_math_op_descriptors(self, descs: dict[tuple, MathOpDescriptor]) -> None:
"""Pre-load element-wise op descriptors (cumulative across pipelines)."""
self._elemwise_table.update(descs)
# -- Mode 1: K-accumulation ------------------------------------------------
def _run(self):
"""Backward-compat alias."""
yield from self._run_k_accumulation()
def _run_k_accumulation(self):
"""Receive GEMM output triggers, TCM store partial result, flush on last-K."""
while True:
trigger = yield self.trig_in.get()
if trigger is None:
break
key = (trigger.pipeline_id, trigger.tile_id)
desc = self._accum_table.get(key)
# TCM store: write partial sum to SPMem
if desc:
c_bytes = desc.Tm * desc.Tn * self._bpe
t0 = self.env.now
done = self.env.event()
yield self._tcm_port.put(TcmRequest("write", c_bytes, done, tag="k_accum_store"))
yield done
self.t_tcm_store_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
if not desc or not desc.is_last_k:
continue # intermediate K-tile: store done, no flush yet
out_trig = Trigger(
tile_id=trigger.tile_id,
pipeline_id=trigger.pipeline_id,
source_block="MATH",
)
if desc.skip_dmaWB:
yield self.completion_q.put(out_trig)
else:
yield self.to_dmaWB_trig.put(out_trig)
# -- Mode 2: Element-wise ops ----------------------------------------------
def _run_math_op(self, trig_q: simpy.Store):
"""Backward-compat alias."""
yield from self._run_element_wise(trig_q)
def _run_element_wise(self, trig_q: simpy.Store):
"""Receive DMA_IN triggers, apply element-wise op via TCM, route to DMA_WB.
Per tile:
1. TCM read — load input tile from SPMem to SIMD
2. Compute — SIMD operation (exp/log/etc.)
3. TCM write — store result from SIMD to SPMem
4. Route to DMA_WB
"""
while True:
trigger = yield trig_q.get()
if trigger is None:
break
key = (trigger.pipeline_id, trigger.tile_id)
desc = self._elemwise_table.get(key)
if desc:
tile_bytes = desc.Tm * desc.Tn * self._bpe
num_elements = desc.Tm * desc.Tn
# 1. TCM read
t0 = self.env.now
done = self.env.event()
yield self._tcm_port.put(TcmRequest("read", tile_bytes, done, tag="elemwise_load"))
yield done
self.t_math_op_load_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
# 2. SIMD compute
t0 = self.env.now
compute_cycles = ceil(num_elements / self._vector_width)
compute_ns = compute_cycles / self._freq
if compute_ns > 0:
yield self.env.timeout(compute_ns)
self.t_math_op_compute_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
# 3. TCM write
t0 = self.env.now
done = self.env.event()
yield self._tcm_port.put(TcmRequest("write", tile_bytes, done, tag="elemwise_store"))
yield done
self.t_math_op_store_per_tile.setdefault(trigger.pipeline_id, []).append(
self.env.now - t0
)
yield self.to_dmaWB_trig.put(Trigger(
tile_id=trigger.tile_id,
pipeline_id=trigger.pipeline_id,
source_block="MATH_OP",
))
@@ -0,0 +1,80 @@
"""TCM Block: tightly-coupled memory with BW-based access serialization.
Models SPMem (scratchpad memory) inside the PE. Compute blocks (GEMM, MATH)
send TcmRequests for load/store operations. The TCM block serializes access
per channel and computes timing based on data size and bandwidth.
Two channels:
- READ (SPMem → compute unit): models operand fetch for MAC/SIMD
- WRITE (compute unit → SPMem): models result store from MAC/SIMD
Each channel has capacity=1: concurrent reads serialize, concurrent writes
serialize, but a read and a write can proceed in parallel.
"""
from __future__ import annotations
from dataclasses import dataclass
import simpy
@dataclass
class TcmRequest:
"""Request to read from or write to TCM."""
direction: str # "read" or "write"
nbytes: int
done: simpy.Event
tag: str = "" # optional label for debugging
class TcmBlock:
"""BW-serialized TCM model with dual read/write channels.
Args:
env: SimPy environment.
read_bw_gbs: read bandwidth in GB/s (SPMem → compute).
write_bw_gbs: write bandwidth in GB/s (compute → SPMem).
"""
def __init__(
self,
env: simpy.Environment,
read_bw_gbs: float = 512.0,
write_bw_gbs: float = 512.0,
) -> None:
self.env = env
self._read_bw = read_bw_gbs
self._write_bw = write_bw_gbs
self._read_res = simpy.Resource(env, capacity=1)
self._write_res = simpy.Resource(env, capacity=1)
self._port: simpy.Store = simpy.Store(env)
@property
def port(self) -> simpy.Store:
"""The SimPy Store that blocks send TcmRequests to."""
return self._port
def _run(self):
"""Main process: receive TcmRequests, dispatch to channel processes."""
while True:
req: TcmRequest = yield self._port.get()
if req is None:
break
self.env.process(self._handle(req))
def _handle(self, req: TcmRequest):
"""Acquire channel, apply BW-based delay, signal done."""
if req.direction == "write":
res = self._write_res
bw = self._write_bw
else:
res = self._read_res
bw = self._read_bw
with res.request() as lock:
yield lock
if bw > 0 and req.nbytes > 0:
delay_ns = req.nbytes / bw
yield self.env.timeout(delay_ns)
req.done.succeed()
@@ -0,0 +1,10 @@
"""Scheduler: accelerator component + dispatch + tiling pipelines.
scheduler.py — SchedulerV2Component (init, wiring, dispatch, metrics)
gemm_pipeline.py — GemmPipeline (tiled GEMM coordinator)
math_pipeline.py — MathPipeline (tiled element-wise math coordinator)
tile_address.py — per-tile address computation
"""
from kernbench.components.custom.pe_accel.scheduler.scheduler import SchedulerV2Component
__all__ = ["SchedulerV2Component"]
@@ -0,0 +1,157 @@
"""GEMM Tiling Pipeline: splits (M,K)×(K,N) into tiles and coordinates execution.
Flow per tile:
DMA_IN(A tile) + DMA_IN(B tile) → GEMM(fetch + compute) → MATH(K-accum) → DMA_WB
The pipeline does NOT own hardware blocks — it uses the component's shared
blocks via descriptor tables and SimPy queues.
Constructor starts two SimPy processes:
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
- _collect_completions(): waits for all output tiles to flush
"""
from __future__ import annotations
from math import ceil
import simpy
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_gemm_tiles
from kernbench.components.custom.pe_accel.types import (
CmdType,
DmaInDescriptor,
DmaRequest,
DmaWBDescriptor,
GemmDescriptor,
MathDescriptor,
Trigger,
)
class GemmPipeline:
"""Coordinates one tiled GEMM operation across shared hardware blocks."""
def __init__(
self,
env: simpy.Environment,
M: int, K: int, N: int,
tile_m: int, tile_k: int, tile_n: int,
bytes_per_element: int,
pipeline_id: int,
reply_queue: simpy.Store,
dmaIN_cmd_q: simpy.Store,
dmaIN_to_fetch_trig: simpy.Store,
A_addr: int = 0,
B_addr: int = 0,
C_addr: int = 0,
dma_a: bool = True,
dma_b: bool = True,
dma_c: bool = True,
) -> None:
self.env = env
self.M, self.K, self.N = M, K, N
self.pipeline_id = pipeline_id
self.reply_queue = reply_queue
self.dmaIN_cmd_q = dmaIN_cmd_q
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
self._dma_a = dma_a
self._dma_b = dma_b
self._skip_dmaWB = not dma_c
_Tm = min(tile_m, M)
_Tk = min(tile_k, K)
_Tn = min(tile_n, N)
self.M_tiles = ceil(M / tile_m)
self.K_tiles = ceil(K / tile_k)
self.N_tiles = ceil(N / tile_n)
triggers_per_tile = 2 if (dma_a and dma_b) else 1
# Generate tile schedule with pre-computed addresses
self.schedule = generate_gemm_tiles(
self.M_tiles, self.K_tiles, self.N_tiles,
M=M, K=K, N=N,
tile_m=_Tm, tile_k=_Tk, tile_n=_Tn,
bytes_per_element=bytes_per_element,
A_addr=A_addr, B_addr=B_addr, C_addr=C_addr,
pipeline_id=pipeline_id,
)
# Build descriptor tables for shared blocks
pid = pipeline_id
a_tile_bytes = _Tm * _Tk * bytes_per_element
b_tile_bytes = _Tk * _Tn * bytes_per_element
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
self.gemm_descs: dict[tuple, GemmDescriptor] = {}
self.math_descs: dict[tuple, MathDescriptor] = {}
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
for cmd in self.schedule.commands:
if cmd.cmd_type != CmdType.DMA_LOAD:
continue
t = cmd.tile_id
if dma_a:
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
size_bytes=a_tile_bytes, src_addr=cmd.a_tile_addr
)
if dma_b:
self.dmaIN_descs[(pid, t, "B")] = DmaInDescriptor(
size_bytes=b_tile_bytes, src_addr=cmd.b_tile_addr
)
self.gemm_descs[(pid, t)] = GemmDescriptor(
Tm=_Tm, Tk=_Tk, Tn=_Tn,
triggers_needed=triggers_per_tile,
next_block="MATH",
)
self.math_descs[(pid, t)] = MathDescriptor(
Tm=_Tm, Tn=_Tn,
is_last_k=cmd.is_last_k,
skip_dmaWB=self._skip_dmaWB,
)
if not self._skip_dmaWB and cmd.is_last_k:
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
Tm=_Tm, Tn=_Tn, dst_addr=cmd.c_tile_addr
)
self.expected_flushes = self.M_tiles * self.N_tiles
self.completed_flushes = 0
self.done_at: int = 0
self.done: simpy.Event = env.event()
env.process(self._feed_commands())
env.process(self._collect_completions())
def _feed_commands(self):
"""Send DmaRequests for each tile's operands to dmaIN_cmd_q."""
for cmd in self.schedule.commands:
if cmd.cmd_type != CmdType.DMA_LOAD:
continue
if self._dma_a:
yield self.dmaIN_cmd_q.put(DmaRequest(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
))
if self._dma_b:
yield self.dmaIN_cmd_q.put(DmaRequest(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="B",
))
if not self._dma_a and not self._dma_b:
yield self.dmaIN_to_fetch_trig.put(Trigger(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
))
def _collect_completions(self):
"""Wait for all output tile flush completions, then signal done."""
while self.completed_flushes < self.expected_flushes:
yield self.reply_queue.get()
self.completed_flushes += 1
self.done_at = int(self.env.now)
self.done.succeed()
@@ -0,0 +1,132 @@
"""Math Tiling Pipeline: splits element-wise ops into tiles for pipelined execution.
Flow per tile:
DMA_IN(input tile) → MATH_OP(exp/log/etc.) → DMA_WB(output tile)
Mirrors GemmTilingPipeline but for unary element-wise operations.
Pipeline overlap across tiles: while one tile is in MATH_OP, the next
tile's DMA_IN can proceed concurrently.
Constructor starts two SimPy processes:
- _feed_commands(): sends DmaRequests to shared dmaIN_cmd_q
- _collect_completions(): waits for all tiles to writeback
"""
from __future__ import annotations
from math import ceil
import simpy
from kernbench.components.custom.pe_accel.scheduler.tiling import generate_math_tiles
from kernbench.components.custom.pe_accel.types import (
DmaInDescriptor,
DmaRequest,
DmaWBDescriptor,
MathOpDescriptor,
Trigger,
)
class MathPipeline:
"""Coordinates one tiled element-wise math operation across shared blocks."""
def __init__(
self,
env: simpy.Environment,
M: int, N: int,
tile_m: int, tile_n: int,
bytes_per_element: int,
pipeline_id: int,
reply_queue: simpy.Store,
dmaIN_cmd_q: simpy.Store,
dmaIN_to_fetch_trig: simpy.Store,
op: str = "exp",
src_addr: int = 0,
dst_addr: int = 0,
dma_in: bool = True,
dma_out: bool = True,
) -> None:
self.env = env
self.M, self.N = M, N
self.pipeline_id = pipeline_id
self.reply_queue = reply_queue
self.dmaIN_cmd_q = dmaIN_cmd_q
self.dmaIN_to_fetch_trig = dmaIN_to_fetch_trig
self.op = op
self._dma_in = dma_in
self._skip_dmaWB = not dma_out
_Tm = min(tile_m, M)
_Tn = min(tile_n, N)
self.M_tiles = ceil(M / tile_m)
self.N_tiles = ceil(N / tile_n)
# Generate tile schedule with pre-computed addresses
self.schedule = generate_math_tiles(
self.M_tiles, self.N_tiles,
M=M, N=N,
tile_m=_Tm, tile_n=_Tn,
bytes_per_element=bytes_per_element,
src_addr=src_addr, dst_addr=dst_addr,
pipeline_id=pipeline_id,
)
# Build descriptor tables for shared blocks
pid = pipeline_id
tile_bytes = _Tm * _Tn * bytes_per_element
self.dmaIN_descs: dict[tuple, DmaInDescriptor] = {}
self.math_op_descs: dict[tuple, MathOpDescriptor] = {}
self.dmaWB_descs: dict[tuple, DmaWBDescriptor] = {}
for cmd in self.schedule.commands:
t = cmd.tile_id
if dma_in:
self.dmaIN_descs[(pid, t, "A")] = DmaInDescriptor(
size_bytes=tile_bytes,
src_addr=cmd.src_tile_addr,
next_block="MATH",
)
self.math_op_descs[(pid, t)] = MathOpDescriptor(
Tm=_Tm, Tn=_Tn, op=op,
src_addr=cmd.src_tile_addr,
dst_addr=cmd.dst_tile_addr,
)
if not self._skip_dmaWB:
self.dmaWB_descs[(pid, t)] = DmaWBDescriptor(
Tm=_Tm, Tn=_Tn, dst_addr=cmd.dst_tile_addr,
)
self.expected_flushes = self.M_tiles * self.N_tiles
self.completed_flushes = 0
self.done_at: int = 0
self.done: simpy.Event = env.event()
env.process(self._feed_commands())
env.process(self._collect_completions())
def _feed_commands(self):
"""Send DmaRequests for each tile's input to dmaIN_cmd_q."""
for cmd in self.schedule.commands:
if self._dma_in:
yield self.dmaIN_cmd_q.put(DmaRequest(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, operand="A",
))
else:
yield self.dmaIN_to_fetch_trig.put(Trigger(
tile_id=cmd.tile_id, pipeline_id=self.pipeline_id, source_block="PIPELINE",
))
def _collect_completions(self):
"""Wait for all tile completions, then signal done."""
while self.completed_flushes < self.expected_flushes:
yield self.reply_queue.get()
self.completed_flushes += 1
self.done_at = int(self.env.now)
self.done.succeed()
@@ -0,0 +1,434 @@
"""SchedulerV2Component: accelerator scheduler block (pe_scheduler_v2).
Replaces the pe_scheduler slot in topology (impl: pe_scheduler_v2).
Hosts four internal hardware blocks as concurrent SimPy processes:
- DmaInBlock — HBM → TCM tile reads (real fabric DMA)
- DmaWbBlock — TCM → HBM tile writes (real fabric DMA)
- GemmBlock — 2-stage MAC pipeline (fetch + compute)
- MathBlock — K-accumulation (GEMM) + element-wise ops (exp, log, etc.)
Command dispatch routes PeInternalTxn to the correct engine or tiling pipeline:
- DmaReadCmd / DmaWriteCmd → PE_DMA out_port
- GemmCmd → PE_GEMM out_port
- MathCmd → PE_MATH out_port
- CompositeCmd(op="gemm") → GemmPipeline (tiled DMA + GEMM + K-accum)
- CompositeCmd(op="math") → MathPipeline (tiled DMA + element-wise + DMA)
- PeCpuOverheadCmd → yield timeout
Config via node.attrs (all optional):
overhead_ns, clock_freq_ghz, bytes_per_element,
mac_m, mac_k, mac_n, tile_m, tile_k, tile_n,
port_a_bw, port_b_bw, store_port_bw, vector_width
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# ==============================================================================
# Component
# ==============================================================================
class SchedulerV2Component(ComponentBase):
"""PE accelerator scheduler: wires internal blocks and dispatches commands."""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
attrs = node.attrs
# Hardware config
self._overhead_ns: float = float(attrs.get("overhead_ns", 0.0))
self._clock_freq_ghz: float = float(attrs.get("clock_freq_ghz", 1.0))
self._bpe: int = int(attrs.get("bytes_per_element", 2))
# MAC array dimensions
self._mac_m: int = int(attrs.get("mac_m", 8))
self._mac_k: int = int(attrs.get("mac_k", 16))
self._mac_n: int = int(attrs.get("mac_n", 32))
# Tile dimensions
self._tile_m: int = int(attrs.get("tile_m", 32))
self._tile_k: int = int(attrs.get("tile_k", 32))
self._tile_n: int = int(attrs.get("tile_n", 32))
# Bandwidth (bytes/cycle)
self._port_a_bw: int = int(attrs.get("port_a_bw", 256))
self._port_b_bw: int = int(attrs.get("port_b_bw", 256))
self._store_port_bw: int = int(attrs.get("store_port_bw", 256))
self._vector_width: int = int(attrs.get("vector_width", 256))
# Initialized in start()
self._dmaIN_block: Any = None
self._dmaWB_block: Any = None
self._gemm_block: Any = None
self._math_block: Any = None
self._dmaIN_cmd_q: simpy.Store | None = None
self._dmaIN_to_fetch_trig: simpy.Store | None = None
# Pipeline tracking
self._next_pipeline_id: int = 0
self._pipeline_queues: dict[int, simpy.Store] = {}
# -- SimPy lifecycle -------------------------------------------------------
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
"""Scheduler overhead per dispatch."""
yield env.timeout(self._overhead_ns)
def start(self, env: simpy.Environment) -> None:
"""Create internal blocks, wire queues, start SimPy processes."""
from kernbench.components.custom.pe_accel.blocks import (
DmaInBlock, DmaWbBlock, GemmBlock, MathBlock, TcmBlock,
)
pe_dma_port = self.out_ports.get(f"{self._pe_prefix}.pe_dma")
# -- TCM block (shared BW-serialized scratchpad) -----------------------
# Read TCM BW from topology spec (gemm_to_tcm / math_to_tcm)
pe_links = {}
if self.ctx and self.ctx.spec:
pe_links = self.ctx.spec.get("cube", {}).get("pe_template", {}).get("links", {})
tcm_read_bw = float(pe_links.get("gemm_to_tcm_bw_gbs", 512.0))
tcm_write_bw = float(pe_links.get("math_to_tcm_bw_gbs", 512.0))
self._tcm_block = TcmBlock(
env=env,
read_bw_gbs=tcm_read_bw,
write_bw_gbs=tcm_write_bw,
)
# -- Internal queues ---------------------------------------------------
self._dmaIN_cmd_q = simpy.Store(env)
self._dmaIN_to_fetch_trig = simpy.Store(env)
dmaIN_to_math_trig = simpy.Store(env) # DMA_IN → element-wise math
fetch_to_gemm_trig = simpy.Store(env)
gemm_to_math_trig = simpy.Store(env)
gemm_to_dmaWB_trig = simpy.Store(env)
math_to_dmaWB_trig = simpy.Store(env)
# Completion queues (block → _completion_router → pipeline reply_q)
dmaIN_completion_q = simpy.Store(env)
dmaWB_completion_q = simpy.Store(env)
gemm_completion_q = simpy.Store(env)
math_completion_q = simpy.Store(env)
# -- Create blocks -----------------------------------------------------
self._dmaIN_block = DmaInBlock(
env=env,
cmd_q=self._dmaIN_cmd_q,
to_fetch_trig=self._dmaIN_to_fetch_trig,
to_math_trig=dmaIN_to_math_trig,
completion_q=dmaIN_completion_q,
pe_dma_port=pe_dma_port,
pe_prefix=self._pe_prefix,
)
self._dmaWB_block = DmaWbBlock(
env=env,
completion_q=dmaWB_completion_q,
pe_dma_port=pe_dma_port,
pe_prefix=self._pe_prefix,
bytes_per_element=self._bpe,
)
self._gemm_block = GemmBlock(
env=env,
trig_in=self._dmaIN_to_fetch_trig,
fetch_to_gemm_trig=fetch_to_gemm_trig,
to_math_trig=gemm_to_math_trig,
to_dmaWB_trig=gemm_to_dmaWB_trig,
completion_q=gemm_completion_q,
tcm_port=self._tcm_block.port,
mac_m=self._mac_m, mac_k=self._mac_k, mac_n=self._mac_n,
bytes_per_element=self._bpe,
clock_freq_ghz=self._clock_freq_ghz,
)
self._math_block = MathBlock(
env=env,
trig_in=gemm_to_math_trig,
to_dmaWB_trig=math_to_dmaWB_trig,
completion_q=math_completion_q,
tcm_port=self._tcm_block.port,
bytes_per_element=self._bpe,
clock_freq_ghz=self._clock_freq_ghz,
vector_width=self._vector_width,
)
# -- Start block processes ---------------------------------------------
env.process(self._tcm_block._run())
env.process(self._dmaIN_block._load_loop())
env.process(self._dmaWB_block._flush_loop(gemm_to_dmaWB_trig))
env.process(self._dmaWB_block._flush_loop(math_to_dmaWB_trig))
env.process(self._gemm_block._fetch_stage())
env.process(self._gemm_block._gemm_stage())
env.process(self._math_block._run_k_accumulation())
env.process(self._math_block._run_element_wise(dmaIN_to_math_trig))
# Wire in-ports → inbox, start _worker
super().start(env)
# Start completion routers
for q in (dmaIN_completion_q, dmaWB_completion_q, gemm_completion_q, math_completion_q):
env.process(self._completion_router(q))
# -- Internal processes ----------------------------------------------------
def _completion_router(self, queue: simpy.Store) -> Generator:
"""Route block completion triggers to the correct pipeline's reply queue."""
while True:
trigger = yield queue.get()
if trigger is None:
break
reply_q = self._pipeline_queues.get(trigger.pipeline_id)
if reply_q is not None:
yield reply_q.put(trigger)
def _worker(self, env: simpy.Environment) -> Generator:
"""Main inbox loop: dispatch PE commands, forward fabric transactions."""
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
env.process(self._forward_txn(env, msg))
# ==========================================================================
# Command Dispatch
# ==========================================================================
def _dispatch(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Route a PeInternalTxn to the appropriate engine or pipeline."""
from kernbench.common.pe_commands import (
CompositeCmd, DmaReadCmd, DmaWriteCmd,
GemmCmd, MathCmd, PeCpuOverheadCmd,
)
yield from self.run(env, 0) # scheduler overhead
cmd = pe_txn.command
pp = self._pe_prefix
if isinstance(cmd, (DmaReadCmd, DmaWriteCmd)):
yield self.out_ports[f"{pp}.pe_dma"].put(pe_txn)
elif isinstance(cmd, GemmCmd):
yield self.out_ports[f"{pp}.pe_gemm"].put(pe_txn)
elif isinstance(cmd, MathCmd):
yield self.out_ports[f"{pp}.pe_math"].put(pe_txn)
elif isinstance(cmd, CompositeCmd):
if cmd.op == "gemm" and cmd.b is not None:
yield from self._dispatch_composite_gemm(env, pe_txn, cmd)
else:
yield from self._dispatch_composite_math(env, pe_txn, cmd)
elif isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles / self._clock_freq_ghz)
pe_txn.done.succeed()
else:
pe_txn.done.succeed()
# -- GEMM composite --------------------------------------------------------
def _dispatch_composite_gemm(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Run tiled GEMM pipeline and collect per-stage metrics."""
from kernbench.components.custom.pe_accel.scheduler.gemm_pipeline import GemmPipeline
a, b = cmd.a, cmd.b
M, K, N = a.shape[-2], a.shape[-1], b.shape[-1]
pid, reply_q = self._alloc_pipeline(env)
pipeline = GemmPipeline(
env=env, M=M, K=K, N=N,
tile_m=self._tile_m, tile_k=self._tile_k, tile_n=self._tile_n,
bytes_per_element=self._bpe,
pipeline_id=pid,
reply_queue=reply_q,
dmaIN_cmd_q=self._dmaIN_cmd_q,
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
)
# Load descriptors into shared blocks
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
self._gemm_block.load_descriptors(pipeline.gemm_descs)
self._math_block.load_descriptors(pipeline.math_descs)
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
start_ns = env.now
yield pipeline.done
wall_clock_ns = env.now - start_ns
del self._pipeline_queues[pid]
# -- Collect metrics ---------------------------------------------------
rd = pe_txn.result_data
rd["M_tiles"] = pipeline.M_tiles
rd["K_tiles"] = pipeline.K_tiles
rd["N_tiles"] = pipeline.N_tiles
rd["total_tiles"] = pipeline.M_tiles * pipeline.K_tiles * pipeline.N_tiles
rd["output_tiles"] = pipeline.M_tiles * pipeline.N_tiles
rd["total_dma_read_bytes"] = sum(d.size_bytes for d in pipeline.dmaIN_descs.values())
rd["total_dma_write_bytes"] = sum(
d.Tm * d.Tn * self._bpe for d in pipeline.dmaWB_descs.values()
)
rd["wall_clock_ns"] = wall_clock_ns
# Per-stage timing
rd["t_tcm_load_per_tile"], rd["t_tcm_load_max_ns"], rd["t_tcm_load_total_ns"] = (
_collect_timing(self._gemm_block.t_tcm_load_per_tile.pop(pid, []))
)
rd["t_compute_per_tile"], rd["t_compute_max_ns"], rd["t_compute_total_ns"] = (
_collect_timing(self._gemm_block.t_compute_per_tile.pop(pid, []))
)
rd["t_tcm_store_per_tile"], rd["t_tcm_store_max_ns"], rd["t_tcm_store_total_ns"] = (
_collect_timing(self._math_block.t_tcm_store_per_tile.pop(pid, []))
)
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
)
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
)
# Derived metrics
rd["mac_utilization"] = rd["t_compute_total_ns"] / wall_clock_ns if wall_clock_ns > 0 else 0.0
rd["effective_tflops"] = 2 * M * K * N / wall_clock_ns / 1e3 if wall_clock_ns > 0 else 0.0
rd["pipeline_parallelism"] = (
(rd["t_tcm_load_total_ns"] + rd["t_compute_total_ns"] + rd["t_tcm_store_total_ns"])
/ wall_clock_ns if wall_clock_ns > 0 else 0.0
)
rd["effective_read_bw_gbs"] = (
rd["total_dma_read_bytes"] / rd["t_dma_read_total_ns"]
if rd["t_dma_read_total_ns"] > 0 else None
)
rd["effective_write_bw_gbs"] = (
rd["total_dma_write_bytes"] / rd["t_dma_write_total_ns"]
if rd["t_dma_write_total_ns"] > 0 else None
)
rd["bottleneck_stage"] = max(
[("load", rd["t_tcm_load_max_ns"]), ("compute", rd["t_compute_max_ns"]),
("store", rd["t_tcm_store_max_ns"])],
key=lambda x: x[1],
)[0]
# pe_cpu.py compatibility aliases
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
rd["compute_ns"] = rd["t_compute_total_ns"]
pe_txn.done.succeed()
# -- Math composite --------------------------------------------------------
def _dispatch_composite_math(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Run tiled element-wise math pipeline and collect metrics."""
from kernbench.components.custom.pe_accel.scheduler.math_pipeline import MathPipeline
assert self._dmaIN_cmd_q is not None
assert self._dmaIN_to_fetch_trig is not None
a = cmd.a
M = a.shape[-2] if len(a.shape) >= 2 else 1
N = a.shape[-1]
op = cmd.math_op or "identity"
pid, reply_q = self._alloc_pipeline(env)
pipeline = MathPipeline(
env=env, M=M, N=N,
tile_m=self._tile_m, tile_n=self._tile_n,
bytes_per_element=self._bpe,
pipeline_id=pid,
reply_queue=reply_q,
dmaIN_cmd_q=self._dmaIN_cmd_q,
dmaIN_to_fetch_trig=self._dmaIN_to_fetch_trig,
op=op,
src_addr=a.addr, dst_addr=cmd.out_addr,
)
# Load descriptors into shared blocks
self._dmaIN_block.load_descriptors(pipeline.dmaIN_descs)
self._math_block.load_math_op_descriptors(pipeline.math_op_descs)
self._dmaWB_block.load_descriptors(pipeline.dmaWB_descs)
start_ns = env.now
yield pipeline.done
wall_clock_ns = env.now - start_ns
del self._pipeline_queues[pid]
# -- Collect metrics ---------------------------------------------------
rd = pe_txn.result_data
rd["M_tiles"] = pipeline.M_tiles
rd["N_tiles"] = pipeline.N_tiles
rd["total_tiles"] = pipeline.M_tiles * pipeline.N_tiles
rd["wall_clock_ns"] = wall_clock_ns
rd["math_op"] = op
# DMA timing
rd["t_dma_read_per_request"], rd["t_dma_read_max_ns"], rd["t_dma_read_total_ns"] = (
_collect_timing(self._dmaIN_block.t_dma_read_per_request.pop(pid, []))
)
rd["t_dma_write_per_tile"], rd["t_dma_write_max_ns"], rd["t_dma_write_total_ns"] = (
_collect_timing(self._dmaWB_block.t_dma_write_per_tile.pop(pid, []))
)
# Math op timing (load + compute + store)
rd["t_math_load_per_tile"], rd["t_math_load_max_ns"], rd["t_math_load_total_ns"] = (
_collect_timing(self._math_block.t_math_op_load_per_tile.pop(pid, []))
)
rd["t_math_compute_per_tile"], rd["t_math_compute_max_ns"], rd["t_math_compute_total_ns"] = (
_collect_timing(self._math_block.t_math_op_compute_per_tile.pop(pid, []))
)
rd["t_math_store_per_tile"], rd["t_math_store_max_ns"], rd["t_math_store_total_ns"] = (
_collect_timing(self._math_block.t_math_op_store_per_tile.pop(pid, []))
)
# pe_cpu.py compatibility aliases
rd["dma_ns"] = rd["t_dma_read_total_ns"] + rd["t_dma_write_total_ns"]
rd["compute_ns"] = rd["t_math_compute_total_ns"]
pe_txn.done.succeed()
# -- Helpers ---------------------------------------------------------------
def _alloc_pipeline(self, env: simpy.Environment) -> tuple[int, simpy.Store]:
"""Allocate a pipeline ID and reply queue."""
pid = self._next_pipeline_id
self._next_pipeline_id += 1
reply_q = simpy.Store(env)
self._pipeline_queues[pid] = reply_q
return pid, reply_q
# ==============================================================================
# Utility
# ==============================================================================
def _collect_timing(times: list) -> tuple[list, float, float]:
"""Return (raw_list, max_ns, total_ns) from a timing list."""
return times, max(times, default=0.0), sum(times)
@@ -0,0 +1,121 @@
"""Tile schedule generators for GEMM and element-wise math operations.
Each generator produces a plan of tile commands with pre-computed addresses.
Pipelines use these plans to build descriptor tables and feed commands
to the shared hardware blocks.
"""
from __future__ import annotations
from math import ceil
from kernbench.components.custom.pe_accel.types import (
CmdType,
MathSchedulePlan,
MathTileCommand,
SchedulePlan,
TileCommand,
)
def generate_gemm_tiles(
M_tiles: int, K_tiles: int, N_tiles: int,
M: int = 0, K: int = 0, N: int = 0,
tile_m: int = 0, tile_k: int = 0, tile_n: int = 0,
bytes_per_element: int = 2,
A_addr: int = 0, B_addr: int = 0, C_addr: int = 0,
pipeline_id: int = 0,
) -> SchedulePlan:
"""Generate GEMM tile commands in M (outer) -> N -> K (inner) order.
Stamps is_last_k=True on the final K-tile per (m, n) pair.
Emits one DMA_FLUSH per (m, n) pair after all K tiles.
Per-tile addresses (row-major layout):
A (M,K): A_addr + (m * tile_m * K + k * tile_k) * bpe
B (K,N): B_addr + (k * tile_k * N + n * tile_n) * bpe
C (M,N): C_addr + (m * tile_m * N + n * tile_n) * bpe
"""
commands: list[TileCommand] = []
cmd_id = 0
tile_id = 0
bpe = bytes_per_element
for m in range(M_tiles):
for n in range(N_tiles):
c_tile_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
for k in range(K_tiles):
last_k = k == K_tiles - 1
a_tile_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
b_tile_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
commands.append(TileCommand(
cmd_id=cmd_id, cmd_type=CmdType.DMA_LOAD,
tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n,
is_last_k=last_k, pipeline_id=pipeline_id,
a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr,
c_tile_addr=c_tile_addr,
))
cmd_id += 1
commands.append(TileCommand(
cmd_id=cmd_id, cmd_type=CmdType.TENSOR_OP,
tile_id=tile_id, m_idx=m, k_idx=k, n_idx=n,
is_last_k=last_k, pipeline_id=pipeline_id,
a_tile_addr=a_tile_addr, b_tile_addr=b_tile_addr,
c_tile_addr=c_tile_addr,
))
cmd_id += 1
tile_id += 1
# One flush per (m, n) pair after all K tiles
commands.append(TileCommand(
cmd_id=cmd_id, cmd_type=CmdType.DMA_FLUSH,
tile_id=tile_id - 1, m_idx=m, k_idx=0, n_idx=n,
pipeline_id=pipeline_id,
c_tile_addr=c_tile_addr,
))
cmd_id += 1
return SchedulePlan(
commands=commands, M_tiles=M_tiles, K_tiles=K_tiles, N_tiles=N_tiles
)
def generate_math_tiles(
M_tiles: int, N_tiles: int,
M: int = 0, N: int = 0,
tile_m: int = 0, tile_n: int = 0,
bytes_per_element: int = 2,
src_addr: int = 0, dst_addr: int = 0,
pipeline_id: int = 0,
) -> MathSchedulePlan:
"""Generate element-wise math tile commands in row-major order.
Per-tile addresses (row-major layout):
src: src_addr + (m * tile_m * N + n * tile_n) * bpe
dst: dst_addr + (m * tile_m * N + n * tile_n) * bpe
"""
commands: list[MathTileCommand] = []
cmd_id = 0
tile_id = 0
bpe = bytes_per_element
for m in range(M_tiles):
for n in range(N_tiles):
offset = (m * tile_m * N + n * tile_n) * bpe
commands.append(MathTileCommand(
cmd_id=cmd_id,
tile_id=tile_id,
m_idx=m,
n_idx=n,
src_tile_addr=src_addr + offset,
dst_tile_addr=dst_addr + offset,
pipeline_id=pipeline_id,
))
cmd_id += 1
tile_id += 1
return MathSchedulePlan(
commands=commands, M_tiles=M_tiles, N_tiles=N_tiles
)
@@ -0,0 +1,148 @@
"""Data types for pe_accel_v1: descriptors, triggers, tile commands.
All types are frozen/plain dataclasses with no logic.
Schedule generators live in tiling/schedule.py.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
# -- Enums ---------------------------------------------------------------------
class CmdType(Enum):
DMA_LOAD = auto()
TENSOR_OP = auto()
DMA_FLUSH = auto()
# -- Inter-block messaging -----------------------------------------------------
@dataclass
class Trigger:
"""Completion token passed between hardware blocks."""
tile_id: int
pipeline_id: int
vc: int | None = None
source_block: str = ""
@dataclass
class DmaRequest:
"""DMA load request — descriptor lookup key only.
Transfer params (size, address) live in the pre-loaded DmaInDescriptor.
"""
tile_id: int
pipeline_id: int
operand: str # "A" or "B"
# -- Descriptors (pre-loaded by pipelines, consumed by blocks) -----------------
@dataclass
class DmaInDescriptor:
"""Per-operand DMA read descriptor."""
size_bytes: int
src_addr: int = 0
next_block: str = "GEMM" # "GEMM" | "MATH" | "COMPLETION"
@dataclass
class GemmDescriptor:
"""Per-tile GEMM descriptor."""
Tm: int
Tk: int
Tn: int
triggers_needed: int = 2 # 2 = both operands from DMA; 1 = SPMem bypass
gemm_load: bool = True
gemm_compute: bool = True
next_block: str = "MATH" # "MATH" | "DMAWB" | "DONE"
@dataclass
class MathDescriptor:
"""Per-tile K-accumulation descriptor (used by GEMM pipeline)."""
Tm: int
Tn: int
is_last_k: bool
skip_dmaWB: bool # True = C stays in SPMem; False = flush to HBM
@dataclass
class MathOpDescriptor:
"""Per-tile element-wise math op descriptor (used by math pipeline)."""
Tm: int
Tn: int
op: str # "exp", "log", "sqrt", "sigmoid", etc.
src_addr: int = 0
dst_addr: int = 0
@dataclass
class DmaWBDescriptor:
"""Per-tile DMA writeback descriptor."""
Tm: int
Tn: int
dst_addr: int = 0
# -- Tile commands (produced by schedule generators) ---------------------------
@dataclass
class TileCommand:
"""A single GEMM tile command."""
cmd_id: int
cmd_type: CmdType
tile_id: int
m_idx: int
k_idx: int
n_idx: int
is_last_k: bool = False
pipeline_id: int = 0
a_tile_addr: int = 0
b_tile_addr: int = 0
c_tile_addr: int = 0
@dataclass
class SchedulePlan:
"""Full tile schedule for one GEMM operation."""
commands: list # list[TileCommand]
M_tiles: int
K_tiles: int
N_tiles: int
@dataclass
class MathTileCommand:
"""A single element-wise math tile command."""
cmd_id: int
tile_id: int
m_idx: int
n_idx: int
src_tile_addr: int = 0
dst_tile_addr: int = 0
pipeline_id: int = 0
@dataclass
class MathSchedulePlan:
"""Full tile schedule for one element-wise math operation."""
commands: list # list[MathTileCommand]
M_tiles: int
N_tiles: int
+52 -1
View File
@@ -20,6 +20,7 @@ _PE_COMP_OFFSETS = {
"pe_cpu": (-0.3, 0.0),
"pe_scheduler": (-0.15, 0.0),
"pe_dma": (0.0, -0.15),
"pe_fetch_store": (0.15, 0.0),
"pe_gemm": (0.0, 0.0),
"pe_math": (0.0, 0.15),
"pe_mmu": (0.15, -0.15),
@@ -637,12 +638,13 @@ def _instantiate_cube(
def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
"""Add PE-internal edges for a single PE instance."""
"""Add PE-internal edges for a single PE instance (ADR-0021)."""
edges.append(Edge(
src=f"{pp}.pe_cpu", dst=f"{pp}.pe_scheduler",
distance_mm=pe_links["pe_cpu_to_scheduler_mm"],
kind="pe_internal",
))
# Scheduler → engines (initial dispatch)
for eng, key in [("pe_dma", "scheduler_to_dma_mm"),
("pe_gemm", "scheduler_to_gemm_mm"),
("pe_math", "scheduler_to_math_mm")]:
@@ -651,6 +653,15 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
distance_mm=pe_links[key],
kind="pe_internal",
))
# Scheduler → fetch_store (initial dispatch)
if "scheduler_to_fetch_store_mm" in pe_links:
edges.append(Edge(
src=f"{pp}.pe_scheduler", dst=f"{pp}.pe_fetch_store",
distance_mm=pe_links["scheduler_to_fetch_store_mm"],
kind="pe_internal",
))
# Engine → TCM (legacy BW edges)
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
@@ -661,6 +672,32 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
kind="pe_internal",
))
# Fetch/Store → TCM (ADR-0021 D5)
if "fetch_store_to_tcm_mm" in pe_links:
edges.append(Edge(
src=f"{pp}.pe_fetch_store", dst=f"{pp}.pe_tcm",
distance_mm=pe_links["fetch_store_to_tcm_mm"],
bw_gbs=pe_links.get("fetch_store_to_tcm_bw_gbs", 512.0),
kind="pe_internal",
))
# Chaining edges (ADR-0021 D4 — token self-routing)
chaining = [
("pe_dma", "pe_fetch_store", "dma_to_fetch_store_mm"),
("pe_fetch_store", "pe_gemm", "fetch_store_to_gemm_mm"),
("pe_fetch_store", "pe_math", "fetch_store_to_math_mm"),
("pe_gemm", "pe_fetch_store", "gemm_to_fetch_store_mm"),
("pe_math", "pe_fetch_store", "math_to_fetch_store_mm"),
("pe_fetch_store", "pe_dma", "fetch_store_to_dma_mm"),
]
for src_eng, dst_eng, mm_key in chaining:
if mm_key in pe_links:
edges.append(Edge(
src=f"{pp}.{src_eng}", dst=f"{pp}.{dst_eng}",
distance_mm=pe_links[mm_key],
kind="pe_internal",
))
# ── Inter-cube / IO / system edges ──────────────────────────────────
@@ -1071,6 +1108,7 @@ def _build_pe_view(spec: dict) -> ViewGraph:
"pe_cpu": (1.5, 4.0),
"pe_scheduler": (4.0, 4.0),
"pe_dma": (7.0, 1.5),
"pe_fetch_store": (8.5, 4.0),
"pe_gemm": (7.0, 4.0),
"pe_math": (7.0, 6.5),
"pe_mmu": (4.0, 1.5),
@@ -1101,6 +1139,12 @@ def _build_pe_view(spec: dict) -> ViewGraph:
distance_mm=pe_links[key],
kind="pe_internal",
))
if "scheduler_to_fetch_store_mm" in pe_links:
view_edges.append(Edge(
src="pe_scheduler", dst="pe_fetch_store",
distance_mm=pe_links["scheduler_to_fetch_store_mm"],
kind="pe_internal",
))
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
@@ -1110,6 +1154,13 @@ def _build_pe_view(spec: dict) -> ViewGraph:
bw_gbs=pe_links[bw_key],
kind="pe_internal",
))
if "fetch_store_to_tcm_mm" in pe_links:
view_edges.append(Edge(
src="pe_fetch_store", dst="pe_tcm",
distance_mm=pe_links["fetch_store_to_tcm_mm"],
bw_gbs=pe_links.get("fetch_store_to_tcm_bw_gbs", 512.0),
kind="pe_internal",
))
return ViewGraph(
name="pe", nodes=nodes, edges=view_edges,
+7 -6
View File
@@ -19,16 +19,16 @@ def test_full_graph_node_count():
# + 2 SIPs x (1 IO x 23 io_nodes
# + 16 cubes x (32 routers + 1 hbm_ctrl + 1 m_cpu + 1 sram
# + 20 ucie (4 ports x (1 port + 4 conn))
# + 8 PEs x 7 pe_comps))
# + 8 PEs x 8 pe_comps)) (ADR-0021: +pe_fetch_store)
# IO: pcie_ep + io_cpu + noc + 4 io_ucie_ports + 4*4 io_ucie_conn = 23
# cube: 32 + 3 + 20 + 56 = 111
# = 1 + 2*(23 + 16*111) = 1 + 2*(23+1776) = 1 + 3598 = 3599
assert len(g.nodes) == 3599
# cube: 32 + 3 + 20 + 64 = 119
# = 1 + 2*(23 + 16*119) = 1 + 2*(23+1904) = 1 + 3854 = 3855
assert len(g.nodes) == 3855
def test_full_graph_edge_count():
g = _graph()
assert len(g.edges) == 10874
assert len(g.edges) == 12922 # ADR-0021: +pe_fetch_store + chaining edges
# -- Full graph: specific nodes exist -----------------------------------------
@@ -286,7 +286,8 @@ def test_cube_view_pe_to_router():
def test_pe_view_has_all_components():
v = _graph().pe_view
assert set(v.nodes.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_fetch_store",
"pe_gemm", "pe_math", "pe_mmu", "pe_tcm",
}
+2 -1
View File
@@ -23,7 +23,8 @@ def test_pe_template_components():
spec = _read_spec(TOPOLOGY_PATH)
comps = spec["cube"]["pe_template"]["components"]
assert set(comps.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_fetch_store",
"pe_gemm", "pe_math", "pe_mmu", "pe_tcm",
}
+12 -3
View File
@@ -65,17 +65,26 @@ cube:
pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } }
pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } }
pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }
pe_fetch_store: { kind: pe_fetch_store, impl: pe_fetch_store_v1, attrs: { overhead_ns: 0.0 } }
pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } }
pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs:
{ size_mb: 16 } }
pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs: { size_mb: 16, read_bw_gbs: 512.0, write_bw_gbs: 512.0 } }
links:
pe_cpu_to_scheduler_mm: 0.5
scheduler_to_dma_mm: 0.5
scheduler_to_gemm_mm: 0.5
scheduler_to_math_mm: 0.5
scheduler_to_fetch_store_mm: 0.5
dma_to_tcm_bw_gbs: 512.0
dma_to_tcm_mm: 0.5
gemm_to_tcm_bw_gbs: 512.0 # GEMM reads inputs from TCM (ADR-0014 D5)
dma_to_fetch_store_mm: 0.0 # DMA → fetch_store chaining (ADR-0021)
fetch_store_to_tcm_bw_gbs: 512.0
fetch_store_to_tcm_mm: 0.0
fetch_store_to_gemm_mm: 0.0 # fetch → GEMM chaining (ADR-0021)
fetch_store_to_math_mm: 0.0 # fetch → MATH chaining (ADR-0021)
gemm_to_fetch_store_mm: 0.0 # GEMM → store chaining (ADR-0021)
math_to_fetch_store_mm: 0.0 # MATH → store chaining (ADR-0021)
fetch_store_to_dma_mm: 0.0 # store → DMA writeback chaining (ADR-0021)
gemm_to_tcm_bw_gbs: 512.0
gemm_to_tcm_mm: 0.5
math_to_tcm_bw_gbs: 512.0
math_to_tcm_mm: 0.5