Files
ywkang 51004c311c Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation:
- OpRecord/OpLogger: op log infrastructure with t_start stable ordering
- MemoryStore: numpy ndarray tensor-granular storage (reference semantics)
- data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd
- numpy/greenlet dependencies added to pyproject.toml

Step 2 — ComponentBase hooks:
- _on_process_start/end hooks in _forward_txn (fabric messages)
- _handle_with_hooks in PeEngineBase (PE-internal commands)
- op_logger optional — zero overhead when disabled

Step 3 — KernelRunner + greenlet:
- KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py
- TLContext: _emit() method routes to greenlet switch or command list
- tl.load() returns real numpy data in greenlet mode
- Dynamic control flow supported (memory-read based branching)

Step 4 — PE_CPU integration:
- Greenlet mode when ctx.memory_store is set, legacy fallback otherwise
- Refactored into _execute_greenlet/_execute_legacy/_send_response
- ComponentContext gains memory_store and op_logger fields

Step 5 — DataExecutor:
- Phase 2 numpy execution for GEMM/Math ops from op_log
- _compute_math: all unary/binary/reduction ops
- verify(): compare MemoryStore against expected with dtype tolerance

28 new tests, 366 total passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-08 00:22:44 -07:00

88 lines
3.1 KiB
Python

"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5)."""
import numpy as np
from kernbench.common.pe_commands import (
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle,
)
from kernbench.sim_engine.op_log import OpLogger, OpRecord
def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"):
return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32)
def test_op_logger_record_start_end():
logger = OpLogger()
cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64)
logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd)
logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd)
assert len(logger.records) == 1
r = logger.records[0]
assert r.t_start == 10.0
assert r.t_end == 15.0
assert r.op_kind == "memory"
assert r.op_name == "dma_read"
assert r.params["src_addr"] == 0x1000
def test_op_logger_gemm():
logger = OpLogger()
a = _th("a", 0, (128, 256), "f16")
b = _th("b", 1024, (256, 128), "f16")
out = _th("out", 2048, (128, 128), "f16")
cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128)
logger.record_start(0.0, "pe_gemm", cmd)
logger.record_end(100.0, "pe_gemm", cmd)
r = logger.records[0]
assert r.op_kind == "gemm"
assert r.op_name == "gemm_f16"
assert r.params["m"] == 128
def test_op_logger_math():
logger = OpLogger()
x = _th("x", 0, (32,), "f32")
out = _th("out", 128, (32,), "f32")
cmd = MathCmd(op="exp", inputs=(x,), out=out)
logger.record_start(5.0, "pe_math", cmd)
logger.record_end(6.0, "pe_math", cmd)
r = logger.records[0]
assert r.op_kind == "math"
assert r.op_name == "exp"
def test_op_logger_stable_ordering():
logger = OpLogger()
cmds = [
DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64)
for i in range(5)
]
for i, cmd in enumerate(cmds):
logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start
for i, cmd in enumerate(cmds):
logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd)
# Verify insertion order preserved for same t_start
for i in range(len(logger.records) - 1):
assert logger.records[i].t_start <= logger.records[i + 1].t_start
def test_op_logger_unmatched_end_ignored():
logger = OpLogger()
cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32)
logger.record_end(5.0, "comp", cmd) # no matching start
assert len(logger.records) == 0
def test_data_op_flag():
"""DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't."""
assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False)
assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False)
a, b, out = _th("a"), _th("b"), _th("out")
assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False)
assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False)
# WaitCmd and PeCpuOverheadCmd should not have data_op
from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd
assert not getattr(WaitCmd(), "data_op", False)
assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)