Implement ADR-0020: 2-pass data execution with greenlet kernel runner

Step 1 — Foundation:
- OpRecord/OpLogger: op log infrastructure with t_start stable ordering
- MemoryStore: numpy ndarray tensor-granular storage (reference semantics)
- data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd
- numpy/greenlet dependencies added to pyproject.toml

Step 2 — ComponentBase hooks:
- _on_process_start/end hooks in _forward_txn (fabric messages)
- _handle_with_hooks in PeEngineBase (PE-internal commands)
- op_logger optional — zero overhead when disabled

Step 3 — KernelRunner + greenlet:
- KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py
- TLContext: _emit() method routes to greenlet switch or command list
- tl.load() returns real numpy data in greenlet mode
- Dynamic control flow supported (memory-read based branching)

Step 4 — PE_CPU integration:
- Greenlet mode when ctx.memory_store is set, legacy fallback otherwise
- Refactored into _execute_greenlet/_execute_legacy/_send_response
- ComponentContext gains memory_store and op_logger fields

Step 5 — DataExecutor:
- Phase 2 numpy execution for GEMM/Math ops from op_log
- _compute_math: all unary/binary/reduction ops
- verify(): compare MemoryStore against expected with dtype tolerance

28 new tests, 366 total passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-08 00:22:44 -07:00
parent 140b85436a
commit 51004c311c
14 changed files with 1181 additions and 59 deletions
+188
View File
@@ -0,0 +1,188 @@
"""Tests for DataExecutor Phase 2 execution (ADR-0020 D6)."""
import numpy as np
from kernbench.sim_engine.data_executor import DataExecutor
from kernbench.sim_engine.memory_store import MemoryStore
from kernbench.sim_engine.op_log import OpRecord
def test_gemm_execution():
"""Phase 2 GEMM: out = a @ b with f32 accumulation."""
store = MemoryStore()
a = np.ones((4, 8), dtype=np.float16)
b = np.ones((8, 4), dtype=np.float16) * 2.0
store.write("tcm", 0x0, a)
store.write("tcm", 0x100, b)
op = OpRecord(
t_start=0.0, t_end=100.0,
component_id="pe_gemm",
op_kind="gemm", op_name="gemm_f16",
params={
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
"shape_a": (4, 8), "shape_b": (8, 4), "shape_out": (4, 4),
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16",
"addr_space": "tcm",
},
)
executor = DataExecutor([op], store)
executor.run()
result = store.read("tcm", 0x200)
expected = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float16)
assert np.allclose(result, expected)
def test_math_exp():
store = MemoryStore()
x = np.array([0.0, 1.0, 2.0], dtype=np.float32)
store.write("tcm", 0x0, x)
op = OpRecord(
t_start=0.0, t_end=10.0,
component_id="pe_math",
op_kind="math", op_name="exp",
params={
"op": "exp",
"input_addrs": [0x0], "input_shapes": [(3,)],
"dst_addr": 0x100, "shape_out": (3,),
"dtype": "f32", "axis": None, "addr_space": "tcm",
},
)
executor = DataExecutor([op], store)
executor.run()
result = store.read("tcm", 0x100)
assert np.allclose(result, np.exp(x))
def test_math_add():
store = MemoryStore()
a = np.array([1.0, 2.0], dtype=np.float32)
b = np.array([3.0, 4.0], dtype=np.float32)
store.write("tcm", 0x0, a)
store.write("tcm", 0x100, b)
op = OpRecord(
t_start=0.0, t_end=5.0,
component_id="pe_math",
op_kind="math", op_name="add",
params={
"op": "add",
"input_addrs": [0x0, 0x100], "input_shapes": [(2,), (2,)],
"dst_addr": 0x200, "shape_out": (2,),
"dtype": "f32", "axis": None, "addr_space": "tcm",
},
)
executor = DataExecutor([op], store)
executor.run()
result = store.read("tcm", 0x200)
assert np.array_equal(result, np.array([4.0, 6.0], dtype=np.float32))
def test_math_sum_reduction():
store = MemoryStore()
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
store.write("tcm", 0x0, x)
op = OpRecord(
t_start=0.0, t_end=5.0,
component_id="pe_math",
op_kind="math", op_name="sum",
params={
"op": "sum",
"input_addrs": [0x0], "input_shapes": [(2, 2)],
"dst_addr": 0x100, "shape_out": (1, 2),
"dtype": "f32", "axis": 0, "addr_space": "tcm",
},
)
executor = DataExecutor([op], store)
executor.run()
result = store.read("tcm", 0x100)
assert np.array_equal(result, np.array([[4.0, 6.0]], dtype=np.float32))
def test_verify_pass():
store = MemoryStore()
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
executor = DataExecutor([], store)
results = executor.verify({
("hbm", 0x0): np.array([1.0, 2.0], dtype=np.float32),
})
assert results["hbm:0x0"] is True
def test_verify_fail():
store = MemoryStore()
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
executor = DataExecutor([], store)
results = executor.verify({
("hbm", 0x0): np.array([9.0, 9.0], dtype=np.float32),
})
assert results["hbm:0x0"] is False
def test_memory_ops_skipped():
"""Memory ops in op_log should be skipped (handled in Phase 1)."""
store = MemoryStore()
op = OpRecord(
t_start=0.0, t_end=5.0,
component_id="pe_dma",
op_kind="memory", op_name="dma_read",
params={"src_addr": 0x0, "nbytes": 64, "handle_id": "t0"},
)
# Should not raise
executor = DataExecutor([op], store)
executor.run()
def test_sequential_gemm_then_math():
"""GEMM output feeds into math op."""
store = MemoryStore()
a = np.eye(2, dtype=np.float16)
b = np.ones((2, 2), dtype=np.float16)
store.write("tcm", 0x0, a)
store.write("tcm", 0x100, b)
ops = [
OpRecord(
t_start=0.0, t_end=50.0,
component_id="pe_gemm",
op_kind="gemm", op_name="gemm_f16",
params={
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
"shape_a": (2, 2), "shape_b": (2, 2), "shape_out": (2, 2),
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f32",
"addr_space": "tcm",
},
),
OpRecord(
t_start=50.0, t_end=55.0,
component_id="pe_math",
op_kind="math", op_name="exp",
params={
"op": "exp",
"input_addrs": [0x200], "input_shapes": [(2, 2)],
"dst_addr": 0x300, "shape_out": (2, 2),
"dtype": "f32", "axis": None, "addr_space": "tcm",
},
),
]
executor = DataExecutor(ops, store)
executor.run()
gemm_result = store.read("tcm", 0x200)
expected_gemm = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float32)
assert np.allclose(gemm_result, expected_gemm)
exp_result = store.read("tcm", 0x300)
assert np.allclose(exp_result, np.exp(expected_gemm))
+140
View File
@@ -0,0 +1,140 @@
"""Tests for KernelRunner greenlet-based execution (ADR-0020 D3)."""
import numpy as np
import simpy
from kernbench.sim_engine.memory_store import MemoryStore
from kernbench.triton_emu.kernel_runner import KernelRunner
def _make_runner(env, store=None):
"""Create a minimal KernelRunner with mock scheduler port."""
scheduler_id = "sip0.cube0.pe0.pe_scheduler"
out_ports = {scheduler_id: simpy.Store(env)}
runner = KernelRunner(
pe_prefix="sip0.cube0.pe0",
pe_idx=0, sip_idx=0, cube_idx=0,
scheduler_id=scheduler_id,
out_ports=out_ports,
store=store,
)
return runner, out_ports[scheduler_id]
def _mock_scheduler(env, inbox):
"""Consume PeInternalTxn from inbox and immediately succeed."""
while True:
pe_txn = yield inbox.get()
pe_txn.done.succeed()
def test_kernel_runner_basic_load():
"""Kernel with tl.load runs through greenlet without hanging."""
env = simpy.Environment()
store = MemoryStore()
data = np.ones((4, 4), dtype=np.float16)
store.write("hbm", 0x1000, data)
runner, sched_port = _make_runner(env, store)
env.process(_mock_scheduler(env, sched_port))
def kernel(a_ptr, tl):
a = tl.load(a_ptr, (4, 4), "f16")
assert a.data is not None
assert a.data.shape == (4, 4)
def run():
yield from runner.run(env, kernel, [0x1000], num_programs=1)
env.process(run())
env.run()
def test_kernel_runner_load_returns_data():
"""tl.load returns actual numpy data from MemoryStore."""
env = simpy.Environment()
store = MemoryStore()
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16)
store.write("hbm", 0x2000, data)
runner, sched_port = _make_runner(env, store)
env.process(_mock_scheduler(env, sched_port))
results = {}
def kernel(ptr, tl):
a = tl.load(ptr, (2, 2), "f16")
results["data"] = a.data
def run():
yield from runner.run(env, kernel, [0x2000], num_programs=1)
env.process(run())
env.run()
assert results["data"] is data # reference equality
def test_kernel_runner_composite():
"""Composite commands pass through without blocking kernel."""
env = simpy.Environment()
runner, sched_port = _make_runner(env)
env.process(_mock_scheduler(env, sched_port))
def kernel(a_ptr, b_ptr, out_ptr, tl):
a = tl.ref(a_ptr, (4, 8), "f16")
b = tl.ref(b_ptr, (8, 4), "f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
tl.wait(h)
def run():
yield from runner.run(env, kernel, [0, 64, 128], num_programs=1)
env.process(run())
env.run()
def test_kernel_runner_dynamic_branch():
"""Kernel can branch based on loaded data (ADR-0020 D3)."""
env = simpy.Environment()
store = MemoryStore()
store.write("hbm", 0x100, np.array([1.0], dtype=np.float32))
store.write("hbm", 0x200, np.array([0.0], dtype=np.float32))
runner, sched_port = _make_runner(env, store)
env.process(_mock_scheduler(env, sched_port))
results = {"branch": None}
def kernel(flag_ptr, tl):
flag = tl.load(flag_ptr, (1,), "f32")
if flag.data is not None and flag.data[0] > 0.5:
results["branch"] = "taken"
else:
results["branch"] = "not_taken"
# Test with flag=1.0 → branch taken
def run():
yield from runner.run(env, kernel, [0x100], num_programs=1)
env.process(run())
env.run()
assert results["branch"] == "taken"
def test_kernel_runner_no_store():
"""Without MemoryStore, tl.load returns handle with data=None."""
env = simpy.Environment()
runner, sched_port = _make_runner(env, store=None)
env.process(_mock_scheduler(env, sched_port))
results = {}
def kernel(ptr, tl):
a = tl.load(ptr, (4,), "f16")
results["data"] = a.data
def run():
yield from runner.run(env, kernel, [0], num_programs=1)
env.process(run())
env.run()
assert results["data"] is None
+85
View File
@@ -0,0 +1,85 @@
"""Tests for MemoryStore (ADR-0020 D7)."""
import numpy as np
import pytest
from kernbench.sim_engine.memory_store import MemoryStore
def test_write_read_reference():
"""Write and read return the same numpy array (no copy)."""
store = MemoryStore()
data = np.ones((4, 4), dtype=np.float16)
store.write("tcm", 0x0, data)
result = store.read("tcm", 0x0)
assert result is data
def test_overwrite_replaces():
"""Same addr write replaces the previous tensor."""
store = MemoryStore()
data1 = np.zeros((4,), dtype=np.float32)
data2 = np.ones((4,), dtype=np.float32)
store.write("hbm", 0x100, data1)
store.write("hbm", 0x100, data2)
result = store.read("hbm", 0x100)
assert result is data2
def test_read_missing_raises():
store = MemoryStore()
with pytest.raises(KeyError):
store.read("hbm", 0x999)
def test_read_different_space():
store = MemoryStore()
data = np.array([1, 2, 3], dtype=np.int32)
store.write("tcm", 0x0, data)
with pytest.raises(KeyError):
store.read("hbm", 0x0) # different space
def test_dtype_reinterpret():
"""Read with different dtype does view cast."""
store = MemoryStore()
data = np.array([1.0, 2.0], dtype=np.float32) # 8 bytes
store.write("tcm", 0x0, data)
result = store.read("tcm", 0x0, dtype="u8")
assert result.dtype == np.uint8
assert result.nbytes == data.nbytes
def test_reshape():
store = MemoryStore()
data = np.arange(12, dtype=np.float32)
store.write("tcm", 0x0, data)
result = store.read("tcm", 0x0, shape=(3, 4))
assert result.shape == (3, 4)
def test_shape_mismatch_raises():
store = MemoryStore()
data = np.arange(12, dtype=np.float32)
store.write("tcm", 0x0, data)
with pytest.raises(ValueError, match="Shape mismatch"):
store.read("tcm", 0x0, shape=(5, 5))
def test_has():
store = MemoryStore()
assert not store.has("tcm", 0x0)
store.write("tcm", 0x0, np.array([1]))
assert store.has("tcm", 0x0)
def test_snapshot():
store = MemoryStore()
data = np.ones((4,), dtype=np.float16)
store.write("hbm", 0x0, data)
snap = store.snapshot()
assert snap.read("hbm", 0x0) is data # same reference
# Modifying snap doesn't affect original
snap.write("hbm", 0x0, np.zeros((4,), dtype=np.float16))
assert store.read("hbm", 0x0) is data
+87
View File
@@ -0,0 +1,87 @@
"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5)."""
import numpy as np
from kernbench.common.pe_commands import (
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle,
)
from kernbench.sim_engine.op_log import OpLogger, OpRecord
def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"):
return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32)
def test_op_logger_record_start_end():
logger = OpLogger()
cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64)
logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd)
logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd)
assert len(logger.records) == 1
r = logger.records[0]
assert r.t_start == 10.0
assert r.t_end == 15.0
assert r.op_kind == "memory"
assert r.op_name == "dma_read"
assert r.params["src_addr"] == 0x1000
def test_op_logger_gemm():
logger = OpLogger()
a = _th("a", 0, (128, 256), "f16")
b = _th("b", 1024, (256, 128), "f16")
out = _th("out", 2048, (128, 128), "f16")
cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128)
logger.record_start(0.0, "pe_gemm", cmd)
logger.record_end(100.0, "pe_gemm", cmd)
r = logger.records[0]
assert r.op_kind == "gemm"
assert r.op_name == "gemm_f16"
assert r.params["m"] == 128
def test_op_logger_math():
logger = OpLogger()
x = _th("x", 0, (32,), "f32")
out = _th("out", 128, (32,), "f32")
cmd = MathCmd(op="exp", inputs=(x,), out=out)
logger.record_start(5.0, "pe_math", cmd)
logger.record_end(6.0, "pe_math", cmd)
r = logger.records[0]
assert r.op_kind == "math"
assert r.op_name == "exp"
def test_op_logger_stable_ordering():
logger = OpLogger()
cmds = [
DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64)
for i in range(5)
]
for i, cmd in enumerate(cmds):
logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start
for i, cmd in enumerate(cmds):
logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd)
# Verify insertion order preserved for same t_start
for i in range(len(logger.records) - 1):
assert logger.records[i].t_start <= logger.records[i + 1].t_start
def test_op_logger_unmatched_end_ignored():
logger = OpLogger()
cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32)
logger.record_end(5.0, "comp", cmd) # no matching start
assert len(logger.records) == 0
def test_data_op_flag():
"""DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't."""
assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False)
assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False)
a, b, out = _th("a"), _th("b"), _th("out")
assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False)
assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False)
# WaitCmd and PeCpuOverheadCmd should not have data_op
from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd
assert not getattr(WaitCmd(), "data_op", False)
assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)