"""Tests for KernelRunner greenlet-based execution (ADR-0020 D3).""" import numpy as np import simpy from kernbench.sim_engine.memory_store import MemoryStore from kernbench.triton_emu.kernel_runner import KernelRunner def _make_runner(env, store=None): """Create a minimal KernelRunner with mock scheduler port.""" scheduler_id = "sip0.cube0.pe0.pe_scheduler" out_ports = {scheduler_id: simpy.Store(env)} runner = KernelRunner( pe_prefix="sip0.cube0.pe0", pe_idx=0, sip_idx=0, cube_idx=0, scheduler_id=scheduler_id, out_ports=out_ports, store=store, ) return runner, out_ports[scheduler_id] def _mock_scheduler(env, inbox): """Consume PeInternalTxn from inbox and immediately succeed.""" while True: pe_txn = yield inbox.get() pe_txn.done.succeed() def test_kernel_runner_basic_load(): """Kernel with tl.load runs through greenlet without hanging.""" env = simpy.Environment() store = MemoryStore() data = np.ones((4, 4), dtype=np.float16) store.write("hbm", 0x1000, data) runner, sched_port = _make_runner(env, store) env.process(_mock_scheduler(env, sched_port)) def kernel(a_ptr, tl): a = tl.load(a_ptr, (4, 4), "f16") assert a.data is not None assert a.data.shape == (4, 4) def run(): yield from runner.run(env, kernel, [0x1000], num_programs=1) env.process(run()) env.run() def test_kernel_runner_load_returns_data(): """tl.load returns actual numpy data from MemoryStore.""" env = simpy.Environment() store = MemoryStore() data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16) store.write("hbm", 0x2000, data) runner, sched_port = _make_runner(env, store) env.process(_mock_scheduler(env, sched_port)) results = {} def kernel(ptr, tl): a = tl.load(ptr, (2, 2), "f16") results["data"] = a.data def run(): yield from runner.run(env, kernel, [0x2000], num_programs=1) env.process(run()) env.run() assert results["data"] is data # reference equality def test_kernel_runner_composite(): """Composite commands pass through without blocking kernel.""" env = simpy.Environment() runner, sched_port = _make_runner(env) env.process(_mock_scheduler(env, sched_port)) def kernel(a_ptr, b_ptr, out_ptr, tl): a = tl.ref(a_ptr, (4, 8), "f16") b = tl.ref(b_ptr, (8, 4), "f16") h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr) tl.wait(h) def run(): yield from runner.run(env, kernel, [0, 64, 128], num_programs=1) env.process(run()) env.run() def test_kernel_runner_dynamic_branch(): """Kernel can branch based on loaded data (ADR-0020 D3).""" env = simpy.Environment() store = MemoryStore() store.write("hbm", 0x100, np.array([1.0], dtype=np.float32)) store.write("hbm", 0x200, np.array([0.0], dtype=np.float32)) runner, sched_port = _make_runner(env, store) env.process(_mock_scheduler(env, sched_port)) results = {"branch": None} def kernel(flag_ptr, tl): flag = tl.load(flag_ptr, (1,), "f32") if flag.data is not None and flag.data[0] > 0.5: results["branch"] = "taken" else: results["branch"] = "not_taken" # Test with flag=1.0 → branch taken def run(): yield from runner.run(env, kernel, [0x100], num_programs=1) env.process(run()) env.run() assert results["branch"] == "taken" def test_kernel_runner_no_store(): """Without MemoryStore, tl.load returns handle with data=None.""" env = simpy.Environment() runner, sched_port = _make_runner(env, store=None) env.process(_mock_scheduler(env, sched_port)) results = {} def kernel(ptr, tl): a = tl.load(ptr, (4,), "f16") results["data"] = a.data def run(): yield from runner.run(env, kernel, [0], num_programs=1) env.process(run()) env.run() assert results["data"] is None