Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -33,6 +33,7 @@ class ComponentBase(ABC):
|
||||
self.ctx = ctx
|
||||
self.in_ports: dict[str, simpy.Store] = {}
|
||||
self.out_ports: dict[str, simpy.Store] = {}
|
||||
self._op_logger: Any | None = None # OpLogger, set by GraphEngine if enabled
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
"""Called once after all ports are wired.
|
||||
@@ -64,9 +65,21 @@ class ComponentBase(ABC):
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._forward_txn(env, txn))
|
||||
|
||||
def _on_process_start(self, env: simpy.Environment, msg: Any) -> None:
|
||||
"""Op log hook: record service start for data_op messages (ADR-0020 D2)."""
|
||||
if self._op_logger and getattr(msg, "data_op", False):
|
||||
self._op_logger.record_start(env.now, self.node.id, msg)
|
||||
|
||||
def _on_process_end(self, env: simpy.Environment, msg: Any) -> None:
|
||||
"""Op log hook: record service end for data_op messages (ADR-0020 D2)."""
|
||||
if self._op_logger and getattr(msg, "data_op", False):
|
||||
self._op_logger.record_end(env.now, self.node.id, msg)
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Apply run() latency, then forward to next hop or drain at terminal."""
|
||||
self._on_process_start(env, txn)
|
||||
yield from self.run(env, txn.nbytes)
|
||||
self._on_process_end(env, txn)
|
||||
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
@@ -120,10 +133,16 @@ class PeEngineBase(ComponentBase):
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self.handle_command(env, msg))
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
def _handle_with_hooks(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Wrap handle_command with op log hooks on the inner command."""
|
||||
self._on_process_start(env, pe_txn.command)
|
||||
yield from self.handle_command(env, pe_txn)
|
||||
self._on_process_end(env, pe_txn.command)
|
||||
|
||||
@abstractmethod
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Process a PE-internal command (PeInternalTxn).
|
||||
|
||||
@@ -65,24 +65,45 @@ class PeCpuComponent(ComponentBase):
|
||||
yield from self._forward_txn(env, txn)
|
||||
|
||||
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Compile kernel function and replay command trace."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd,
|
||||
PeCpuOverheadCmd,
|
||||
PeInternalTxn,
|
||||
WaitCmd,
|
||||
)
|
||||
"""Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay."""
|
||||
from kernbench.triton_emu.registry import get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
request = txn.request
|
||||
|
||||
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
num_programs = self._derive_num_programs(request)
|
||||
kernel_args = self._unpack_kernel_args(request)
|
||||
|
||||
# Derive num_programs from the number of PE shards in this cube
|
||||
pe_exec_start = env.now
|
||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||
|
||||
# Choose execution mode: greenlet (ADR-0020) or legacy command-list
|
||||
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
|
||||
|
||||
if store is not None:
|
||||
composite_results = yield from self._execute_greenlet(
|
||||
env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
||||
)
|
||||
else:
|
||||
composite_results = yield from self._execute_legacy(
|
||||
env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
||||
)
|
||||
|
||||
# Record PE-internal execution time
|
||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
for rd in composite_results:
|
||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||
txn.result_data["dma_ns"] = total_dma_ns
|
||||
txn.result_data["compute_ns"] = total_compute_ns
|
||||
|
||||
# Send ResponseMsg on reverse path
|
||||
yield from self._send_response(env, txn, request)
|
||||
|
||||
def _derive_num_programs(self, request: Any) -> int:
|
||||
num_programs = 1
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
@@ -92,11 +113,9 @@ class PeCpuComponent(ComponentBase):
|
||||
)
|
||||
if cube_pe_count > num_programs:
|
||||
num_programs = cube_pe_count
|
||||
return num_programs
|
||||
|
||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → va_base (already local, set by runtime) or PA fallback
|
||||
def _unpack_kernel_args(self, request: Any) -> list:
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
@@ -107,15 +126,41 @@ class PeCpuComponent(ComponentBase):
|
||||
kernel_args.append(shard.pa)
|
||||
elif arg.arg_kind == "scalar":
|
||||
kernel_args.append(arg.value)
|
||||
return kernel_args
|
||||
|
||||
def _execute_greenlet(
|
||||
self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
||||
) -> Generator:
|
||||
"""Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved."""
|
||||
from kernbench.triton_emu.kernel_runner import KernelRunner
|
||||
|
||||
runner = KernelRunner(
|
||||
pe_prefix=self._pe_prefix,
|
||||
pe_idx=self._pe_idx,
|
||||
sip_idx=self._sip_idx,
|
||||
cube_idx=self._cube_idx,
|
||||
scheduler_id=scheduler_id,
|
||||
out_ports=self.out_ports,
|
||||
store=store,
|
||||
)
|
||||
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
|
||||
return getattr(runner, "_composite_results", [])
|
||||
|
||||
def _execute_legacy(
|
||||
self, env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
||||
) -> Generator:
|
||||
"""Legacy Phase 0 + replay: generate command list, then dispatch."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||
run_kernel(kernel_fn, tl, *kernel_args)
|
||||
commands = tl.commands
|
||||
|
||||
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
|
||||
pe_exec_start = env.now
|
||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||
pending: dict[str, simpy.Event] = {} # completion_id → done event
|
||||
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
|
||||
pending: dict[str, simpy.Event] = {}
|
||||
composite_results: list[dict] = []
|
||||
|
||||
for cmd in commands:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
@@ -126,47 +171,30 @@ class PeCpuComponent(ComponentBase):
|
||||
if evt:
|
||||
yield evt
|
||||
else:
|
||||
# Wait all pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking: dispatch to scheduler, track completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
else:
|
||||
# Blocking: dispatch and wait for completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
|
||||
# Wait for any remaining pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
return composite_results
|
||||
|
||||
# Record PE-internal execution time
|
||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||
|
||||
# Aggregate dma_ns / compute_ns from CompositeCmd results
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
for rd in composite_results:
|
||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||
txn.result_data["dma_ns"] = total_dma_ns
|
||||
txn.result_data["compute_ns"] = total_compute_ns
|
||||
|
||||
# Send ResponseMsg on reverse path (PE_CPU → NOC → M_CPU)
|
||||
def _send_response(self, env, txn, request) -> Generator:
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
@@ -24,6 +24,8 @@ class ComponentContext:
|
||||
ns_per_mm: float # wire propagation constant (from topology spec)
|
||||
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
|
||||
memory_store: Any = None # MemoryStore for Phase 1 data-aware execution (ADR-0020)
|
||||
op_logger: Any = None # OpLogger for Phase 1 op recording (ADR-0020)
|
||||
|
||||
def get_shared_resource(
|
||||
self, env: simpy.Environment, key: str, capacity: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user