from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class PeCpuComponent(ComponentBase): """PE_CPU: kernel execution controller (Stage 2). Two-phase kernel execution (ADR-0014 D1): Phase 1 (compile): look up kernel from registry, run it with TLContext to generate a PeCommand list. Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via PeInternalTxn, wait for blocking commands. Non-kernel Transactions are forwarded normally. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0" try: self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1]) except (IndexError, ValueError): self._pe_idx = 0 # Extract sip/cube index for multi-SIP/cube shard matching parts = node.id.split(".") try: self._sip_idx = int(parts[0].replace("sip", "")) except (IndexError, ValueError): self._sip_idx = 0 try: self._cube_idx = int(parts[1].replace("cube", "")) except (IndexError, ValueError): self._cube_idx = 0 # num_cubes from spec (for tl.program_id(axis=1)) spec = ctx.spec if ctx else {} self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1) def _find_shard(self, shards: tuple) -> Any: """Find shard matching this PE's (sip, cube, pe). Fallback to positional index.""" for s in shards: if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx: return s return shards[min(self._pe_idx, len(shards) - 1)] def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _worker(self, env: simpy.Environment) -> Generator: while True: txn: Any = yield self._inbox.get() from kernbench.runtime_api.kernel import KernelLaunchMsg if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg): yield from self._execute_kernel(env, txn) else: yield from self._forward_txn(env, txn) def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator: """Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay.""" from kernbench.triton_emu.registry import get_kernel request = txn.request yield from self.run(env, 0) kernel_fn = get_kernel(request.kernel_ref.name) num_programs = self._derive_num_programs(request) kernel_args = self._unpack_kernel_args(request) pe_exec_start = env.now scheduler_id = f"{self._pe_prefix}.pe_scheduler" # Choose execution mode: greenlet (ADR-0020) or legacy command-list store = getattr(self.ctx, "memory_store", None) if self.ctx else None if store is not None: composite_results = yield from self._execute_greenlet( env, kernel_fn, kernel_args, num_programs, scheduler_id, store, ) else: composite_results = yield from self._execute_legacy( env, kernel_fn, kernel_args, num_programs, scheduler_id, ) # Record PE-internal execution time txn.result_data["pe_exec_ns"] = env.now - pe_exec_start total_dma_ns = 0.0 total_compute_ns = 0.0 for rd in composite_results: total_dma_ns += rd.get("dma_ns", 0.0) total_compute_ns += rd.get("compute_ns", 0.0) txn.result_data["dma_ns"] = total_dma_ns txn.result_data["compute_ns"] = total_compute_ns # Send ResponseMsg on reverse path yield from self._send_response(env, txn, request) def _derive_num_programs(self, request: Any) -> int: num_programs = 1 for arg in request.args: if arg.arg_kind == "tensor": cube_pe_count = sum( 1 for s in arg.shards if s.sip == self._sip_idx and s.cube == self._cube_idx ) if cube_pe_count > num_programs: num_programs = cube_pe_count return num_programs def _unpack_kernel_args(self, request: Any) -> list: kernel_args: list = [] for arg in request.args: if arg.arg_kind == "tensor": if arg.va_base: kernel_args.append(arg.va_base) else: shard = self._find_shard(arg.shards) kernel_args.append(shard.pa) elif arg.arg_kind == "scalar": kernel_args.append(arg.value) return kernel_args def _execute_greenlet( self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store, ) -> Generator: """Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved.""" from kernbench.triton_emu.kernel_runner import KernelRunner runner = KernelRunner( pe_prefix=self._pe_prefix, pe_idx=self._pe_idx, sip_idx=self._sip_idx, cube_idx=self._cube_idx, num_cubes=self._num_cubes, scheduler_id=scheduler_id, out_ports=self.out_ports, store=store, ) yield from runner.run(env, kernel_fn, kernel_args, num_programs) return getattr(runner, "_composite_results", []) def _execute_legacy( self, env, kernel_fn, kernel_args, num_programs, scheduler_id, ) -> Generator: """Legacy Phase 0 + replay: generate command list, then dispatch.""" from kernbench.common.pe_commands import ( CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd, ) from kernbench.triton_emu.tl_context import TLContext, run_kernel tl = TLContext( pe_id=self._pe_idx, num_programs=num_programs, cube_id=self._cube_idx, num_cubes=self._num_cubes, dispatch_cycles=0, ) run_kernel(kernel_fn, tl, *kernel_args) commands = tl.commands pending: dict[str, simpy.Event] = {} composite_results: list[dict] = [] for cmd in commands: if isinstance(cmd, PeCpuOverheadCmd): yield env.timeout(cmd.cycles) elif isinstance(cmd, WaitCmd): if cmd.handle is not None: evt = pending.pop(cmd.handle.id, None) if evt: yield evt else: for evt in pending.values(): yield evt pending.clear() elif isinstance(cmd, CompositeCmd): done_evt = env.event() pe_txn = PeInternalTxn( command=cmd, done=done_evt, pe_prefix=self._pe_prefix, ) composite_results.append(pe_txn.result_data) yield self.out_ports[scheduler_id].put(pe_txn) pending[cmd.completion.id] = done_evt else: done_evt = env.event() pe_txn = PeInternalTxn( command=cmd, done=done_evt, pe_prefix=self._pe_prefix, ) yield self.out_ports[scheduler_id].put(pe_txn) yield done_evt for evt in pending.values(): yield evt return composite_results def _send_response(self, env, txn, request) -> Generator: reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: from kernbench.runtime_api.kernel import ResponseMsg resp_msg = ResponseMsg( correlation_id=request.correlation_id, request_id=request.request_id, src_cube=self._cube_idx, src_pe=self._pe_idx, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed()