Implement ADR-0021: PE pipeline refactor with token self-routing

Step 1-2: Backup existing code
- builtin/ → builtin_legacy/ (unchanged backup)
- custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup)

Step 3-4: New pipeline types and tiling
- pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken
- tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel)

Step 5: Component implementations (ADR-0021 D4-D6)
- PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation
- PE_FETCH_STORE: new component — TCM ↔ Register File
- PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode
- PE_MATH: TileToken pipeline + legacy dual-mode
- PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode
- PE_TCM: TcmRequest handler with dual-channel BW serialization

Step 6: Infrastructure
- topology.yaml: pe_fetch_store component + chaining edges
- components.yaml: pe_fetch_store_v1 registration
- builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions
- Tests: node/edge counts, PE component sets updated

All components handle both TileToken (pipeline) and PeInternalTxn (legacy).
Token self-routing: components read next stage from token.plan, chain via out_port.
366 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:35:31 -07:00
parent 161132cdcb
commit b6eb97c49a
40 changed files with 4055 additions and 214 deletions
@@ -105,6 +105,73 @@ class PeDmaComponent(PeEngineBase):
yield sub_done
pe_txn.done.succeed()
def _worker(self, env: simpy.Environment) -> Generator:
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: DMA read/write via fabric, then self-route."""
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, TensorHandle
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import PeDmaMsg
self._on_process_start(env, token)
params = token.params
stage_type = token.current_stage.stage_type
from kernbench.components.builtin.pe_types import StageType
is_write = stage_type == StageType.DMA_WRITE
addr = params.get("dst_addr" if is_write else "src_addr", 0)
nbytes = params.get("nbytes", 0)
if nbytes > 0 and self.ctx:
dma_res = self._dma_write if is_write else self._dma_read
assert dma_res is not None
pa = PhysAddr.decode(addr)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
drain_ns = self.ctx.compute_drain_ns(path, nbytes)
with dma_res.request() as req:
yield req
sub_done = env.event()
sub_request = PeDmaMsg(
correlation_id="pipeline",
request_id=f"tile_{token.tile_id}",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=addr, nbytes=nbytes,
is_write=is_write,
)
sub_txn = Transaction(
request=sub_request, path=path, step=0,
nbytes=nbytes, done=sub_done, drain_ns=drain_ns,
)
if len(path) > 1:
yield self.out_ports[path[1]].put(sub_txn.advance())
yield sub_done
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
# Response transactions bypass DMA channel (no outbound resource needed)
@@ -0,0 +1,77 @@
"""PE_FETCH_STORE: TCM ↔ Register File transfer unit (ADR-0021 D5).
Handles both fetch (TCM → register) and store (register → TCM).
BW serialization is delegated to PE_TCM via port communication.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeFetchStoreComponent(PeEngineBase):
"""PE_FETCH_STORE: TCM ↔ Register File (ADR-0021 D5).
Receives TileTokens via pipeline self-routing.
Sends TcmRequest to PE_TCM for BW-based latency.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._tcm_id = f"{self._pe_prefix}.pe_tcm"
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Handle both PeInternalTxn (legacy) and TileToken (pipeline)."""
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self.handle_command(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Process a pipeline TileToken: fetch or store via TCM."""
from kernbench.components.builtin.pe_tcm import TcmRequest
self._on_process_start(env, token)
direction = token.params.get("direction", "read")
nbytes = token.params.get("nbytes", 0)
if nbytes > 0 and self._tcm_id in self.out_ports:
done = env.event()
yield self.out_ports[self._tcm_id].put(
TcmRequest(direction=direction, nbytes=nbytes, done=done)
)
yield done
self._on_process_end(env, token)
# Self-routing: advance to next stage
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Legacy PeInternalTxn handling."""
yield from self.run(env, 0)
pe_txn.done.succeed()
+77 -16
View File
@@ -1,6 +1,18 @@
"""PE_GEMM: matrix multiplication engine (ADR-0021 D6).
Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken.
In pipeline mode, receives token after fetch stage, computes MAC, chains to next.
MAC latency model (from pe_accel):
cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n)
latency_ns = cycles / clock_freq_ghz
Falls back to TFLOPS model when mac dimensions not configured.
"""
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
@@ -12,33 +24,29 @@ if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# dtype → bit width (for TFLOPS scaling)
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8,
"i16": 16, "int16": 16,
"i32": 32, "int32": 32,
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
"""PE_GEMM: MAC array (ADR-0021 D6).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_MATH within the same PE.
Compute latency model:
FLOPs = 2 * M * K * N
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
compute_ns = FLOPs / (effective_tflops * 1e3)
In pipeline mode: pure compute — register data already fetched.
In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
# Cycle-accurate MAC dimensions (from pe_accel)
self._mac_m: int = int(node.attrs.get("mac_m", 0))
self._mac_k: int = int(node.attrs.get("mac_k", 0))
self._mac_n: int = int(node.attrs.get("mac_n", 0))
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
@@ -47,8 +55,15 @@ class PeGemmComponent(PeEngineBase):
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
"""Compute GEMM latency in nanoseconds."""
def _compute_ns_mac(self, m: int, k: int, n: int) -> float:
"""Cycle-accurate MAC latency (pe_accel model)."""
if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0:
cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n)
return cycles / self._clock_freq
return 0.0
def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""TFLOPS-based latency (legacy model)."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
@@ -56,11 +71,58 @@ class PeGemmComponent(PeEngineBase):
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
"""Choose best available latency model."""
mac_ns = self._compute_ns_mac(m, k, n)
if mac_ns > 0:
return mac_ns
return self._compute_ns_tflops(m, k, n, dtype)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure MAC compute, then self-route."""
self._on_process_start(env, token)
m = token.params.get("m", 0)
k = token.params.get("k", 0)
n = token.params.get("n", 0)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
else:
ns = self._compute_ns(m, k, n)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Legacy PeInternalTxn handling."""
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
@@ -81,7 +143,6 @@ class PeGemmComponent(PeEngineBase):
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
+60 -4
View File
@@ -1,6 +1,16 @@
"""PE_MATH: element-wise / reduction computation engine (ADR-0021 D6).
Handles both legacy PeInternalTxn (MathCmd) and pipeline TileToken.
In pipeline mode, receives token after fetch stage, computes SIMD, chains to next.
SIMD latency model (from pe_accel):
cycles = ceil(num_elements / vector_width)
latency_ns = cycles / clock_freq_ghz
"""
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
@@ -14,15 +24,17 @@ if TYPE_CHECKING:
class PeMathComponent(PeEngineBase):
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
"""PE_MATH: SIMD/Vector unit (ADR-0021 D6).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_GEMM within the same PE.
In pipeline mode: pure compute — register data already fetched.
In legacy mode: handles PeInternalTxn(MathCmd) with shared accel_slot.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._vector_width: int = int(node.attrs.get("vector_width", 256))
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
@@ -31,11 +43,56 @@ class PeMathComponent(PeEngineBase):
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, num_elements: int) -> float:
"""SIMD latency (pe_accel model)."""
if self._vector_width > 0 and self._clock_freq > 0 and num_elements > 0:
cycles = ceil(num_elements / self._vector_width)
return cycles / self._clock_freq
return float(self.node.attrs.get("overhead_ns", 0.0))
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: pure SIMD compute, then self-route."""
self._on_process_start(env, token)
num_elements = token.params.get("num_elements", 0)
if self._accel:
with self._accel.request() as req:
yield req
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
else:
ns = self._compute_ns(num_elements)
yield env.timeout(ns)
self._on_process_end(env, token)
# Self-routing
next_stage = token.advance()
if next_stage is not None:
yield self.out_ports[next_stage.component].put(token)
else:
token.pipeline_ctx.complete_tile()
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Legacy PeInternalTxn handling."""
if self._accel:
with self._accel.request() as req:
yield req
@@ -45,7 +102,6 @@ class PeMathComponent(PeEngineBase):
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
+101 -167
View File
@@ -1,3 +1,13 @@
"""PE_SCHEDULER: plan generation + tile dispatch (ADR-0021 D2).
Receives PeInternalTxn from PE_CPU, routes to engines:
- Simple commands (DmaReadCmd, GemmCmd, etc.) → direct dispatch to engine
- CompositeCmd → generate TilePlan, feed tiles via _feed_loop
Composite pipeline uses token self-routing (ADR-0021 D4):
Scheduler only does initial dispatch + completion tracking.
Tiles chain through components based on their plan's stage sequence.
"""
from __future__ import annotations
from collections.abc import Generator
@@ -14,29 +24,18 @@ if TYPE_CHECKING:
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1, ADR-0021 D2).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Simple commands are forwarded to the appropriate engine.
CompositeCmd creates a TilePlan and feeds tiles into the pipeline.
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
Single _feed_loop process per scheduler ensures FIFO command ordering.
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
@@ -44,7 +43,6 @@ class PeSchedulerComponent(ComponentBase):
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
@@ -56,6 +54,13 @@ class PeSchedulerComponent(ComponentBase):
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
self._pending_feeds: simpy.Store | None = None
self._pipeline_counter = 0
def start(self, env: simpy.Environment) -> None:
self._pending_feeds = simpy.Store(env)
super().start(env)
env.process(self._feed_loop(env))
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
@@ -72,174 +77,103 @@ class PeSchedulerComponent(ComponentBase):
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
from kernbench.common.pe_commands import CompositeCmd, PeCpuOverheadCmd
# Scheduler overhead
yield from self.run(env, 0)
yield from self.run(env, 0) # scheduler overhead
cmd = pe_txn.command
# Check dispatch table first
# Simple command dispatch
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
# CompositeCmd: generate plan and feed
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
yield from self._dispatch_composite(env, pe_txn, cmd)
return
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
pe_txn.done.succeed()
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
def _dispatch_composite(
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
) -> Generator:
"""Generate plan and enqueue to feeder. Non-blocking (ADR-0021 D4)."""
from kernbench.components.builtin.pe_types import PipelineContext
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
plan = self._generate_plan(cmd)
self._pipeline_counter += 1
ctx = PipelineContext(
id=f"p{self._pipeline_counter}",
total_tiles=len(plan.tiles),
done_event=pe_txn.done,
)
# Enqueue to feeder — scheduler worker returns immediately
assert self._pending_feeds is not None
yield self._pending_feeds.put((plan, ctx))
def _feed_loop(self, env: simpy.Environment) -> Generator:
"""Single feeder process: FIFO command ordering (ADR-0021 D2).
No tile feed interleaving between commands.
Queue full → only this process blocks.
"""
from kernbench.common.pe_commands import CompositeCmd
from kernbench.components.builtin.pe_types import TileToken
assert self._pending_feeds is not None
while True:
plan, ctx = yield self._pending_feeds.get()
for tile in plan.tiles:
first_stage = tile.stages[0]
token = TileToken(
tile_id=tile.tile_id,
pipeline_ctx=ctx,
plan=tile,
stage_idx=0,
params=first_stage.params,
)
yield self.out_ports[first_stage.component].put(token)
def _generate_plan(self, cmd: Any) -> Any:
"""Generate a PipelinePlan from CompositeCmd."""
from kernbench.components.builtin.tiling import (
generate_gemm_plan,
generate_math_plan,
)
pp = self._pe_prefix
bpe = 2 # default bytes per element (f16)
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
a = cmd.a
b = cmd.b
M, K = a.shape[-2], a.shape[-1]
N = b.shape[-1]
return generate_gemm_plan(
M=M, K=K, N=N,
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
bytes_per_element=bpe,
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
pe_prefix=pp,
)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
# Math composite
a = cmd.a
M = a.shape[-2] if len(a.shape) >= 2 else a.shape[0]
N = a.shape[-1] if len(a.shape) >= 2 else 1
return generate_math_plan(
M=M, N=N,
tile_m=self.TILE_M, tile_n=self.TILE_N,
bytes_per_element=bpe,
math_op=cmd.math_op or "identity",
src_addr=a.addr, dst_addr=cmd.out_addr,
pe_prefix=pp,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()
+63 -6
View File
@@ -1,7 +1,18 @@
"""PE_TCM: tightly-coupled memory with BW-based access serialization (ADR-0021).
Models scratchpad memory inside the PE. Handles both legacy Transaction forwarding
and TcmRequest from PE_FETCH_STORE for BW-serialized read/write access.
Two channels (read/write) with independent serialization.
Ported from pe_accel TcmBlock timing model.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
@@ -10,16 +21,62 @@ if TYPE_CHECKING:
from kernbench.topology.types import Node
class PeTcmComponent(ComponentBase):
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
@dataclass
class TcmRequest:
"""Request to read from or write to TCM (used by PE_FETCH_STORE)."""
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
Phase 0: applies overhead_ns and drain_ns at terminal.
direction: str # "read" or "write"
nbytes: int
done: simpy.Event
tag: str = ""
class PeTcmComponent(ComponentBase):
"""PE_TCM: BW-serialized scratchpad memory (ADR-0021 D1).
Dual-channel: read and write can proceed in parallel,
but concurrent reads serialize, concurrent writes serialize.
BW from topology attrs or pe_template links.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read_bw: float = float(node.attrs.get("read_bw_gbs", 512.0))
self._write_bw: float = float(node.attrs.get("write_bw_gbs", 512.0))
self._read_res: simpy.Resource | None = None
self._write_res: simpy.Resource | None = None
def run(self, env, nbytes: int) -> Generator:
def start(self, env: simpy.Environment) -> None:
self._read_res = simpy.Resource(env, capacity=1)
self._write_res = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch TcmRequest (from fetch_store) and Transaction (fabric)."""
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TcmRequest):
env.process(self._handle_tcm_request(env, msg))
else:
env.process(self._forward_txn(env, msg))
def _handle_tcm_request(self, env: simpy.Environment, req: TcmRequest) -> Generator:
"""BW-serialized access: acquire channel, apply delay, signal done."""
if req.direction == "write":
res = self._write_res
bw = self._write_bw
else:
res = self._read_res
bw = self._read_bw
assert res is not None
with res.request() as lock:
yield lock
if bw > 0 and req.nbytes > 0:
delay_ns = req.nbytes / bw
yield env.timeout(delay_ns)
req.done.succeed()
@@ -0,0 +1,115 @@
"""PE pipeline types for ADR-0021: TileToken, TilePlan, Stage, PipelineContext.
These types are used by the PE_SCHEDULER and all PE engine components
for tile-based pipeline execution with self-routing.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import simpy
# ── Stage types ──────────────────────────────────────────────────────
class StageType(Enum):
DMA_READ = auto()
FETCH = auto()
GEMM = auto()
MATH = auto()
STORE = auto()
DMA_WRITE = auto()
@dataclass
class Stage:
"""One stage in a tile's execution plan."""
stage_type: StageType
component: str # topology node ID (e.g. "sip0.cube0.pe0.pe_dma")
params: dict = field(default_factory=dict)
# ── Plan ─────────────────────────────────────────────────────────────
@dataclass
class TilePlan:
"""Execution plan for a single tile (immutable stage sequence)."""
tile_id: int
stages: tuple[Stage, ...]
@dataclass
class PipelinePlan:
"""Full pipeline plan for one CompositeCmd."""
tiles: list[TilePlan]
# Metadata for metrics
m_tiles: int = 0
k_tiles: int = 0
n_tiles: int = 0
# ── Pipeline Context ─────────────────────────────────────────────────
@dataclass
class PipelineContext:
"""Tracks completion of a pipeline (exactly-once contract).
Each tile's last stage calls complete_tile() exactly once.
When all tiles complete, done_event.succeed() is called.
"""
id: str
total_tiles: int
completed_tiles: int = 0
done_event: Any = None # simpy.Event
def complete_tile(self) -> None:
self.completed_tiles += 1
if self.completed_tiles == self.total_tiles:
if self.done_event is not None:
self.done_event.succeed()
# ── TileToken ────────────────────────────────────────────────────────
@dataclass
class TileToken:
"""Self-routing tile token passed between PE components (ADR-0021 D9).
Single-owner: only one component holds this token at any time.
params is a cache of plan.stages[stage_idx].params (canonical source).
"""
tile_id: int
pipeline_ctx: PipelineContext
plan: TilePlan
stage_idx: int
params: dict = field(default_factory=dict)
data_op: bool = True # op_log recording target (ADR-0020)
@property
def current_stage(self) -> Stage:
return self.plan.stages[self.stage_idx]
@property
def has_next_stage(self) -> bool:
return self.stage_idx + 1 < len(self.plan.stages)
def advance(self) -> Stage | None:
"""Advance to next stage. Returns next Stage or None if last."""
self.stage_idx += 1
if self.stage_idx < len(self.plan.stages):
next_stage = self.plan.stages[self.stage_idx]
self.params = next_stage.params
return next_stage
return None
+176
View File
@@ -0,0 +1,176 @@
"""Tile plan generators for PE pipeline (ADR-0021).
Generates TilePlan with stage sequences for GEMM and Math operations.
Ported from pe_accel tiling.py with stage-based plan structure.
"""
from __future__ import annotations
from math import ceil
from kernbench.components.builtin.pe_types import (
PipelinePlan,
Stage,
StageType,
TilePlan,
)
def generate_gemm_plan(
M: int, K: int, N: int,
tile_m: int, tile_k: int, tile_n: int,
bytes_per_element: int,
A_addr: int, B_addr: int, C_addr: int,
pe_prefix: str,
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
Each tile follows stage sequence:
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
On last K-tile per (m,n): → DMA_WRITE
Args:
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
"""
M_tiles = max(1, ceil(M / tile_m))
K_tiles = max(1, ceil(K / tile_k))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
gemm_id = f"{pe_prefix}.pe_gemm"
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
for k in range(K_tiles):
last_k = k == K_tiles - 1
a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
a_bytes = tile_m * tile_k * bpe
b_bytes = tile_k * tile_n * bpe
out_bytes = tile_m * tile_n * bpe
stages: list[Stage] = []
# DMA READ: load A and B tiles from HBM → TCM
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# FETCH: TCM → Register File
stages.append(Stage(
stage_type=StageType.FETCH,
component=fetch_id,
params={
"direction": "read",
"nbytes": a_bytes + b_bytes,
},
))
# GEMM: MAC compute
stages.append(Stage(
stage_type=StageType.GEMM,
component=gemm_id,
params={
"m": tile_m, "k": tile_k, "n": tile_n,
"is_last_k": last_k,
},
))
# STORE: Register File → TCM
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
# DMA WRITE: TCM → HBM (only on last K-tile)
if last_k:
stages.append(Stage(
stage_type=StageType.DMA_WRITE,
component=dma_id,
params={
"dst_addr": c_addr, "nbytes": out_bytes,
},
))
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(
tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles,
)
def generate_math_plan(
M: int, N: int,
tile_m: int, tile_n: int,
bytes_per_element: int,
math_op: str,
src_addr: int, dst_addr: int,
pe_prefix: str,
) -> PipelinePlan:
"""Generate element-wise math tile plan.
Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE
"""
M_tiles = max(1, ceil(M / tile_m))
N_tiles = max(1, ceil(N / tile_n))
bpe = bytes_per_element
dma_id = f"{pe_prefix}.pe_dma"
fetch_id = f"{pe_prefix}.pe_fetch_store"
math_id = f"{pe_prefix}.pe_math"
tiles: list[TilePlan] = []
tile_id = 0
for m in range(M_tiles):
for n in range(N_tiles):
offset = (m * tile_m * N + n * tile_n) * bpe
tile_bytes = tile_m * tile_n * bpe
stages = [
Stage(StageType.DMA_READ, dma_id, {
"src_addr": src_addr + offset, "nbytes": tile_bytes,
}),
Stage(StageType.FETCH, fetch_id, {
"direction": "read", "nbytes": tile_bytes,
}),
Stage(StageType.MATH, math_id, {
"op": math_op, "num_elements": tile_m * tile_n,
}),
Stage(StageType.STORE, fetch_id, {
"direction": "write", "nbytes": tile_bytes,
}),
Stage(StageType.DMA_WRITE, dma_id, {
"dst_addr": dst_addr + offset, "nbytes": tile_bytes,
}),
]
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
tile_id += 1
return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)