63669f82cb
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
246 lines
9.0 KiB
Python
246 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import ComponentBase
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.common.pe_commands import PeInternalTxn
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
|
|
class PeSchedulerComponent(ComponentBase):
|
|
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
|
|
|
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
|
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
|
- GemmCmd → PE_GEMM
|
|
- MathCmd → PE_MATH
|
|
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
|
|
|
Composite GEMM pipeline (32x64x32 tiles):
|
|
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
|
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
|
|
|
Applies scheduler overhead_ns before dispatching each command.
|
|
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
|
"""
|
|
|
|
# Scheduler tile dimensions (ADR-0014 D3.2)
|
|
TILE_M = 32
|
|
TILE_K = 64
|
|
TILE_N = 32
|
|
|
|
# Command → engine suffix dispatch table.
|
|
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
|
_CMD_DISPATCH: dict[type, str] = {}
|
|
|
|
@classmethod
|
|
def _ensure_dispatch_table(cls) -> None:
|
|
if cls._CMD_DISPATCH:
|
|
return
|
|
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
|
|
|
cls._CMD_DISPATCH = {
|
|
DmaReadCmd: "pe_dma",
|
|
DmaWriteCmd: "pe_dma",
|
|
GemmCmd: "pe_gemm",
|
|
MathCmd: "pe_math",
|
|
}
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
|
self._ensure_dispatch_table()
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
|
yield env.timeout(overhead_ns)
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
from kernbench.common.pe_commands import PeInternalTxn
|
|
|
|
while True:
|
|
msg: Any = yield self._inbox.get()
|
|
if isinstance(msg, PeInternalTxn):
|
|
env.process(self._dispatch(env, msg))
|
|
else:
|
|
yield from self._forward_txn(env, msg)
|
|
|
|
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
|
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
|
from kernbench.common.pe_commands import CompositeCmd
|
|
|
|
# Scheduler overhead
|
|
yield from self.run(env, 0)
|
|
|
|
cmd = pe_txn.command
|
|
|
|
# Check dispatch table first
|
|
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
|
if engine_suffix is not None:
|
|
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
|
return
|
|
|
|
# CompositeCmd: tiled pipeline (not a simple forward)
|
|
if isinstance(cmd, CompositeCmd):
|
|
yield from self._dispatch_composite(env, pe_txn)
|
|
return
|
|
|
|
# Unknown command — signal done immediately
|
|
pe_txn.done.succeed()
|
|
|
|
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
|
"""Composite tiled pipeline (ADR-0014 D3.2).
|
|
|
|
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
|
MATH: sequential compute + DMA_WRITE (no tiling).
|
|
"""
|
|
from kernbench.common.pe_commands import CompositeCmd
|
|
|
|
cmd = pe_txn.command
|
|
assert isinstance(cmd, CompositeCmd)
|
|
if cmd.op == "gemm" and cmd.b is not None:
|
|
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
|
else:
|
|
yield from self._pipeline_math(env, pe_txn, cmd)
|
|
|
|
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
|
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
|
|
|
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
|
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
|
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
|
"""
|
|
from kernbench.common.pe_commands import (
|
|
DmaReadCmd,
|
|
DmaWriteCmd,
|
|
GemmCmd,
|
|
PeInternalTxn as PeTxn,
|
|
TensorHandle,
|
|
)
|
|
|
|
pp = self._pe_prefix
|
|
a = cmd.a # already in TCM
|
|
b = cmd.b # HBM reference (via tl.ref)
|
|
|
|
M, K_a = a.shape[-2], a.shape[-1]
|
|
K_b, N = b.shape[-2], b.shape[-1]
|
|
dtype = a.dtype
|
|
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
|
|
|
# Tile counts
|
|
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
|
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
|
n_tiles = n_tiles_k * n_tiles_n
|
|
|
|
prev_compute_done = None
|
|
prev_write_done = None
|
|
total_dma_ns = 0.0
|
|
total_compute_ns = 0.0
|
|
|
|
for tile_idx in range(n_tiles):
|
|
tk = tile_idx // n_tiles_n
|
|
tn = tile_idx % n_tiles_n
|
|
|
|
k_start = tk * self.TILE_K
|
|
n_start = tn * self.TILE_N
|
|
tile_k = min(self.TILE_K, K_a - k_start)
|
|
tile_n = min(self.TILE_N, N - n_start)
|
|
tile_nbytes = tile_k * tile_n * dtype_bytes
|
|
|
|
# --- Stage 1: DMA_READ b_tile from HBM ---
|
|
read_done = env.event()
|
|
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
|
b_tile_handle = TensorHandle(
|
|
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
|
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
|
)
|
|
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
|
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
|
t0 = env.now
|
|
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
|
|
|
# Wait for previous compute before starting this tile's compute
|
|
if prev_compute_done is not None:
|
|
yield prev_compute_done
|
|
|
|
# Wait for this tile's DMA_READ
|
|
yield read_done
|
|
total_dma_ns += env.now - t0
|
|
|
|
# --- Stage 2: COMPUTE (GEMM) ---
|
|
compute_done = env.event()
|
|
out_handle = TensorHandle(
|
|
id=f"out_tile_{tile_idx}", addr=0,
|
|
shape=(M, tile_n), dtype=dtype,
|
|
nbytes=M * tile_n * dtype_bytes,
|
|
)
|
|
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
|
m=M, k=tile_k, n=tile_n)
|
|
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
|
t0 = env.now
|
|
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
|
|
|
# Wait for previous write (DMA_WRITE serialization)
|
|
if prev_write_done is not None:
|
|
yield prev_write_done
|
|
|
|
# Wait for compute of THIS tile
|
|
yield compute_done
|
|
total_compute_ns += env.now - t0
|
|
prev_compute_done = compute_done
|
|
|
|
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
|
write_done = env.event()
|
|
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
|
write_nbytes = M * tile_n * dtype_bytes
|
|
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
|
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
|
t0 = env.now
|
|
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
|
prev_write_done = write_done
|
|
|
|
# Wait for final write
|
|
if prev_write_done is not None:
|
|
t0 = env.now
|
|
yield prev_write_done
|
|
total_dma_ns += env.now - t0
|
|
|
|
pe_txn.result_data["dma_ns"] = total_dma_ns
|
|
pe_txn.result_data["compute_ns"] = total_compute_ns
|
|
pe_txn.done.succeed()
|
|
|
|
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
|
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
|
from kernbench.common.pe_commands import (
|
|
DmaWriteCmd,
|
|
MathCmd,
|
|
PeInternalTxn as PeTxn,
|
|
)
|
|
|
|
pp = self._pe_prefix
|
|
|
|
# Step 1: Compute (MATH)
|
|
compute_done = env.event()
|
|
compute_cmd = MathCmd(
|
|
op=cmd.math_op or "identity",
|
|
inputs=(cmd.a,), out=cmd.a,
|
|
)
|
|
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
|
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
|
yield compute_done
|
|
|
|
# Step 2: DMA_WRITE result to HBM
|
|
write_done = env.event()
|
|
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
|
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
|
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
|
yield write_done
|
|
|
|
pe_txn.done.succeed()
|