Implement ADR-0020: 2-pass data execution with greenlet kernel runner

Step 1 — Foundation:
- OpRecord/OpLogger: op log infrastructure with t_start stable ordering
- MemoryStore: numpy ndarray tensor-granular storage (reference semantics)
- data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd
- numpy/greenlet dependencies added to pyproject.toml

Step 2 — ComponentBase hooks:
- _on_process_start/end hooks in _forward_txn (fabric messages)
- _handle_with_hooks in PeEngineBase (PE-internal commands)
- op_logger optional — zero overhead when disabled

Step 3 — KernelRunner + greenlet:
- KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py
- TLContext: _emit() method routes to greenlet switch or command list
- tl.load() returns real numpy data in greenlet mode
- Dynamic control flow supported (memory-read based branching)

Step 4 — PE_CPU integration:
- Greenlet mode when ctx.memory_store is set, legacy fallback otherwise
- Refactored into _execute_greenlet/_execute_legacy/_send_response
- ComponentContext gains memory_store and op_logger fields

Step 5 — DataExecutor:
- Phase 2 numpy execution for GEMM/Math ops from op_log
- _compute_math: all unary/binary/reduction ops
- verify(): compare MemoryStore against expected with dtype tolerance

28 new tests, 366 total passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-08 00:22:44 -07:00
parent 140b85436a
commit 51004c311c
14 changed files with 1181 additions and 59 deletions
+157
View File
@@ -0,0 +1,157 @@
"""DataExecutor: Phase 2 op_log-based data execution (ADR-0020 D6).
Executes GEMM/Math operations from the op_log using numpy.
Memory ops are skipped (already handled in Phase 1 via MemoryStore).
Same-timestamp independent ops can be batched for efficiency.
"""
from __future__ import annotations
from itertools import groupby
from typing import Any
import numpy as np
from kernbench.sim_engine.memory_store import MemoryStore, _resolve_dtype
from kernbench.sim_engine.op_log import OpRecord
class DataExecutor:
"""Phase 2 executor: replay op_log with actual numpy computation.
Args:
op_log: list of OpRecords from Phase 1.
store: MemoryStore snapshot from Phase 1 (contains tensor data).
"""
def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None:
self._op_log = op_log
self.store = store
def run(self) -> None:
"""Execute all ops in op_log order, grouped by t_start."""
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
ops = list(ops_iter)
for op in ops:
self._execute_op(op)
def _execute_op(self, op: OpRecord) -> None:
if op.op_kind == "memory":
self._execute_memory(op)
elif op.op_kind == "gemm":
self._execute_gemm(op)
elif op.op_kind == "math":
self._execute_math(op)
def _execute_memory(self, op: OpRecord) -> None:
"""Memory ops are already handled by Phase 1 MemoryStore. Skip."""
def _execute_gemm(self, op: OpRecord) -> None:
"""Execute GEMM: out = a @ b."""
p = op.params
if "src_a_addr" not in p:
return # composite record without full params
space = p.get("addr_space", "tcm")
dtype_in = p.get("dtype_in", "f16")
dtype_out = p.get("dtype_out", dtype_in)
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in)
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in)
# Compute in higher precision if specified
dtype_acc = p.get("dtype_acc", "f32")
a_f = a.astype(_resolve_dtype(dtype_acc))
b_f = b.astype(_resolve_dtype(dtype_acc))
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
self.store.write(space, p["dst_addr"], result)
def _execute_math(self, op: OpRecord) -> None:
"""Execute math op: unary, binary, or reduction."""
p = op.params
math_op = p.get("op", op.op_name)
space = p.get("addr_space", "tcm")
dtype = p.get("dtype", "f32")
input_addrs = p.get("input_addrs", [])
input_shapes = p.get("input_shapes", [])
inputs = []
for addr, shape in zip(input_addrs, input_shapes):
inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype))
result = _compute_math(math_op, inputs, p.get("axis"))
if result is not None:
self.store.write(space, p["dst_addr"], result)
def verify(self, expected: dict[tuple[str, int], np.ndarray],
rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]:
"""Compare MemoryStore contents against expected tensors.
Args:
expected: {(space, addr): expected_ndarray}
rtol, atol: tolerance for floating-point comparison.
Returns:
{key_str: passed} dict.
"""
results = {}
for (space, addr), exp in expected.items():
key = f"{space}:0x{addr:x}"
try:
actual = self.store.read(space, addr)
if np.issubdtype(actual.dtype, np.integer):
results[key] = bool(np.array_equal(actual, exp))
else:
results[key] = bool(np.allclose(actual, exp, rtol=rtol, atol=atol))
except KeyError:
results[key] = False
return results
def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.ndarray | None:
"""Execute a math operation on numpy arrays."""
if not inputs:
return None
x = inputs[0]
# Unary
if op == "exp":
return np.exp(x)
if op == "log":
return np.log(x)
if op == "sqrt":
return np.sqrt(x)
if op == "abs":
return np.abs(x)
if op == "sigmoid":
return 1.0 / (1.0 + np.exp(-x))
if op == "cos":
return np.cos(x)
if op == "sin":
return np.sin(x)
# Reduction
if op == "sum":
return np.sum(x, axis=axis, keepdims=True)
if op == "max":
return np.max(x, axis=axis, keepdims=True)
if op == "min":
return np.min(x, axis=axis, keepdims=True)
# Binary
if len(inputs) >= 2:
y = inputs[1]
if op == "add":
return x + y
if op == "sub":
return x - y
if op == "mul":
return x * y
if op == "div":
return x / y
# Ternary
if op == "where" and len(inputs) >= 3:
return np.where(inputs[0], inputs[1], inputs[2])
return None
+84
View File
@@ -0,0 +1,84 @@
"""MemoryStore: tensor-granular storage for Phase 1 and Phase 2 (ADR-0020 D7).
Logically byte-addressable, implemented as addr → numpy ndarray mapping.
Read/write are reference-based (no copy) for Phase 1 performance.
"""
from __future__ import annotations
import numpy as np
# numpy dtype string → numpy dtype mapping
_DTYPE_MAP = {
"f16": np.float16,
"f32": np.float32,
"f64": np.float64,
"bf16": np.float16, # numpy has no bfloat16; use float16 as proxy
"i8": np.int8,
"i16": np.int16,
"i32": np.int32,
"i64": np.int64,
"u8": np.uint8,
"u16": np.uint16,
"u32": np.uint32,
}
def _resolve_dtype(dtype: str) -> np.dtype:
if dtype in _DTYPE_MAP:
return np.dtype(_DTYPE_MAP[dtype])
return np.dtype(dtype)
class MemoryStore:
"""Tensor-granular memory storage (ADR-0020 D7).
Stores numpy ndarrays by (space, addr) key.
Write = reference store (no copy), read = reference return (no copy).
Overwrite at same addr replaces the entire tensor.
"""
def __init__(self) -> None:
# {space: {addr: ndarray}}
self._storage: dict[str, dict[int, np.ndarray]] = {}
def write(self, space: str, addr: int, data: np.ndarray) -> None:
"""Store tensor at (space, addr). Reference-only, no copy."""
if space not in self._storage:
self._storage[space] = {}
self._storage[space][addr] = data
def read(self, space: str, addr: int, shape: tuple[int, ...] | None = None,
dtype: str | None = None) -> np.ndarray:
"""Read tensor from (space, addr). Returns reference, no copy.
If shape/dtype match stored tensor, returns as-is.
If dtype differs, performs reinterpret cast (view).
If shape differs but nbytes match, reshapes.
"""
store = self._storage.get(space)
if store is None or addr not in store:
raise KeyError(f"No data at ({space}, 0x{addr:x})")
arr = store[addr]
if dtype is not None:
np_dtype = _resolve_dtype(dtype)
if arr.dtype != np_dtype:
arr = arr.view(np_dtype)
if shape is not None and arr.shape != shape:
if arr.nbytes != np.prod(shape) * arr.dtype.itemsize:
raise ValueError(
f"Shape mismatch: stored {arr.shape} ({arr.nbytes}B) "
f"vs requested {shape} ({np.prod(shape) * arr.dtype.itemsize}B)"
)
arr = arr.reshape(shape)
return arr
def has(self, space: str, addr: int) -> bool:
return addr in self._storage.get(space, {})
def snapshot(self) -> MemoryStore:
"""Create a shallow copy for Phase 2 initialization."""
new = MemoryStore()
for space, addrs in self._storage.items():
new._storage[space] = dict(addrs) # shallow copy of addr→ndarray map
return new
+111
View File
@@ -0,0 +1,111 @@
"""Op log infrastructure for 2-pass data execution (ADR-0020 D2, D5).
OpRecord: single data operation with timing, params, and dependencies.
OpLogger: collects OpRecords from ComponentBase hooks during Phase 1.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
@dataclass
class OpRecord:
"""Single data operation record (ADR-0020 D5)."""
t_start: float
t_end: float
component_id: str
op_kind: str # "memory" | "gemm" | "math"
op_name: str # e.g. "dma_read", "gemm_f16", "exp"
params: dict[str, Any]
dependency_ids: list[int] = field(default_factory=list)
class OpLogger:
"""Collects OpRecords during Phase 1 simulation (ADR-0020 D2).
Thread-safe is not required — SimPy is single-threaded.
Records are maintained in t_start stable ordering (insertion order).
"""
def __init__(self) -> None:
self._records: list[OpRecord] = []
self._pending: dict[int, dict[str, Any]] = {} # msg id → partial record
@property
def records(self) -> list[OpRecord]:
"""Records sorted by t_start (stable ordering per ADR-0020 D5)."""
self._records.sort(key=lambda r: r.t_start)
return self._records
def record_start(self, t: float, component_id: str, msg: Any) -> None:
"""Called by ComponentBase._on_process_start."""
self._pending[id(msg)] = {
"t_start": t,
"component_id": component_id,
"msg": msg,
}
def record_end(self, t: float, component_id: str, msg: Any) -> None:
"""Called by ComponentBase._on_process_end."""
pending = self._pending.pop(id(msg), None)
if pending is None:
return
op_kind, op_name, params = _extract_op_info(msg)
self._records.append(OpRecord(
t_start=pending["t_start"],
t_end=t,
component_id=pending["component_id"],
op_kind=op_kind,
op_name=op_name,
params=params,
))
def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"""Extract op_kind, op_name, params from a data_op message."""
from kernbench.common.pe_commands import (
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd,
)
if isinstance(msg, DmaReadCmd):
return "memory", "dma_read", {
"src_addr": msg.src_addr,
"nbytes": msg.nbytes,
"handle_id": msg.handle.id,
}
if isinstance(msg, DmaWriteCmd):
return "memory", "dma_write", {
"dst_addr": msg.dst_addr,
"nbytes": msg.nbytes,
"handle_id": msg.handle.id,
}
if isinstance(msg, GemmCmd):
return "gemm", f"gemm_{msg.a.dtype}", {
"src_a_addr": msg.a.addr,
"src_b_addr": msg.b.addr,
"dst_addr": msg.out.addr,
"shape_a": msg.a.shape,
"shape_b": msg.b.shape,
"shape_out": msg.out.shape,
"dtype_in": msg.a.dtype,
"dtype_out": msg.out.dtype,
"m": msg.m, "k": msg.k, "n": msg.n,
}
if isinstance(msg, MathCmd):
return "math", msg.op, {
"input_addrs": [h.addr for h in msg.inputs],
"input_shapes": [h.shape for h in msg.inputs],
"dst_addr": msg.out.addr,
"shape_out": msg.out.shape,
"dtype": msg.out.dtype,
"axis": msg.axis,
}
if isinstance(msg, CompositeCmd):
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", {
"op": msg.op,
"out_addr": msg.out_addr,
"out_nbytes": msg.out_nbytes,
}
# Fallback for unknown data_op messages
return "unknown", type(msg).__name__, {}