Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5)."""
|
||||
import numpy as np
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle,
|
||||
)
|
||||
from kernbench.sim_engine.op_log import OpLogger, OpRecord
|
||||
|
||||
|
||||
def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"):
|
||||
return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32)
|
||||
|
||||
|
||||
def test_op_logger_record_start_end():
|
||||
logger = OpLogger()
|
||||
cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||
assert len(logger.records) == 1
|
||||
r = logger.records[0]
|
||||
assert r.t_start == 10.0
|
||||
assert r.t_end == 15.0
|
||||
assert r.op_kind == "memory"
|
||||
assert r.op_name == "dma_read"
|
||||
assert r.params["src_addr"] == 0x1000
|
||||
|
||||
|
||||
def test_op_logger_gemm():
|
||||
logger = OpLogger()
|
||||
a = _th("a", 0, (128, 256), "f16")
|
||||
b = _th("b", 1024, (256, 128), "f16")
|
||||
out = _th("out", 2048, (128, 128), "f16")
|
||||
cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128)
|
||||
logger.record_start(0.0, "pe_gemm", cmd)
|
||||
logger.record_end(100.0, "pe_gemm", cmd)
|
||||
r = logger.records[0]
|
||||
assert r.op_kind == "gemm"
|
||||
assert r.op_name == "gemm_f16"
|
||||
assert r.params["m"] == 128
|
||||
|
||||
|
||||
def test_op_logger_math():
|
||||
logger = OpLogger()
|
||||
x = _th("x", 0, (32,), "f32")
|
||||
out = _th("out", 128, (32,), "f32")
|
||||
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
||||
logger.record_start(5.0, "pe_math", cmd)
|
||||
logger.record_end(6.0, "pe_math", cmd)
|
||||
r = logger.records[0]
|
||||
assert r.op_kind == "math"
|
||||
assert r.op_name == "exp"
|
||||
|
||||
|
||||
def test_op_logger_stable_ordering():
|
||||
logger = OpLogger()
|
||||
cmds = [
|
||||
DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64)
|
||||
for i in range(5)
|
||||
]
|
||||
for i, cmd in enumerate(cmds):
|
||||
logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start
|
||||
for i, cmd in enumerate(cmds):
|
||||
logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd)
|
||||
|
||||
# Verify insertion order preserved for same t_start
|
||||
for i in range(len(logger.records) - 1):
|
||||
assert logger.records[i].t_start <= logger.records[i + 1].t_start
|
||||
|
||||
|
||||
def test_op_logger_unmatched_end_ignored():
|
||||
logger = OpLogger()
|
||||
cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32)
|
||||
logger.record_end(5.0, "comp", cmd) # no matching start
|
||||
assert len(logger.records) == 0
|
||||
|
||||
|
||||
def test_data_op_flag():
|
||||
"""DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't."""
|
||||
assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False)
|
||||
assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False)
|
||||
a, b, out = _th("a"), _th("b"), _th("out")
|
||||
assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False)
|
||||
assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False)
|
||||
# WaitCmd and PeCpuOverheadCmd should not have data_op
|
||||
from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd
|
||||
assert not getattr(WaitCmd(), "data_op", False)
|
||||
assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)
|
||||
Reference in New Issue
Block a user