Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,188 @@
|
||||
"""Tests for DataExecutor Phase 2 execution (ADR-0020 D6)."""
|
||||
import numpy as np
|
||||
|
||||
from kernbench.sim_engine.data_executor import DataExecutor
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
from kernbench.sim_engine.op_log import OpRecord
|
||||
|
||||
|
||||
def test_gemm_execution():
|
||||
"""Phase 2 GEMM: out = a @ b with f32 accumulation."""
|
||||
store = MemoryStore()
|
||||
a = np.ones((4, 8), dtype=np.float16)
|
||||
b = np.ones((8, 4), dtype=np.float16) * 2.0
|
||||
store.write("tcm", 0x0, a)
|
||||
store.write("tcm", 0x100, b)
|
||||
|
||||
op = OpRecord(
|
||||
t_start=0.0, t_end=100.0,
|
||||
component_id="pe_gemm",
|
||||
op_kind="gemm", op_name="gemm_f16",
|
||||
params={
|
||||
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
|
||||
"shape_a": (4, 8), "shape_b": (8, 4), "shape_out": (4, 4),
|
||||
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16",
|
||||
"addr_space": "tcm",
|
||||
},
|
||||
)
|
||||
|
||||
executor = DataExecutor([op], store)
|
||||
executor.run()
|
||||
|
||||
result = store.read("tcm", 0x200)
|
||||
expected = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_math_exp():
|
||||
store = MemoryStore()
|
||||
x = np.array([0.0, 1.0, 2.0], dtype=np.float32)
|
||||
store.write("tcm", 0x0, x)
|
||||
|
||||
op = OpRecord(
|
||||
t_start=0.0, t_end=10.0,
|
||||
component_id="pe_math",
|
||||
op_kind="math", op_name="exp",
|
||||
params={
|
||||
"op": "exp",
|
||||
"input_addrs": [0x0], "input_shapes": [(3,)],
|
||||
"dst_addr": 0x100, "shape_out": (3,),
|
||||
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||
},
|
||||
)
|
||||
|
||||
executor = DataExecutor([op], store)
|
||||
executor.run()
|
||||
|
||||
result = store.read("tcm", 0x100)
|
||||
assert np.allclose(result, np.exp(x))
|
||||
|
||||
|
||||
def test_math_add():
|
||||
store = MemoryStore()
|
||||
a = np.array([1.0, 2.0], dtype=np.float32)
|
||||
b = np.array([3.0, 4.0], dtype=np.float32)
|
||||
store.write("tcm", 0x0, a)
|
||||
store.write("tcm", 0x100, b)
|
||||
|
||||
op = OpRecord(
|
||||
t_start=0.0, t_end=5.0,
|
||||
component_id="pe_math",
|
||||
op_kind="math", op_name="add",
|
||||
params={
|
||||
"op": "add",
|
||||
"input_addrs": [0x0, 0x100], "input_shapes": [(2,), (2,)],
|
||||
"dst_addr": 0x200, "shape_out": (2,),
|
||||
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||
},
|
||||
)
|
||||
|
||||
executor = DataExecutor([op], store)
|
||||
executor.run()
|
||||
|
||||
result = store.read("tcm", 0x200)
|
||||
assert np.array_equal(result, np.array([4.0, 6.0], dtype=np.float32))
|
||||
|
||||
|
||||
def test_math_sum_reduction():
|
||||
store = MemoryStore()
|
||||
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
store.write("tcm", 0x0, x)
|
||||
|
||||
op = OpRecord(
|
||||
t_start=0.0, t_end=5.0,
|
||||
component_id="pe_math",
|
||||
op_kind="math", op_name="sum",
|
||||
params={
|
||||
"op": "sum",
|
||||
"input_addrs": [0x0], "input_shapes": [(2, 2)],
|
||||
"dst_addr": 0x100, "shape_out": (1, 2),
|
||||
"dtype": "f32", "axis": 0, "addr_space": "tcm",
|
||||
},
|
||||
)
|
||||
|
||||
executor = DataExecutor([op], store)
|
||||
executor.run()
|
||||
|
||||
result = store.read("tcm", 0x100)
|
||||
assert np.array_equal(result, np.array([[4.0, 6.0]], dtype=np.float32))
|
||||
|
||||
|
||||
def test_verify_pass():
|
||||
store = MemoryStore()
|
||||
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
|
||||
|
||||
executor = DataExecutor([], store)
|
||||
results = executor.verify({
|
||||
("hbm", 0x0): np.array([1.0, 2.0], dtype=np.float32),
|
||||
})
|
||||
assert results["hbm:0x0"] is True
|
||||
|
||||
|
||||
def test_verify_fail():
|
||||
store = MemoryStore()
|
||||
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
|
||||
|
||||
executor = DataExecutor([], store)
|
||||
results = executor.verify({
|
||||
("hbm", 0x0): np.array([9.0, 9.0], dtype=np.float32),
|
||||
})
|
||||
assert results["hbm:0x0"] is False
|
||||
|
||||
|
||||
def test_memory_ops_skipped():
|
||||
"""Memory ops in op_log should be skipped (handled in Phase 1)."""
|
||||
store = MemoryStore()
|
||||
op = OpRecord(
|
||||
t_start=0.0, t_end=5.0,
|
||||
component_id="pe_dma",
|
||||
op_kind="memory", op_name="dma_read",
|
||||
params={"src_addr": 0x0, "nbytes": 64, "handle_id": "t0"},
|
||||
)
|
||||
# Should not raise
|
||||
executor = DataExecutor([op], store)
|
||||
executor.run()
|
||||
|
||||
|
||||
def test_sequential_gemm_then_math():
|
||||
"""GEMM output feeds into math op."""
|
||||
store = MemoryStore()
|
||||
a = np.eye(2, dtype=np.float16)
|
||||
b = np.ones((2, 2), dtype=np.float16)
|
||||
store.write("tcm", 0x0, a)
|
||||
store.write("tcm", 0x100, b)
|
||||
|
||||
ops = [
|
||||
OpRecord(
|
||||
t_start=0.0, t_end=50.0,
|
||||
component_id="pe_gemm",
|
||||
op_kind="gemm", op_name="gemm_f16",
|
||||
params={
|
||||
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
|
||||
"shape_a": (2, 2), "shape_b": (2, 2), "shape_out": (2, 2),
|
||||
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f32",
|
||||
"addr_space": "tcm",
|
||||
},
|
||||
),
|
||||
OpRecord(
|
||||
t_start=50.0, t_end=55.0,
|
||||
component_id="pe_math",
|
||||
op_kind="math", op_name="exp",
|
||||
params={
|
||||
"op": "exp",
|
||||
"input_addrs": [0x200], "input_shapes": [(2, 2)],
|
||||
"dst_addr": 0x300, "shape_out": (2, 2),
|
||||
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
executor = DataExecutor(ops, store)
|
||||
executor.run()
|
||||
|
||||
gemm_result = store.read("tcm", 0x200)
|
||||
expected_gemm = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float32)
|
||||
assert np.allclose(gemm_result, expected_gemm)
|
||||
|
||||
exp_result = store.read("tcm", 0x300)
|
||||
assert np.allclose(exp_result, np.exp(expected_gemm))
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Tests for KernelRunner greenlet-based execution (ADR-0020 D3)."""
|
||||
import numpy as np
|
||||
import simpy
|
||||
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
from kernbench.triton_emu.kernel_runner import KernelRunner
|
||||
|
||||
|
||||
def _make_runner(env, store=None):
|
||||
"""Create a minimal KernelRunner with mock scheduler port."""
|
||||
scheduler_id = "sip0.cube0.pe0.pe_scheduler"
|
||||
out_ports = {scheduler_id: simpy.Store(env)}
|
||||
runner = KernelRunner(
|
||||
pe_prefix="sip0.cube0.pe0",
|
||||
pe_idx=0, sip_idx=0, cube_idx=0,
|
||||
scheduler_id=scheduler_id,
|
||||
out_ports=out_ports,
|
||||
store=store,
|
||||
)
|
||||
return runner, out_ports[scheduler_id]
|
||||
|
||||
|
||||
def _mock_scheduler(env, inbox):
|
||||
"""Consume PeInternalTxn from inbox and immediately succeed."""
|
||||
while True:
|
||||
pe_txn = yield inbox.get()
|
||||
pe_txn.done.succeed()
|
||||
|
||||
|
||||
def test_kernel_runner_basic_load():
|
||||
"""Kernel with tl.load runs through greenlet without hanging."""
|
||||
env = simpy.Environment()
|
||||
store = MemoryStore()
|
||||
data = np.ones((4, 4), dtype=np.float16)
|
||||
store.write("hbm", 0x1000, data)
|
||||
|
||||
runner, sched_port = _make_runner(env, store)
|
||||
env.process(_mock_scheduler(env, sched_port))
|
||||
|
||||
def kernel(a_ptr, tl):
|
||||
a = tl.load(a_ptr, (4, 4), "f16")
|
||||
assert a.data is not None
|
||||
assert a.data.shape == (4, 4)
|
||||
|
||||
def run():
|
||||
yield from runner.run(env, kernel, [0x1000], num_programs=1)
|
||||
|
||||
env.process(run())
|
||||
env.run()
|
||||
|
||||
|
||||
def test_kernel_runner_load_returns_data():
|
||||
"""tl.load returns actual numpy data from MemoryStore."""
|
||||
env = simpy.Environment()
|
||||
store = MemoryStore()
|
||||
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16)
|
||||
store.write("hbm", 0x2000, data)
|
||||
|
||||
runner, sched_port = _make_runner(env, store)
|
||||
env.process(_mock_scheduler(env, sched_port))
|
||||
|
||||
results = {}
|
||||
|
||||
def kernel(ptr, tl):
|
||||
a = tl.load(ptr, (2, 2), "f16")
|
||||
results["data"] = a.data
|
||||
|
||||
def run():
|
||||
yield from runner.run(env, kernel, [0x2000], num_programs=1)
|
||||
|
||||
env.process(run())
|
||||
env.run()
|
||||
assert results["data"] is data # reference equality
|
||||
|
||||
|
||||
def test_kernel_runner_composite():
|
||||
"""Composite commands pass through without blocking kernel."""
|
||||
env = simpy.Environment()
|
||||
runner, sched_port = _make_runner(env)
|
||||
env.process(_mock_scheduler(env, sched_port))
|
||||
|
||||
def kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||
a = tl.ref(a_ptr, (4, 8), "f16")
|
||||
b = tl.ref(b_ptr, (8, 4), "f16")
|
||||
h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||
tl.wait(h)
|
||||
|
||||
def run():
|
||||
yield from runner.run(env, kernel, [0, 64, 128], num_programs=1)
|
||||
|
||||
env.process(run())
|
||||
env.run()
|
||||
|
||||
|
||||
def test_kernel_runner_dynamic_branch():
|
||||
"""Kernel can branch based on loaded data (ADR-0020 D3)."""
|
||||
env = simpy.Environment()
|
||||
store = MemoryStore()
|
||||
store.write("hbm", 0x100, np.array([1.0], dtype=np.float32))
|
||||
store.write("hbm", 0x200, np.array([0.0], dtype=np.float32))
|
||||
|
||||
runner, sched_port = _make_runner(env, store)
|
||||
env.process(_mock_scheduler(env, sched_port))
|
||||
|
||||
results = {"branch": None}
|
||||
|
||||
def kernel(flag_ptr, tl):
|
||||
flag = tl.load(flag_ptr, (1,), "f32")
|
||||
if flag.data is not None and flag.data[0] > 0.5:
|
||||
results["branch"] = "taken"
|
||||
else:
|
||||
results["branch"] = "not_taken"
|
||||
|
||||
# Test with flag=1.0 → branch taken
|
||||
def run():
|
||||
yield from runner.run(env, kernel, [0x100], num_programs=1)
|
||||
|
||||
env.process(run())
|
||||
env.run()
|
||||
assert results["branch"] == "taken"
|
||||
|
||||
|
||||
def test_kernel_runner_no_store():
|
||||
"""Without MemoryStore, tl.load returns handle with data=None."""
|
||||
env = simpy.Environment()
|
||||
runner, sched_port = _make_runner(env, store=None)
|
||||
env.process(_mock_scheduler(env, sched_port))
|
||||
|
||||
results = {}
|
||||
|
||||
def kernel(ptr, tl):
|
||||
a = tl.load(ptr, (4,), "f16")
|
||||
results["data"] = a.data
|
||||
|
||||
def run():
|
||||
yield from runner.run(env, kernel, [0], num_programs=1)
|
||||
|
||||
env.process(run())
|
||||
env.run()
|
||||
assert results["data"] is None
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Tests for MemoryStore (ADR-0020 D7)."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
|
||||
|
||||
def test_write_read_reference():
|
||||
"""Write and read return the same numpy array (no copy)."""
|
||||
store = MemoryStore()
|
||||
data = np.ones((4, 4), dtype=np.float16)
|
||||
store.write("tcm", 0x0, data)
|
||||
result = store.read("tcm", 0x0)
|
||||
assert result is data
|
||||
|
||||
|
||||
def test_overwrite_replaces():
|
||||
"""Same addr write replaces the previous tensor."""
|
||||
store = MemoryStore()
|
||||
data1 = np.zeros((4,), dtype=np.float32)
|
||||
data2 = np.ones((4,), dtype=np.float32)
|
||||
store.write("hbm", 0x100, data1)
|
||||
store.write("hbm", 0x100, data2)
|
||||
result = store.read("hbm", 0x100)
|
||||
assert result is data2
|
||||
|
||||
|
||||
def test_read_missing_raises():
|
||||
store = MemoryStore()
|
||||
with pytest.raises(KeyError):
|
||||
store.read("hbm", 0x999)
|
||||
|
||||
|
||||
def test_read_different_space():
|
||||
store = MemoryStore()
|
||||
data = np.array([1, 2, 3], dtype=np.int32)
|
||||
store.write("tcm", 0x0, data)
|
||||
with pytest.raises(KeyError):
|
||||
store.read("hbm", 0x0) # different space
|
||||
|
||||
|
||||
def test_dtype_reinterpret():
|
||||
"""Read with different dtype does view cast."""
|
||||
store = MemoryStore()
|
||||
data = np.array([1.0, 2.0], dtype=np.float32) # 8 bytes
|
||||
store.write("tcm", 0x0, data)
|
||||
result = store.read("tcm", 0x0, dtype="u8")
|
||||
assert result.dtype == np.uint8
|
||||
assert result.nbytes == data.nbytes
|
||||
|
||||
|
||||
def test_reshape():
|
||||
store = MemoryStore()
|
||||
data = np.arange(12, dtype=np.float32)
|
||||
store.write("tcm", 0x0, data)
|
||||
result = store.read("tcm", 0x0, shape=(3, 4))
|
||||
assert result.shape == (3, 4)
|
||||
|
||||
|
||||
def test_shape_mismatch_raises():
|
||||
store = MemoryStore()
|
||||
data = np.arange(12, dtype=np.float32)
|
||||
store.write("tcm", 0x0, data)
|
||||
with pytest.raises(ValueError, match="Shape mismatch"):
|
||||
store.read("tcm", 0x0, shape=(5, 5))
|
||||
|
||||
|
||||
def test_has():
|
||||
store = MemoryStore()
|
||||
assert not store.has("tcm", 0x0)
|
||||
store.write("tcm", 0x0, np.array([1]))
|
||||
assert store.has("tcm", 0x0)
|
||||
|
||||
|
||||
def test_snapshot():
|
||||
store = MemoryStore()
|
||||
data = np.ones((4,), dtype=np.float16)
|
||||
store.write("hbm", 0x0, data)
|
||||
|
||||
snap = store.snapshot()
|
||||
assert snap.read("hbm", 0x0) is data # same reference
|
||||
|
||||
# Modifying snap doesn't affect original
|
||||
snap.write("hbm", 0x0, np.zeros((4,), dtype=np.float16))
|
||||
assert store.read("hbm", 0x0) is data
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5)."""
|
||||
import numpy as np
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle,
|
||||
)
|
||||
from kernbench.sim_engine.op_log import OpLogger, OpRecord
|
||||
|
||||
|
||||
def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"):
|
||||
return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32)
|
||||
|
||||
|
||||
def test_op_logger_record_start_end():
|
||||
logger = OpLogger()
|
||||
cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||
assert len(logger.records) == 1
|
||||
r = logger.records[0]
|
||||
assert r.t_start == 10.0
|
||||
assert r.t_end == 15.0
|
||||
assert r.op_kind == "memory"
|
||||
assert r.op_name == "dma_read"
|
||||
assert r.params["src_addr"] == 0x1000
|
||||
|
||||
|
||||
def test_op_logger_gemm():
|
||||
logger = OpLogger()
|
||||
a = _th("a", 0, (128, 256), "f16")
|
||||
b = _th("b", 1024, (256, 128), "f16")
|
||||
out = _th("out", 2048, (128, 128), "f16")
|
||||
cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128)
|
||||
logger.record_start(0.0, "pe_gemm", cmd)
|
||||
logger.record_end(100.0, "pe_gemm", cmd)
|
||||
r = logger.records[0]
|
||||
assert r.op_kind == "gemm"
|
||||
assert r.op_name == "gemm_f16"
|
||||
assert r.params["m"] == 128
|
||||
|
||||
|
||||
def test_op_logger_math():
|
||||
logger = OpLogger()
|
||||
x = _th("x", 0, (32,), "f32")
|
||||
out = _th("out", 128, (32,), "f32")
|
||||
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
||||
logger.record_start(5.0, "pe_math", cmd)
|
||||
logger.record_end(6.0, "pe_math", cmd)
|
||||
r = logger.records[0]
|
||||
assert r.op_kind == "math"
|
||||
assert r.op_name == "exp"
|
||||
|
||||
|
||||
def test_op_logger_stable_ordering():
|
||||
logger = OpLogger()
|
||||
cmds = [
|
||||
DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64)
|
||||
for i in range(5)
|
||||
]
|
||||
for i, cmd in enumerate(cmds):
|
||||
logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start
|
||||
for i, cmd in enumerate(cmds):
|
||||
logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd)
|
||||
|
||||
# Verify insertion order preserved for same t_start
|
||||
for i in range(len(logger.records) - 1):
|
||||
assert logger.records[i].t_start <= logger.records[i + 1].t_start
|
||||
|
||||
|
||||
def test_op_logger_unmatched_end_ignored():
|
||||
logger = OpLogger()
|
||||
cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32)
|
||||
logger.record_end(5.0, "comp", cmd) # no matching start
|
||||
assert len(logger.records) == 0
|
||||
|
||||
|
||||
def test_data_op_flag():
|
||||
"""DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't."""
|
||||
assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False)
|
||||
assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False)
|
||||
a, b, out = _th("a"), _th("b"), _th("out")
|
||||
assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False)
|
||||
assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False)
|
||||
# WaitCmd and PeCpuOverheadCmd should not have data_op
|
||||
from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd
|
||||
assert not getattr(WaitCmd(), "data_op", False)
|
||||
assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)
|
||||
Reference in New Issue
Block a user