"""DataExecutor: Phase 2 op_log-based data execution (ADR-0020 D6). Executes GEMM/Math operations from the op_log using numpy. Memory ops are skipped (already handled in Phase 1 via MemoryStore). Same-timestamp independent ops can be batched for efficiency. """ from __future__ import annotations from itertools import groupby from typing import Any import numpy as np from kernbench.sim_engine.memory_store import MemoryStore, _resolve_dtype from kernbench.sim_engine.op_log import OpRecord class DataExecutor: """Phase 2 executor: replay op_log with actual numpy computation. Args: op_log: list of OpRecords from Phase 1. store: MemoryStore snapshot from Phase 1 (contains tensor data). """ def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None: self._op_log = op_log self.store = store def run(self) -> None: """Execute all ops in op_log order, grouped by t_start.""" for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start): ops = list(ops_iter) for op in ops: self._execute_op(op) def _execute_op(self, op: OpRecord) -> None: if op.op_kind == "memory": self._execute_memory(op) elif op.op_kind == "gemm": self._execute_gemm(op) elif op.op_kind == "math": self._execute_math(op) def _execute_memory(self, op: OpRecord) -> None: """Memory ops are already handled by Phase 1 MemoryStore. Skip.""" def _execute_gemm(self, op: OpRecord) -> None: """Execute GEMM: out = a @ b.""" p = op.params if "src_a_addr" not in p: return # composite record without full params space = p.get("addr_space", "tcm") dtype_in = p.get("dtype_in", "f16") dtype_out = p.get("dtype_out", dtype_in) a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in) b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in) # Compute in higher precision if specified dtype_acc = p.get("dtype_acc", "f32") a_f = a.astype(_resolve_dtype(dtype_acc)) b_f = b.astype(_resolve_dtype(dtype_acc)) result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out)) self.store.write(space, p["dst_addr"], result) def _execute_math(self, op: OpRecord) -> None: """Execute math op: unary, binary, or reduction.""" p = op.params math_op = p.get("op", op.op_name) space = p.get("addr_space", "tcm") dtype = p.get("dtype", "f32") input_addrs = p.get("input_addrs", []) input_shapes = p.get("input_shapes", []) inputs = [] for addr, shape in zip(input_addrs, input_shapes): inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype)) result = _compute_math(math_op, inputs, p.get("axis")) if result is not None: self.store.write(space, p["dst_addr"], result) def verify(self, expected: dict[tuple[str, int], np.ndarray], rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]: """Compare MemoryStore contents against expected tensors. Args: expected: {(space, addr): expected_ndarray} rtol, atol: tolerance for floating-point comparison. Returns: {key_str: passed} dict. """ results = {} for (space, addr), exp in expected.items(): key = f"{space}:0x{addr:x}" try: actual = self.store.read(space, addr) if np.issubdtype(actual.dtype, np.integer): results[key] = bool(np.array_equal(actual, exp)) else: results[key] = bool(np.allclose(actual, exp, rtol=rtol, atol=atol)) except KeyError: results[key] = False return results def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.ndarray | None: """Execute a math operation on numpy arrays.""" if not inputs: return None x = inputs[0] # Unary if op == "exp": return np.exp(x) if op == "log": return np.log(x) if op == "sqrt": return np.sqrt(x) if op == "abs": return np.abs(x) if op == "sigmoid": return 1.0 / (1.0 + np.exp(-x)) if op == "cos": return np.cos(x) if op == "sin": return np.sin(x) # Reduction if op == "sum": return np.sum(x, axis=axis, keepdims=True) if op == "max": return np.max(x, axis=axis, keepdims=True) if op == "min": return np.min(x, axis=axis, keepdims=True) # Binary if len(inputs) >= 2: y = inputs[1] if op == "add": return x + y if op == "sub": return x - y if op == "mul": return x * y if op == "div": return x / y # Ternary if op == "where" and len(inputs) >= 3: return np.where(inputs[0], inputs[1], inputs[2]) return None