Implement ADR-0021: PE pipeline refactor with token self-routing
Step 1-2: Backup existing code - builtin/ → builtin_legacy/ (unchanged backup) - custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup) Step 3-4: New pipeline types and tiling - pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken - tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel) Step 5: Component implementations (ADR-0021 D4-D6) - PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation - PE_FETCH_STORE: new component — TCM ↔ Register File - PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode - PE_MATH: TileToken pipeline + legacy dual-mode - PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode - PE_TCM: TcmRequest handler with dual-channel BW serialization Step 6: Infrastructure - topology.yaml: pe_fetch_store component + chaining edges - components.yaml: pe_fetch_store_v1 registration - builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions - Tests: node/edge counts, PE component sets updated All components handle both TileToken (pipeline) and PeInternalTxn (legacy). Token self-routing: components read next stage from token.plan, chain via out_port. 366 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeSchedulerComponent(ComponentBase):
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
||||
- GemmCmd → PE_GEMM
|
||||
- MathCmd → PE_MATH
|
||||
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
||||
|
||||
Composite GEMM pipeline (32x64x32 tiles):
|
||||
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
||||
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
|
||||
Applies scheduler overhead_ns before dispatching each command.
|
||||
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
# Scheduler tile dimensions (ADR-0014 D3.2)
|
||||
TILE_M = 32
|
||||
TILE_K = 64
|
||||
TILE_N = 32
|
||||
|
||||
# Command → engine suffix dispatch table.
|
||||
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
||||
_CMD_DISPATCH: dict[type, str] = {}
|
||||
|
||||
@classmethod
|
||||
def _ensure_dispatch_table(cls) -> None:
|
||||
if cls._CMD_DISPATCH:
|
||||
return
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
||||
|
||||
cls._CMD_DISPATCH = {
|
||||
DmaReadCmd: "pe_dma",
|
||||
DmaWriteCmd: "pe_dma",
|
||||
GemmCmd: "pe_gemm",
|
||||
MathCmd: "pe_math",
|
||||
}
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
||||
self._ensure_dispatch_table()
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
yield from self._forward_txn(env, msg)
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
# Scheduler overhead
|
||||
yield from self.run(env, 0)
|
||||
|
||||
cmd = pe_txn.command
|
||||
|
||||
# Check dispatch table first
|
||||
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
||||
if engine_suffix is not None:
|
||||
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
||||
return
|
||||
|
||||
# CompositeCmd: tiled pipeline (not a simple forward)
|
||||
if isinstance(cmd, CompositeCmd):
|
||||
yield from self._dispatch_composite(env, pe_txn)
|
||||
return
|
||||
|
||||
# Unknown command — signal done immediately
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Composite tiled pipeline (ADR-0014 D3.2).
|
||||
|
||||
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
||||
MATH: sequential compute + DMA_WRITE (no tiling).
|
||||
"""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._pipeline_math(env, pe_txn, cmd)
|
||||
|
||||
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
||||
|
||||
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
||||
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
||||
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
"""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
TensorHandle,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
a = cmd.a # already in TCM
|
||||
b = cmd.b # HBM reference (via tl.ref)
|
||||
|
||||
M, K_a = a.shape[-2], a.shape[-1]
|
||||
K_b, N = b.shape[-2], b.shape[-1]
|
||||
dtype = a.dtype
|
||||
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
||||
|
||||
# Tile counts
|
||||
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
||||
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
||||
n_tiles = n_tiles_k * n_tiles_n
|
||||
|
||||
prev_compute_done = None
|
||||
prev_write_done = None
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
|
||||
for tile_idx in range(n_tiles):
|
||||
tk = tile_idx // n_tiles_n
|
||||
tn = tile_idx % n_tiles_n
|
||||
|
||||
k_start = tk * self.TILE_K
|
||||
n_start = tn * self.TILE_N
|
||||
tile_k = min(self.TILE_K, K_a - k_start)
|
||||
tile_n = min(self.TILE_N, N - n_start)
|
||||
tile_nbytes = tile_k * tile_n * dtype_bytes
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
|
||||
# Wait for previous compute before starting this tile's compute
|
||||
if prev_compute_done is not None:
|
||||
yield prev_compute_done
|
||||
|
||||
# Wait for this tile's DMA_READ
|
||||
yield read_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", addr=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
||||
m=M, k=tile_k, n=tile_n)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
||||
|
||||
# Wait for previous write (DMA_WRITE serialization)
|
||||
if prev_write_done is not None:
|
||||
yield prev_write_done
|
||||
|
||||
# Wait for compute of THIS tile
|
||||
yield compute_done
|
||||
total_compute_ns += env.now - t0
|
||||
prev_compute_done = compute_done
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
prev_write_done = write_done
|
||||
|
||||
# Wait for final write
|
||||
if prev_write_done is not None:
|
||||
t0 = env.now
|
||||
yield prev_write_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
pe_txn.result_data["dma_ns"] = total_dma_ns
|
||||
pe_txn.result_data["compute_ns"] = total_compute_ns
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaWriteCmd,
|
||||
MathCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
|
||||
# Step 1: Compute (MATH)
|
||||
compute_done = env.event()
|
||||
compute_cmd = MathCmd(
|
||||
op=cmd.math_op or "identity",
|
||||
inputs=(cmd.a,), out=cmd.a,
|
||||
)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
||||
yield compute_done
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
pe_txn.done.succeed()
|
||||
Reference in New Issue
Block a user