Files
kernbench2/src/kernbench/components/builtin/pe_scheduler.py
T
ywkang 63669f82cb Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 01:13:17 -07:00

246 lines
9.0 KiB
Python

from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
def _ensure_dispatch_table(cls) -> None:
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
GemmCmd: "pe_gemm",
MathCmd: "pe_math",
}
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
# Scheduler overhead
yield from self.run(env, 0)
cmd = pe_txn.command
# Check dispatch table first
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
"""
from kernbench.common.pe_commands import CompositeCmd
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()