"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5).""" import numpy as np from kernbench.common.pe_commands import ( DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle, ) from kernbench.sim_engine.op_log import OpLogger, OpRecord def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"): return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32) def test_op_logger_record_start_end(): logger = OpLogger() cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64) logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd) logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd) assert len(logger.records) == 1 r = logger.records[0] assert r.t_start == 10.0 assert r.t_end == 15.0 assert r.op_kind == "memory" assert r.op_name == "dma_read" assert r.params["src_addr"] == 0x1000 def test_op_logger_gemm(): logger = OpLogger() a = _th("a", 0, (128, 256), "f16") b = _th("b", 1024, (256, 128), "f16") out = _th("out", 2048, (128, 128), "f16") cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128) logger.record_start(0.0, "pe_gemm", cmd) logger.record_end(100.0, "pe_gemm", cmd) r = logger.records[0] assert r.op_kind == "gemm" assert r.op_name == "gemm_f16" assert r.params["m"] == 128 def test_op_logger_math(): logger = OpLogger() x = _th("x", 0, (32,), "f32") out = _th("out", 128, (32,), "f32") cmd = MathCmd(op="exp", inputs=(x,), out=out) logger.record_start(5.0, "pe_math", cmd) logger.record_end(6.0, "pe_math", cmd) r = logger.records[0] assert r.op_kind == "math" assert r.op_name == "exp" def test_op_logger_stable_ordering(): logger = OpLogger() cmds = [ DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64) for i in range(5) ] for i, cmd in enumerate(cmds): logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start for i, cmd in enumerate(cmds): logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd) # Verify insertion order preserved for same t_start for i in range(len(logger.records) - 1): assert logger.records[i].t_start <= logger.records[i + 1].t_start def test_op_logger_unmatched_end_ignored(): logger = OpLogger() cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32) logger.record_end(5.0, "comp", cmd) # no matching start assert len(logger.records) == 0 def test_data_op_flag(): """DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't.""" assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False) assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False) a, b, out = _th("a"), _th("b"), _th("out") assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False) assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False) # WaitCmd and PeCpuOverheadCmd should not have data_op from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd assert not getattr(WaitCmd(), "data_op", False) assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)