Implement ADR-0021: PE pipeline refactor with token self-routing
Step 1-2: Backup existing code - builtin/ → builtin_legacy/ (unchanged backup) - custom/pe_accel/ → custom/pe_accel_legacy/ (unchanged backup) Step 3-4: New pipeline types and tiling - pe_types.py: StageType, Stage, TilePlan, PipelinePlan, PipelineContext, TileToken - tiling.py: generate_gemm_plan, generate_math_plan (ported from pe_accel) Step 5: Component implementations (ADR-0021 D4-D6) - PE_SCHEDULER: _feed_loop (singleton FIFO feeder) + plan generation - PE_FETCH_STORE: new component — TCM ↔ Register File - PE_GEMM: TileToken pipeline + legacy PeInternalTxn dual-mode - PE_MATH: TileToken pipeline + legacy dual-mode - PE_DMA: TileToken pipeline + legacy + fabric Transaction triple-mode - PE_TCM: TcmRequest handler with dual-channel BW serialization Step 6: Infrastructure - topology.yaml: pe_fetch_store component + chaining edges - components.yaml: pe_fetch_store_v1 registration - builder.py: PE_COMP_OFFSETS, _add_pe_internal_edges, PE view positions - Tests: node/edge counts, PE component sets updated All components handle both TileToken (pipeline) and PeInternalTxn (legacy). Token self-routing: components read next stage from token.plan, chain via out_port. 366 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -105,6 +105,73 @@ class PeDmaComponent(PeEngineBase):
|
||||
yield sub_done
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: DMA read/write via fabric, then self-route."""
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, TensorHandle
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
self._on_process_start(env, token)
|
||||
|
||||
params = token.params
|
||||
stage_type = token.current_stage.stage_type
|
||||
|
||||
from kernbench.components.builtin.pe_types import StageType
|
||||
is_write = stage_type == StageType.DMA_WRITE
|
||||
addr = params.get("dst_addr" if is_write else "src_addr", 0)
|
||||
nbytes = params.get("nbytes", 0)
|
||||
|
||||
if nbytes > 0 and self.ctx:
|
||||
dma_res = self._dma_write if is_write else self._dma_read
|
||||
assert dma_res is not None
|
||||
|
||||
pa = PhysAddr.decode(addr)
|
||||
dst_node = self.ctx.resolver.resolve(pa)
|
||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||
drain_ns = self.ctx.compute_drain_ns(path, nbytes)
|
||||
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
sub_done = env.event()
|
||||
sub_request = PeDmaMsg(
|
||||
correlation_id="pipeline",
|
||||
request_id=f"tile_{token.tile_id}",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=addr, nbytes=nbytes,
|
||||
is_write=is_write,
|
||||
)
|
||||
sub_txn = Transaction(
|
||||
request=sub_request, path=path, step=0,
|
||||
nbytes=nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
if len(path) > 1:
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
|
||||
yield sub_done
|
||||
|
||||
self._on_process_end(env, token)
|
||||
|
||||
# Self-routing
|
||||
next_stage = token.advance()
|
||||
if next_stage is not None:
|
||||
yield self.out_ports[next_stage.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
|
||||
# Response transactions bypass DMA channel (no outbound resource needed)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
"""PE_FETCH_STORE: TCM ↔ Register File transfer unit (ADR-0021 D5).
|
||||
|
||||
Handles both fetch (TCM → register) and store (register → TCM).
|
||||
BW serialization is delegated to PE_TCM via port communication.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeFetchStoreComponent(PeEngineBase):
|
||||
"""PE_FETCH_STORE: TCM ↔ Register File (ADR-0021 D5).
|
||||
|
||||
Receives TileTokens via pipeline self-routing.
|
||||
Sends TcmRequest to PE_TCM for BW-based latency.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._tcm_id = f"{self._pe_prefix}.pe_tcm"
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Handle both PeInternalTxn (legacy) and TileToken (pipeline)."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self.handle_command(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Process a pipeline TileToken: fetch or store via TCM."""
|
||||
from kernbench.components.builtin.pe_tcm import TcmRequest
|
||||
|
||||
self._on_process_start(env, token)
|
||||
|
||||
direction = token.params.get("direction", "read")
|
||||
nbytes = token.params.get("nbytes", 0)
|
||||
|
||||
if nbytes > 0 and self._tcm_id in self.out_ports:
|
||||
done = env.event()
|
||||
yield self.out_ports[self._tcm_id].put(
|
||||
TcmRequest(direction=direction, nbytes=nbytes, done=done)
|
||||
)
|
||||
yield done
|
||||
|
||||
self._on_process_end(env, token)
|
||||
|
||||
# Self-routing: advance to next stage
|
||||
next_stage = token.advance()
|
||||
if next_stage is not None:
|
||||
yield self.out_ports[next_stage.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Legacy PeInternalTxn handling."""
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
@@ -1,6 +1,18 @@
|
||||
"""PE_GEMM: matrix multiplication engine (ADR-0021 D6).
|
||||
|
||||
Handles both legacy PeInternalTxn (GemmCmd) and pipeline TileToken.
|
||||
In pipeline mode, receives token after fetch stage, computes MAC, chains to next.
|
||||
|
||||
MAC latency model (from pe_accel):
|
||||
cycles = ceil(Tm/mac_m) * ceil(Tk/mac_k) * ceil(Tn/mac_n)
|
||||
latency_ns = cycles / clock_freq_ghz
|
||||
|
||||
Falls back to TFLOPS model when mac dimensions not configured.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from math import ceil
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
@@ -12,33 +24,29 @@ if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# dtype → bit width (for TFLOPS scaling)
|
||||
_DTYPE_BITS: dict[str, int] = {
|
||||
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
||||
"f32": 32, "fp32": 32, "float32": 32,
|
||||
"i8": 8, "int8": 8,
|
||||
"i16": 16, "int16": 16,
|
||||
"i32": 32, "int32": 32,
|
||||
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
|
||||
}
|
||||
|
||||
|
||||
class PeGemmComponent(PeEngineBase):
|
||||
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
|
||||
"""PE_GEMM: MAC array (ADR-0021 D6).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_MATH within the same PE.
|
||||
|
||||
Compute latency model:
|
||||
FLOPs = 2 * M * K * N
|
||||
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
|
||||
compute_ns = FLOPs / (effective_tflops * 1e3)
|
||||
In pipeline mode: pure compute — register data already fetched.
|
||||
In legacy mode: handles PeInternalTxn(GemmCmd) with shared accel_slot.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
|
||||
# Cycle-accurate MAC dimensions (from pe_accel)
|
||||
self._mac_m: int = int(node.attrs.get("mac_m", 0))
|
||||
self._mac_k: int = int(node.attrs.get("mac_k", 0))
|
||||
self._mac_n: int = int(node.attrs.get("mac_n", 0))
|
||||
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
@@ -47,8 +55,15 @@ class PeGemmComponent(PeEngineBase):
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
|
||||
"""Compute GEMM latency in nanoseconds."""
|
||||
def _compute_ns_mac(self, m: int, k: int, n: int) -> float:
|
||||
"""Cycle-accurate MAC latency (pe_accel model)."""
|
||||
if self._mac_m > 0 and self._mac_k > 0 and self._mac_n > 0:
|
||||
cycles = ceil(m / self._mac_m) * ceil(k / self._mac_k) * ceil(n / self._mac_n)
|
||||
return cycles / self._clock_freq
|
||||
return 0.0
|
||||
|
||||
def _compute_ns_tflops(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
|
||||
"""TFLOPS-based latency (legacy model)."""
|
||||
if self._peak_tflops_f16 <= 0:
|
||||
return float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
dtype_bits = _DTYPE_BITS.get(dtype, 16)
|
||||
@@ -56,11 +71,58 @@ class PeGemmComponent(PeEngineBase):
|
||||
flops = 2.0 * m * k * n
|
||||
return flops / (effective_tflops * 1e3)
|
||||
|
||||
def _compute_ns(self, m: int, k: int, n: int, dtype: str = "f16") -> float:
|
||||
"""Choose best available latency model."""
|
||||
mac_ns = self._compute_ns_mac(m, k, n)
|
||||
if mac_ns > 0:
|
||||
return mac_ns
|
||||
return self._compute_ns_tflops(m, k, n, dtype)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: pure MAC compute, then self-route."""
|
||||
self._on_process_start(env, token)
|
||||
|
||||
m = token.params.get("m", 0)
|
||||
k = token.params.get("k", 0)
|
||||
n = token.params.get("n", 0)
|
||||
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
ns = self._compute_ns(m, k, n)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
ns = self._compute_ns(m, k, n)
|
||||
yield env.timeout(ns)
|
||||
|
||||
self._on_process_end(env, token)
|
||||
|
||||
# Self-routing
|
||||
next_stage = token.advance()
|
||||
if next_stage is not None:
|
||||
yield self.out_ports[next_stage.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Legacy PeInternalTxn handling."""
|
||||
from kernbench.common.pe_commands import GemmCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
@@ -81,7 +143,6 @@ class PeGemmComponent(PeEngineBase):
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
"""PE_MATH: element-wise / reduction computation engine (ADR-0021 D6).
|
||||
|
||||
Handles both legacy PeInternalTxn (MathCmd) and pipeline TileToken.
|
||||
In pipeline mode, receives token after fetch stage, computes SIMD, chains to next.
|
||||
|
||||
SIMD latency model (from pe_accel):
|
||||
cycles = ceil(num_elements / vector_width)
|
||||
latency_ns = cycles / clock_freq_ghz
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from math import ceil
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
@@ -14,15 +24,17 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PeMathComponent(PeEngineBase):
|
||||
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
|
||||
"""PE_MATH: SIMD/Vector unit (ADR-0021 D6).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_GEMM within the same PE.
|
||||
In pipeline mode: pure compute — register data already fetched.
|
||||
In legacy mode: handles PeInternalTxn(MathCmd) with shared accel_slot.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
self._vector_width: int = int(node.attrs.get("vector_width", 256))
|
||||
self._clock_freq: float = float(node.attrs.get("clock_freq_ghz", 1.0))
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
@@ -31,11 +43,56 @@ class PeMathComponent(PeEngineBase):
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def _compute_ns(self, num_elements: int) -> float:
|
||||
"""SIMD latency (pe_accel model)."""
|
||||
if self._vector_width > 0 and self._clock_freq > 0 and num_elements > 0:
|
||||
cycles = ceil(num_elements / self._vector_width)
|
||||
return cycles / self._clock_freq
|
||||
return float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: pure SIMD compute, then self-route."""
|
||||
self._on_process_start(env, token)
|
||||
|
||||
num_elements = token.params.get("num_elements", 0)
|
||||
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
ns = self._compute_ns(num_elements)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
ns = self._compute_ns(num_elements)
|
||||
yield env.timeout(ns)
|
||||
|
||||
self._on_process_end(env, token)
|
||||
|
||||
# Self-routing
|
||||
next_stage = token.advance()
|
||||
if next_stage is not None:
|
||||
yield self.out_ports[next_stage.component].put(token)
|
||||
else:
|
||||
token.pipeline_ctx.complete_tile()
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Legacy PeInternalTxn handling."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
@@ -45,7 +102,6 @@ class PeMathComponent(PeEngineBase):
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
|
||||
@@ -1,3 +1,13 @@
|
||||
"""PE_SCHEDULER: plan generation + tile dispatch (ADR-0021 D2).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to engines:
|
||||
- Simple commands (DmaReadCmd, GemmCmd, etc.) → direct dispatch to engine
|
||||
- CompositeCmd → generate TilePlan, feed tiles via _feed_loop
|
||||
|
||||
Composite pipeline uses token self-routing (ADR-0021 D4):
|
||||
Scheduler only does initial dispatch + completion tracking.
|
||||
Tiles chain through components based on their plan's stage sequence.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
@@ -14,29 +24,18 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PeSchedulerComponent(ComponentBase):
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1, ADR-0021 D2).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
||||
- GemmCmd → PE_GEMM
|
||||
- MathCmd → PE_MATH
|
||||
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
||||
Simple commands are forwarded to the appropriate engine.
|
||||
CompositeCmd creates a TilePlan and feeds tiles into the pipeline.
|
||||
|
||||
Composite GEMM pipeline (32x64x32 tiles):
|
||||
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
||||
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
|
||||
Applies scheduler overhead_ns before dispatching each command.
|
||||
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
||||
Single _feed_loop process per scheduler ensures FIFO command ordering.
|
||||
"""
|
||||
|
||||
# Scheduler tile dimensions (ADR-0014 D3.2)
|
||||
TILE_M = 32
|
||||
TILE_K = 64
|
||||
TILE_N = 32
|
||||
|
||||
# Command → engine suffix dispatch table.
|
||||
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
||||
_CMD_DISPATCH: dict[type, str] = {}
|
||||
|
||||
@classmethod
|
||||
@@ -44,7 +43,6 @@ class PeSchedulerComponent(ComponentBase):
|
||||
if cls._CMD_DISPATCH:
|
||||
return
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
||||
|
||||
cls._CMD_DISPATCH = {
|
||||
DmaReadCmd: "pe_dma",
|
||||
DmaWriteCmd: "pe_dma",
|
||||
@@ -56,6 +54,13 @@ class PeSchedulerComponent(ComponentBase):
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
||||
self._ensure_dispatch_table()
|
||||
self._pending_feeds: simpy.Store | None = None
|
||||
self._pipeline_counter = 0
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._pending_feeds = simpy.Store(env)
|
||||
super().start(env)
|
||||
env.process(self._feed_loop(env))
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
@@ -72,174 +77,103 @@ class PeSchedulerComponent(ComponentBase):
|
||||
yield from self._forward_txn(env, msg)
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
from kernbench.common.pe_commands import CompositeCmd, PeCpuOverheadCmd
|
||||
|
||||
# Scheduler overhead
|
||||
yield from self.run(env, 0)
|
||||
yield from self.run(env, 0) # scheduler overhead
|
||||
|
||||
cmd = pe_txn.command
|
||||
|
||||
# Check dispatch table first
|
||||
# Simple command dispatch
|
||||
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
||||
if engine_suffix is not None:
|
||||
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
||||
return
|
||||
|
||||
# CompositeCmd: tiled pipeline (not a simple forward)
|
||||
# CompositeCmd: generate plan and feed
|
||||
if isinstance(cmd, CompositeCmd):
|
||||
yield from self._dispatch_composite(env, pe_txn)
|
||||
yield from self._dispatch_composite(env, pe_txn, cmd)
|
||||
return
|
||||
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
pe_txn.done.succeed()
|
||||
return
|
||||
|
||||
# Unknown command — signal done immediately
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Composite tiled pipeline (ADR-0014 D3.2).
|
||||
def _dispatch_composite(
|
||||
self, env: simpy.Environment, pe_txn: Any, cmd: Any,
|
||||
) -> Generator:
|
||||
"""Generate plan and enqueue to feeder. Non-blocking (ADR-0021 D4)."""
|
||||
from kernbench.components.builtin.pe_types import PipelineContext
|
||||
|
||||
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
||||
MATH: sequential compute + DMA_WRITE (no tiling).
|
||||
plan = self._generate_plan(cmd)
|
||||
|
||||
self._pipeline_counter += 1
|
||||
ctx = PipelineContext(
|
||||
id=f"p{self._pipeline_counter}",
|
||||
total_tiles=len(plan.tiles),
|
||||
done_event=pe_txn.done,
|
||||
)
|
||||
|
||||
# Enqueue to feeder — scheduler worker returns immediately
|
||||
assert self._pending_feeds is not None
|
||||
yield self._pending_feeds.put((plan, ctx))
|
||||
|
||||
def _feed_loop(self, env: simpy.Environment) -> Generator:
|
||||
"""Single feeder process: FIFO command ordering (ADR-0021 D2).
|
||||
|
||||
No tile feed interleaving between commands.
|
||||
Queue full → only this process blocks.
|
||||
"""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
assert self._pending_feeds is not None
|
||||
while True:
|
||||
plan, ctx = yield self._pending_feeds.get()
|
||||
for tile in plan.tiles:
|
||||
first_stage = tile.stages[0]
|
||||
token = TileToken(
|
||||
tile_id=tile.tile_id,
|
||||
pipeline_ctx=ctx,
|
||||
plan=tile,
|
||||
stage_idx=0,
|
||||
params=first_stage.params,
|
||||
)
|
||||
yield self.out_ports[first_stage.component].put(token)
|
||||
|
||||
def _generate_plan(self, cmd: Any) -> Any:
|
||||
"""Generate a PipelinePlan from CompositeCmd."""
|
||||
from kernbench.components.builtin.tiling import (
|
||||
generate_gemm_plan,
|
||||
generate_math_plan,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
bpe = 2 # default bytes per element (f16)
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
||||
a = cmd.a
|
||||
b = cmd.b
|
||||
M, K = a.shape[-2], a.shape[-1]
|
||||
N = b.shape[-1]
|
||||
return generate_gemm_plan(
|
||||
M=M, K=K, N=N,
|
||||
tile_m=self.TILE_M, tile_k=self.TILE_K, tile_n=self.TILE_N,
|
||||
bytes_per_element=bpe,
|
||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||
pe_prefix=pp,
|
||||
)
|
||||
else:
|
||||
yield from self._pipeline_math(env, pe_txn, cmd)
|
||||
|
||||
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
||||
|
||||
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
||||
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
||||
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
"""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
TensorHandle,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
a = cmd.a # already in TCM
|
||||
b = cmd.b # HBM reference (via tl.ref)
|
||||
|
||||
M, K_a = a.shape[-2], a.shape[-1]
|
||||
K_b, N = b.shape[-2], b.shape[-1]
|
||||
dtype = a.dtype
|
||||
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
||||
|
||||
# Tile counts
|
||||
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
||||
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
||||
n_tiles = n_tiles_k * n_tiles_n
|
||||
|
||||
prev_compute_done = None
|
||||
prev_write_done = None
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
|
||||
for tile_idx in range(n_tiles):
|
||||
tk = tile_idx // n_tiles_n
|
||||
tn = tile_idx % n_tiles_n
|
||||
|
||||
k_start = tk * self.TILE_K
|
||||
n_start = tn * self.TILE_N
|
||||
tile_k = min(self.TILE_K, K_a - k_start)
|
||||
tile_n = min(self.TILE_N, N - n_start)
|
||||
tile_nbytes = tile_k * tile_n * dtype_bytes
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
# Math composite
|
||||
a = cmd.a
|
||||
M = a.shape[-2] if len(a.shape) >= 2 else a.shape[0]
|
||||
N = a.shape[-1] if len(a.shape) >= 2 else 1
|
||||
return generate_math_plan(
|
||||
M=M, N=N,
|
||||
tile_m=self.TILE_M, tile_n=self.TILE_N,
|
||||
bytes_per_element=bpe,
|
||||
math_op=cmd.math_op or "identity",
|
||||
src_addr=a.addr, dst_addr=cmd.out_addr,
|
||||
pe_prefix=pp,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
|
||||
# Wait for previous compute before starting this tile's compute
|
||||
if prev_compute_done is not None:
|
||||
yield prev_compute_done
|
||||
|
||||
# Wait for this tile's DMA_READ
|
||||
yield read_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", addr=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
||||
m=M, k=tile_k, n=tile_n)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
||||
|
||||
# Wait for previous write (DMA_WRITE serialization)
|
||||
if prev_write_done is not None:
|
||||
yield prev_write_done
|
||||
|
||||
# Wait for compute of THIS tile
|
||||
yield compute_done
|
||||
total_compute_ns += env.now - t0
|
||||
prev_compute_done = compute_done
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
prev_write_done = write_done
|
||||
|
||||
# Wait for final write
|
||||
if prev_write_done is not None:
|
||||
t0 = env.now
|
||||
yield prev_write_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
pe_txn.result_data["dma_ns"] = total_dma_ns
|
||||
pe_txn.result_data["compute_ns"] = total_compute_ns
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaWriteCmd,
|
||||
MathCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
|
||||
# Step 1: Compute (MATH)
|
||||
compute_done = env.event()
|
||||
compute_cmd = MathCmd(
|
||||
op=cmd.math_op or "identity",
|
||||
inputs=(cmd.a,), out=cmd.a,
|
||||
)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
||||
yield compute_done
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
pe_txn.done.succeed()
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
"""PE_TCM: tightly-coupled memory with BW-based access serialization (ADR-0021).
|
||||
|
||||
Models scratchpad memory inside the PE. Handles both legacy Transaction forwarding
|
||||
and TcmRequest from PE_FETCH_STORE for BW-serialized read/write access.
|
||||
|
||||
Two channels (read/write) with independent serialization.
|
||||
Ported from pe_accel TcmBlock timing model.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
@@ -10,16 +21,62 @@ if TYPE_CHECKING:
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeTcmComponent(ComponentBase):
|
||||
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
|
||||
@dataclass
|
||||
class TcmRequest:
|
||||
"""Request to read from or write to TCM (used by PE_FETCH_STORE)."""
|
||||
|
||||
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
|
||||
Phase 0: applies overhead_ns and drain_ns at terminal.
|
||||
direction: str # "read" or "write"
|
||||
nbytes: int
|
||||
done: simpy.Event
|
||||
tag: str = ""
|
||||
|
||||
|
||||
class PeTcmComponent(ComponentBase):
|
||||
"""PE_TCM: BW-serialized scratchpad memory (ADR-0021 D1).
|
||||
|
||||
Dual-channel: read and write can proceed in parallel,
|
||||
but concurrent reads serialize, concurrent writes serialize.
|
||||
BW from topology attrs or pe_template links.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._read_bw: float = float(node.attrs.get("read_bw_gbs", 512.0))
|
||||
self._write_bw: float = float(node.attrs.get("write_bw_gbs", 512.0))
|
||||
self._read_res: simpy.Resource | None = None
|
||||
self._write_res: simpy.Resource | None = None
|
||||
|
||||
def run(self, env, nbytes: int) -> Generator:
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._read_res = simpy.Resource(env, capacity=1)
|
||||
self._write_res = simpy.Resource(env, capacity=1)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch TcmRequest (from fetch_store) and Transaction (fabric)."""
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TcmRequest):
|
||||
env.process(self._handle_tcm_request(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _handle_tcm_request(self, env: simpy.Environment, req: TcmRequest) -> Generator:
|
||||
"""BW-serialized access: acquire channel, apply delay, signal done."""
|
||||
if req.direction == "write":
|
||||
res = self._write_res
|
||||
bw = self._write_bw
|
||||
else:
|
||||
res = self._read_res
|
||||
bw = self._read_bw
|
||||
|
||||
assert res is not None
|
||||
with res.request() as lock:
|
||||
yield lock
|
||||
if bw > 0 and req.nbytes > 0:
|
||||
delay_ns = req.nbytes / bw
|
||||
yield env.timeout(delay_ns)
|
||||
req.done.succeed()
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""PE pipeline types for ADR-0021: TileToken, TilePlan, Stage, PipelineContext.
|
||||
|
||||
These types are used by the PE_SCHEDULER and all PE engine components
|
||||
for tile-based pipeline execution with self-routing.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import simpy
|
||||
|
||||
|
||||
# ── Stage types ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class StageType(Enum):
|
||||
DMA_READ = auto()
|
||||
FETCH = auto()
|
||||
GEMM = auto()
|
||||
MATH = auto()
|
||||
STORE = auto()
|
||||
DMA_WRITE = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stage:
|
||||
"""One stage in a tile's execution plan."""
|
||||
|
||||
stage_type: StageType
|
||||
component: str # topology node ID (e.g. "sip0.cube0.pe0.pe_dma")
|
||||
params: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Plan ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class TilePlan:
|
||||
"""Execution plan for a single tile (immutable stage sequence)."""
|
||||
|
||||
tile_id: int
|
||||
stages: tuple[Stage, ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelinePlan:
|
||||
"""Full pipeline plan for one CompositeCmd."""
|
||||
|
||||
tiles: list[TilePlan]
|
||||
# Metadata for metrics
|
||||
m_tiles: int = 0
|
||||
k_tiles: int = 0
|
||||
n_tiles: int = 0
|
||||
|
||||
|
||||
# ── Pipeline Context ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
"""Tracks completion of a pipeline (exactly-once contract).
|
||||
|
||||
Each tile's last stage calls complete_tile() exactly once.
|
||||
When all tiles complete, done_event.succeed() is called.
|
||||
"""
|
||||
|
||||
id: str
|
||||
total_tiles: int
|
||||
completed_tiles: int = 0
|
||||
done_event: Any = None # simpy.Event
|
||||
|
||||
def complete_tile(self) -> None:
|
||||
self.completed_tiles += 1
|
||||
if self.completed_tiles == self.total_tiles:
|
||||
if self.done_event is not None:
|
||||
self.done_event.succeed()
|
||||
|
||||
|
||||
# ── TileToken ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileToken:
|
||||
"""Self-routing tile token passed between PE components (ADR-0021 D9).
|
||||
|
||||
Single-owner: only one component holds this token at any time.
|
||||
params is a cache of plan.stages[stage_idx].params (canonical source).
|
||||
"""
|
||||
|
||||
tile_id: int
|
||||
pipeline_ctx: PipelineContext
|
||||
plan: TilePlan
|
||||
stage_idx: int
|
||||
params: dict = field(default_factory=dict)
|
||||
data_op: bool = True # op_log recording target (ADR-0020)
|
||||
|
||||
@property
|
||||
def current_stage(self) -> Stage:
|
||||
return self.plan.stages[self.stage_idx]
|
||||
|
||||
@property
|
||||
def has_next_stage(self) -> bool:
|
||||
return self.stage_idx + 1 < len(self.plan.stages)
|
||||
|
||||
def advance(self) -> Stage | None:
|
||||
"""Advance to next stage. Returns next Stage or None if last."""
|
||||
self.stage_idx += 1
|
||||
if self.stage_idx < len(self.plan.stages):
|
||||
next_stage = self.plan.stages[self.stage_idx]
|
||||
self.params = next_stage.params
|
||||
return next_stage
|
||||
return None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Tile plan generators for PE pipeline (ADR-0021).
|
||||
|
||||
Generates TilePlan with stage sequences for GEMM and Math operations.
|
||||
Ported from pe_accel tiling.py with stage-based plan structure.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
|
||||
from kernbench.components.builtin.pe_types import (
|
||||
PipelinePlan,
|
||||
Stage,
|
||||
StageType,
|
||||
TilePlan,
|
||||
)
|
||||
|
||||
|
||||
def generate_gemm_plan(
|
||||
M: int, K: int, N: int,
|
||||
tile_m: int, tile_k: int, tile_n: int,
|
||||
bytes_per_element: int,
|
||||
A_addr: int, B_addr: int, C_addr: int,
|
||||
pe_prefix: str,
|
||||
) -> PipelinePlan:
|
||||
"""Generate GEMM tile plan: M→N→K order.
|
||||
|
||||
Each tile follows stage sequence:
|
||||
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
|
||||
On last K-tile per (m,n): → DMA_WRITE
|
||||
|
||||
Args:
|
||||
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
||||
"""
|
||||
M_tiles = max(1, ceil(M / tile_m))
|
||||
K_tiles = max(1, ceil(K / tile_k))
|
||||
N_tiles = max(1, ceil(N / tile_n))
|
||||
bpe = bytes_per_element
|
||||
|
||||
dma_id = f"{pe_prefix}.pe_dma"
|
||||
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
||||
gemm_id = f"{pe_prefix}.pe_gemm"
|
||||
# math_id = f"{pe_prefix}.pe_math" # for K-accumulation if needed
|
||||
|
||||
tiles: list[TilePlan] = []
|
||||
tile_id = 0
|
||||
|
||||
for m in range(M_tiles):
|
||||
for n in range(N_tiles):
|
||||
c_addr = C_addr + (m * tile_m * N + n * tile_n) * bpe
|
||||
for k in range(K_tiles):
|
||||
last_k = k == K_tiles - 1
|
||||
a_addr = A_addr + (m * tile_m * K + k * tile_k) * bpe
|
||||
b_addr = B_addr + (k * tile_k * N + n * tile_n) * bpe
|
||||
|
||||
a_bytes = tile_m * tile_k * bpe
|
||||
b_bytes = tile_k * tile_n * bpe
|
||||
out_bytes = tile_m * tile_n * bpe
|
||||
|
||||
stages: list[Stage] = []
|
||||
|
||||
# DMA READ: load A and B tiles from HBM → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": a_addr, "nbytes": a_bytes,
|
||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": b_addr, "nbytes": b_bytes,
|
||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||
},
|
||||
))
|
||||
|
||||
# FETCH: TCM → Register File
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.FETCH,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "read",
|
||||
"nbytes": a_bytes + b_bytes,
|
||||
},
|
||||
))
|
||||
|
||||
# GEMM: MAC compute
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.GEMM,
|
||||
component=gemm_id,
|
||||
params={
|
||||
"m": tile_m, "k": tile_k, "n": tile_n,
|
||||
"is_last_k": last_k,
|
||||
},
|
||||
))
|
||||
|
||||
# STORE: Register File → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "write",
|
||||
"nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
|
||||
# DMA WRITE: TCM → HBM (only on last K-tile)
|
||||
if last_k:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_WRITE,
|
||||
component=dma_id,
|
||||
params={
|
||||
"dst_addr": c_addr, "nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
|
||||
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
|
||||
tile_id += 1
|
||||
|
||||
return PipelinePlan(
|
||||
tiles=tiles, m_tiles=M_tiles, k_tiles=K_tiles, n_tiles=N_tiles,
|
||||
)
|
||||
|
||||
|
||||
def generate_math_plan(
|
||||
M: int, N: int,
|
||||
tile_m: int, tile_n: int,
|
||||
bytes_per_element: int,
|
||||
math_op: str,
|
||||
src_addr: int, dst_addr: int,
|
||||
pe_prefix: str,
|
||||
) -> PipelinePlan:
|
||||
"""Generate element-wise math tile plan.
|
||||
|
||||
Each tile: DMA_READ → FETCH → MATH → STORE → DMA_WRITE
|
||||
"""
|
||||
M_tiles = max(1, ceil(M / tile_m))
|
||||
N_tiles = max(1, ceil(N / tile_n))
|
||||
bpe = bytes_per_element
|
||||
|
||||
dma_id = f"{pe_prefix}.pe_dma"
|
||||
fetch_id = f"{pe_prefix}.pe_fetch_store"
|
||||
math_id = f"{pe_prefix}.pe_math"
|
||||
|
||||
tiles: list[TilePlan] = []
|
||||
tile_id = 0
|
||||
|
||||
for m in range(M_tiles):
|
||||
for n in range(N_tiles):
|
||||
offset = (m * tile_m * N + n * tile_n) * bpe
|
||||
tile_bytes = tile_m * tile_n * bpe
|
||||
|
||||
stages = [
|
||||
Stage(StageType.DMA_READ, dma_id, {
|
||||
"src_addr": src_addr + offset, "nbytes": tile_bytes,
|
||||
}),
|
||||
Stage(StageType.FETCH, fetch_id, {
|
||||
"direction": "read", "nbytes": tile_bytes,
|
||||
}),
|
||||
Stage(StageType.MATH, math_id, {
|
||||
"op": math_op, "num_elements": tile_m * tile_n,
|
||||
}),
|
||||
Stage(StageType.STORE, fetch_id, {
|
||||
"direction": "write", "nbytes": tile_bytes,
|
||||
}),
|
||||
Stage(StageType.DMA_WRITE, dma_id, {
|
||||
"dst_addr": dst_addr + offset, "nbytes": tile_bytes,
|
||||
}),
|
||||
]
|
||||
|
||||
tiles.append(TilePlan(tile_id=tile_id, stages=tuple(stages)))
|
||||
tile_id += 1
|
||||
|
||||
return PipelinePlan(tiles=tiles, m_tiles=M_tiles, n_tiles=N_tiles)
|
||||
Reference in New Issue
Block a user