Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,157 @@
|
||||
"""DataExecutor: Phase 2 op_log-based data execution (ADR-0020 D6).
|
||||
|
||||
Executes GEMM/Math operations from the op_log using numpy.
|
||||
Memory ops are skipped (already handled in Phase 1 via MemoryStore).
|
||||
Same-timestamp independent ops can be batched for efficiency.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import groupby
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.sim_engine.memory_store import MemoryStore, _resolve_dtype
|
||||
from kernbench.sim_engine.op_log import OpRecord
|
||||
|
||||
|
||||
class DataExecutor:
|
||||
"""Phase 2 executor: replay op_log with actual numpy computation.
|
||||
|
||||
Args:
|
||||
op_log: list of OpRecords from Phase 1.
|
||||
store: MemoryStore snapshot from Phase 1 (contains tensor data).
|
||||
"""
|
||||
|
||||
def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None:
|
||||
self._op_log = op_log
|
||||
self.store = store
|
||||
|
||||
def run(self) -> None:
|
||||
"""Execute all ops in op_log order, grouped by t_start."""
|
||||
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
|
||||
ops = list(ops_iter)
|
||||
for op in ops:
|
||||
self._execute_op(op)
|
||||
|
||||
def _execute_op(self, op: OpRecord) -> None:
|
||||
if op.op_kind == "memory":
|
||||
self._execute_memory(op)
|
||||
elif op.op_kind == "gemm":
|
||||
self._execute_gemm(op)
|
||||
elif op.op_kind == "math":
|
||||
self._execute_math(op)
|
||||
|
||||
def _execute_memory(self, op: OpRecord) -> None:
|
||||
"""Memory ops are already handled by Phase 1 MemoryStore. Skip."""
|
||||
|
||||
def _execute_gemm(self, op: OpRecord) -> None:
|
||||
"""Execute GEMM: out = a @ b."""
|
||||
p = op.params
|
||||
if "src_a_addr" not in p:
|
||||
return # composite record without full params
|
||||
space = p.get("addr_space", "tcm")
|
||||
dtype_in = p.get("dtype_in", "f16")
|
||||
dtype_out = p.get("dtype_out", dtype_in)
|
||||
|
||||
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in)
|
||||
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in)
|
||||
|
||||
# Compute in higher precision if specified
|
||||
dtype_acc = p.get("dtype_acc", "f32")
|
||||
a_f = a.astype(_resolve_dtype(dtype_acc))
|
||||
b_f = b.astype(_resolve_dtype(dtype_acc))
|
||||
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
|
||||
|
||||
self.store.write(space, p["dst_addr"], result)
|
||||
|
||||
def _execute_math(self, op: OpRecord) -> None:
|
||||
"""Execute math op: unary, binary, or reduction."""
|
||||
p = op.params
|
||||
math_op = p.get("op", op.op_name)
|
||||
space = p.get("addr_space", "tcm")
|
||||
dtype = p.get("dtype", "f32")
|
||||
input_addrs = p.get("input_addrs", [])
|
||||
input_shapes = p.get("input_shapes", [])
|
||||
|
||||
inputs = []
|
||||
for addr, shape in zip(input_addrs, input_shapes):
|
||||
inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype))
|
||||
|
||||
result = _compute_math(math_op, inputs, p.get("axis"))
|
||||
if result is not None:
|
||||
self.store.write(space, p["dst_addr"], result)
|
||||
|
||||
def verify(self, expected: dict[tuple[str, int], np.ndarray],
|
||||
rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]:
|
||||
"""Compare MemoryStore contents against expected tensors.
|
||||
|
||||
Args:
|
||||
expected: {(space, addr): expected_ndarray}
|
||||
rtol, atol: tolerance for floating-point comparison.
|
||||
|
||||
Returns:
|
||||
{key_str: passed} dict.
|
||||
"""
|
||||
results = {}
|
||||
for (space, addr), exp in expected.items():
|
||||
key = f"{space}:0x{addr:x}"
|
||||
try:
|
||||
actual = self.store.read(space, addr)
|
||||
if np.issubdtype(actual.dtype, np.integer):
|
||||
results[key] = bool(np.array_equal(actual, exp))
|
||||
else:
|
||||
results[key] = bool(np.allclose(actual, exp, rtol=rtol, atol=atol))
|
||||
except KeyError:
|
||||
results[key] = False
|
||||
return results
|
||||
|
||||
|
||||
def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.ndarray | None:
|
||||
"""Execute a math operation on numpy arrays."""
|
||||
if not inputs:
|
||||
return None
|
||||
|
||||
x = inputs[0]
|
||||
|
||||
# Unary
|
||||
if op == "exp":
|
||||
return np.exp(x)
|
||||
if op == "log":
|
||||
return np.log(x)
|
||||
if op == "sqrt":
|
||||
return np.sqrt(x)
|
||||
if op == "abs":
|
||||
return np.abs(x)
|
||||
if op == "sigmoid":
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
if op == "cos":
|
||||
return np.cos(x)
|
||||
if op == "sin":
|
||||
return np.sin(x)
|
||||
|
||||
# Reduction
|
||||
if op == "sum":
|
||||
return np.sum(x, axis=axis, keepdims=True)
|
||||
if op == "max":
|
||||
return np.max(x, axis=axis, keepdims=True)
|
||||
if op == "min":
|
||||
return np.min(x, axis=axis, keepdims=True)
|
||||
|
||||
# Binary
|
||||
if len(inputs) >= 2:
|
||||
y = inputs[1]
|
||||
if op == "add":
|
||||
return x + y
|
||||
if op == "sub":
|
||||
return x - y
|
||||
if op == "mul":
|
||||
return x * y
|
||||
if op == "div":
|
||||
return x / y
|
||||
|
||||
# Ternary
|
||||
if op == "where" and len(inputs) >= 3:
|
||||
return np.where(inputs[0], inputs[1], inputs[2])
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user