14d800b0ae
- KernelLaunchMsg gains target_start_ns: IO_CPU stamps a global barrier (max path latency across every target PE), M_CPU passes it through, PE_CPU yields until it before recording pe_exec_start. Every PE in a launch begins kernel execution at the same env.now regardless of its dispatch path length — eliminates per-PE dispatch-offset artifact in cross-PE and cross-cube latency measurements. - PE_DMA._handle_ipcq_inbound now pays Transaction.drain_ns at the top, matching the terminal-drain behavior of ComponentBase._forward_txn for every non-IPCQ Transaction. SRC-side tl.send stays fire-and-forget (sender doesn't yield on sub_done); tl.recv now blocks until bytes have actually drained into its inbox. - ComponentContext: new compute_path_latency_ns helper + node_overhead_ns field populated by GraphEngine. - tests/test_kernel_launch_sync.py: asserts all PEs in one launch produce identical pe_exec_ns for a no-op kernel (zero spread). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
222 lines
8.6 KiB
Python
222 lines
8.6 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import ComponentBase
|
|
from kernbench.sim_engine.transaction import Transaction
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
|
|
class PeCpuComponent(ComponentBase):
|
|
"""PE_CPU: kernel execution controller (Stage 2).
|
|
|
|
Two-phase kernel execution (ADR-0014 D1):
|
|
Phase 1 (compile): look up kernel from registry, run it with TLContext
|
|
to generate a PeCommand list.
|
|
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
|
|
PeInternalTxn, wait for blocking commands.
|
|
|
|
Non-kernel Transactions are forwarded normally.
|
|
"""
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
|
try:
|
|
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
|
|
except (IndexError, ValueError):
|
|
self._pe_idx = 0
|
|
# Extract sip/cube index for multi-SIP/cube shard matching
|
|
parts = node.id.split(".")
|
|
try:
|
|
self._sip_idx = int(parts[0].replace("sip", ""))
|
|
except (IndexError, ValueError):
|
|
self._sip_idx = 0
|
|
try:
|
|
self._cube_idx = int(parts[1].replace("cube", ""))
|
|
except (IndexError, ValueError):
|
|
self._cube_idx = 0
|
|
|
|
def _find_shard(self, shards: tuple) -> Any:
|
|
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
|
for s in shards:
|
|
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
|
|
return s
|
|
return shards[min(self._pe_idx, len(shards) - 1)]
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
|
yield env.timeout(overhead_ns)
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
while True:
|
|
txn: Any = yield self._inbox.get()
|
|
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
|
|
|
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
|
|
yield from self._execute_kernel(env, txn)
|
|
else:
|
|
yield from self._forward_txn(env, txn)
|
|
|
|
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
"""Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay."""
|
|
from kernbench.triton_emu.registry import get_kernel
|
|
|
|
request = txn.request
|
|
yield from self.run(env, 0)
|
|
|
|
# ADR-0009 D5: synchronized launch barrier. If M_CPU stamped a
|
|
# target_start_ns, wait until then so every PE in this launch
|
|
# begins pe_exec measurement at the same simulated time.
|
|
target_start = getattr(request, "target_start_ns", None)
|
|
if target_start is not None and target_start > env.now:
|
|
yield env.timeout(float(target_start) - env.now)
|
|
|
|
kernel_fn = get_kernel(request.kernel_ref.name)
|
|
num_programs = self._derive_num_programs(request)
|
|
kernel_args = self._unpack_kernel_args(request)
|
|
|
|
pe_exec_start = env.now
|
|
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
|
|
|
# Choose execution mode: greenlet (ADR-0020) or legacy command-list
|
|
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
|
|
|
|
if store is not None:
|
|
composite_results = yield from self._execute_greenlet(
|
|
env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
|
)
|
|
else:
|
|
composite_results = yield from self._execute_legacy(
|
|
env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
|
)
|
|
|
|
# Record PE-internal execution time
|
|
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
|
total_dma_ns = 0.0
|
|
total_compute_ns = 0.0
|
|
for rd in composite_results:
|
|
total_dma_ns += rd.get("dma_ns", 0.0)
|
|
total_compute_ns += rd.get("compute_ns", 0.0)
|
|
txn.result_data["dma_ns"] = total_dma_ns
|
|
txn.result_data["compute_ns"] = total_compute_ns
|
|
|
|
# Send ResponseMsg on reverse path
|
|
yield from self._send_response(env, txn, request)
|
|
|
|
def _derive_num_programs(self, request: Any) -> int:
|
|
num_programs = 1
|
|
for arg in request.args:
|
|
if arg.arg_kind == "tensor":
|
|
cube_pe_count = sum(
|
|
1 for s in arg.shards
|
|
if s.sip == self._sip_idx and s.cube == self._cube_idx
|
|
)
|
|
if cube_pe_count > num_programs:
|
|
num_programs = cube_pe_count
|
|
return num_programs
|
|
|
|
def _unpack_kernel_args(self, request: Any) -> list:
|
|
kernel_args: list = []
|
|
for arg in request.args:
|
|
if arg.arg_kind == "tensor":
|
|
if arg.va_base:
|
|
kernel_args.append(arg.va_base)
|
|
else:
|
|
shard = self._find_shard(arg.shards)
|
|
kernel_args.append(shard.pa)
|
|
elif arg.arg_kind == "scalar":
|
|
kernel_args.append(arg.value)
|
|
return kernel_args
|
|
|
|
def _execute_greenlet(
|
|
self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
|
) -> Generator:
|
|
"""Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved."""
|
|
from kernbench.triton_emu.kernel_runner import KernelRunner
|
|
|
|
runner = KernelRunner(
|
|
pe_prefix=self._pe_prefix,
|
|
pe_idx=self._pe_idx,
|
|
sip_idx=self._sip_idx,
|
|
cube_idx=self._cube_idx,
|
|
scheduler_id=scheduler_id,
|
|
out_ports=self.out_ports,
|
|
store=store,
|
|
)
|
|
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
|
|
return getattr(runner, "_composite_results", [])
|
|
|
|
def _execute_legacy(
|
|
self, env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
|
) -> Generator:
|
|
"""Legacy Phase 0 + replay: generate command list, then dispatch."""
|
|
from kernbench.common.pe_commands import (
|
|
CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd,
|
|
)
|
|
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
|
|
|
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
|
run_kernel(kernel_fn, tl, *kernel_args)
|
|
commands = tl.commands
|
|
|
|
pending: dict[str, simpy.Event] = {}
|
|
composite_results: list[dict] = []
|
|
|
|
for cmd in commands:
|
|
if isinstance(cmd, PeCpuOverheadCmd):
|
|
yield env.timeout(cmd.cycles)
|
|
elif isinstance(cmd, WaitCmd):
|
|
if cmd.handle is not None:
|
|
evt = pending.pop(cmd.handle.id, None)
|
|
if evt:
|
|
yield evt
|
|
else:
|
|
for evt in pending.values():
|
|
yield evt
|
|
pending.clear()
|
|
elif isinstance(cmd, CompositeCmd):
|
|
done_evt = env.event()
|
|
pe_txn = PeInternalTxn(
|
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
|
)
|
|
composite_results.append(pe_txn.result_data)
|
|
yield self.out_ports[scheduler_id].put(pe_txn)
|
|
pending[cmd.completion.id] = done_evt
|
|
else:
|
|
done_evt = env.event()
|
|
pe_txn = PeInternalTxn(
|
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
|
)
|
|
yield self.out_ports[scheduler_id].put(pe_txn)
|
|
yield done_evt
|
|
|
|
for evt in pending.values():
|
|
yield evt
|
|
return composite_results
|
|
|
|
def _send_response(self, env, txn, request) -> Generator:
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
from kernbench.runtime_api.kernel import ResponseMsg
|
|
|
|
resp_msg = ResponseMsg(
|
|
correlation_id=request.correlation_id,
|
|
request_id=request.request_id,
|
|
src_cube=self._cube_idx, src_pe=self._pe_idx,
|
|
success=True,
|
|
)
|
|
resp_txn = Transaction(
|
|
request=resp_msg, path=reverse_path, step=0,
|
|
nbytes=0, done=env.event(), is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
else:
|
|
txn.done.succeed()
|