"""Tests for Triton emulator: TLContext, command generation, kernel registry.""" from kernbench.common.pe_commands import ( CompletionHandle, CompositeCmd, DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, PeCpuOverheadCmd, TensorHandle, WaitCmd, ) from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel from kernbench.triton_emu.tl_context import TLContext, run_kernel def _ctx(**kwargs) -> TLContext: return TLContext(dispatch_cycles=0, **kwargs) def _ctx_with_overhead(**kwargs) -> TLContext: return TLContext(dispatch_cycles=1, **kwargs) # ── 1. tl.load → DmaReadCmd ────────────────────────────────────── def test_tl_load_generates_dma_read(): tl = _ctx() h = tl.load(0x1000, shape=(32, 64), dtype="f16") assert isinstance(h, TensorHandle) assert h.shape == (32, 64) assert h.nbytes == 32 * 64 * 2 cmds = tl.commands assert len(cmds) == 1 assert isinstance(cmds[0], DmaReadCmd) assert cmds[0].src_addr == 0x1000 assert cmds[0].nbytes == 32 * 64 * 2 # ── 2. tl.store → DmaWriteCmd ──────────────────────────────────── def test_tl_store_generates_dma_write(): tl = _ctx() h = tl.load(0x1000, shape=(16, 16), dtype="f32") tl.store(0x2000, h) cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)] assert len(cmds) == 1 assert cmds[0].dst_addr == 0x2000 assert cmds[0].nbytes == 16 * 16 * 4 # ── 3. tl.dot → GemmCmd ────────────────────────────────────────── def test_tl_dot_generates_gemm_cmd(): tl = _ctx() a = tl.load(0x1000, shape=(32, 64), dtype="f16") b = tl.load(0x2000, shape=(64, 16), dtype="f16") out = tl.dot(a, b) assert out.shape == (32, 16) cmds = [c for c in tl.commands if isinstance(c, GemmCmd)] assert len(cmds) == 1 assert cmds[0].m == 32 assert cmds[0].k == 64 assert cmds[0].n == 16 # ── 4. tl.exp, tl.sqrt etc. → MathCmd ──────────────────────────── def test_tl_math_unary_ops(): tl = _ctx() x = tl.load(0x1000, shape=(8, 8), dtype="f16") for op_name, op_fn in [ ("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt), ("abs", tl.abs), ("sigmoid", tl.sigmoid), ("cos", tl.cos), ("sin", tl.sin), ]: result = op_fn(x) assert isinstance(result, TensorHandle) assert result.shape == x.shape math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)] ops = [c.op for c in math_cmds] assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"] def test_tl_math_extra_ops(): """tl.maximum/minimum/fma/clamp/softmax + tl.cdiv (real-Triton parity).""" tl = _ctx() a = tl.load(0x1000, shape=(8, 8), dtype="f16") b = tl.load(0x2000, shape=(8, 8), dtype="f16") c = tl.load(0x3000, shape=(8, 8), dtype="f16") tl.maximum(a, b) tl.minimum(a, b) tl.fma(a, b, c) tl.clamp(a, b, c) tl.softmax(a, axis=1) math_cmds = [cm for cm in tl.commands if isinstance(cm, MathCmd)] ops = [cm.op for cm in math_cmds] assert ops == ["maximum", "minimum", "fma", "clamp", "softmax"] # ternary fma/clamp must record three inputs fma_cmd = math_cmds[2] assert len(fma_cmd.inputs) == 3 clamp_cmd = math_cmds[3] assert len(clamp_cmd.inputs) == 3 # softmax records the axis assert math_cmds[4].axis == 1 # cdiv is a scalar helper, not a tensor op from kernbench.triton_emu.tl_context import TLContext assert TLContext.cdiv(10, 3) == 4 assert TLContext.cdiv(9, 3) == 3 assert TLContext.cdiv(0, 4) == 0 # ── 5. a + b, a * b → MathCmd ──────────────────────────────────── def test_tl_math_binary_ops(): tl = _ctx() a = tl.load(0x1000, shape=(4, 4), dtype="f16") b = tl.load(0x2000, shape=(4, 4), dtype="f16") r1 = run_kernel(lambda tl: None, tl) # activate context for operators # Need active context for operators tl2 = _ctx() a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16") b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16") def kernel(tl): pass # Use run_kernel to activate context, then test operators tl3 = _ctx() def binary_kernel(tl): a = tl.load(0x1000, shape=(4, 4), dtype="f16") b = tl.load(0x2000, shape=(4, 4), dtype="f16") _ = a + b _ = a - b _ = a * b _ = a / b run_kernel(binary_kernel, tl3) math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)] ops = [c.op for c in math_cmds] assert ops == ["add", "sub", "mul", "div"] # ── 6. tl.sum, tl.max → MathCmd with axis ──────────────────────── def test_tl_reduction_ops(): tl = _ctx() x = tl.load(0x1000, shape=(32, 64), dtype="f16") s = tl.sum(x, axis=1) m = tl.max(x, axis=0) assert s.shape == (32, 1) assert m.shape == (1, 64) math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)] assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1 assert math_cmds[1].op == "max" and math_cmds[1].axis == 0 # ── 7. tl.composite → CompositeCmd + CompletionHandle ──────────── def test_tl_composite_nonblocking(): tl = _ctx() a = tl.load(0x1000, shape=(32, 64), dtype="f16") b = tl.load(0x2000, shape=(64, 32), dtype="f16") h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000) assert isinstance(h, CompletionHandle) comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)] assert len(comp_cmds) == 1 assert comp_cmds[0].op == "gemm" assert comp_cmds[0].out_addr == 0x3000 assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes # ── 8. tl.wait(handle) → WaitCmd ───────────────────────────────── def test_tl_wait_specific(): tl = _ctx() a = tl.load(0x1000, shape=(4, 4), dtype="f16") h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000) tl.wait(h) wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)] assert len(wait_cmds) == 1 assert wait_cmds[0].handle == h # ── 9. tl.wait() → WaitCmd(handle=None) ────────────────────────── def test_tl_wait_all(): tl = _ctx() tl.wait() wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)] assert len(wait_cmds) == 1 assert wait_cmds[0].handle is None # ── 10. tl.cycles → PeCpuOverheadCmd ───────────────────────────── def test_tl_cycles(): tl = _ctx() tl.cycles(10) assert len(tl.commands) == 1 assert isinstance(tl.commands[0], PeCpuOverheadCmd) assert tl.commands[0].cycles == 10 # ── 11. tl.program_id ──────────────────────────────────────────── def test_tl_program_id(): tl = TLContext(pe_id=5, num_programs=8) assert tl.program_id(0) == 5 assert tl.num_programs(0) == 8 def test_tl_program_id_axis1(): """axis=1 returns cube_id and num_cubes.""" tl = TLContext(pe_id=3, num_programs=8, cube_id=7, num_cubes=16) assert tl.program_id(0) == 3 assert tl.program_id(1) == 7 assert tl.num_programs(0) == 8 assert tl.num_programs(1) == 16 def test_tl_program_id_global(): """global_pid = cube_id * num_pes_per_cube + local_pe_id.""" pe_id, cube_id, num_pes = 5, 3, 8 tl = TLContext(pe_id=pe_id, num_programs=num_pes, cube_id=cube_id, num_cubes=16) global_pid = tl.program_id(1) * tl.num_programs(0) + tl.program_id(0) assert global_pid == cube_id * num_pes + pe_id # ── 12. tl.arange, tl.zeros, tl.full ───────────────────────────── def test_tl_arange_zeros_full(): tl = _ctx() r = tl.arange(0, 16, dtype="i32") assert r.shape == (16,) assert r.dtype == "i32" z = tl.zeros((4, 8), dtype="f16") assert z.shape == (4, 8) assert z.nbytes == 4 * 8 * 2 f = tl.full((2, 3), value=1.0, dtype="f32") assert f.shape == (2, 3) assert f.nbytes == 2 * 3 * 4 # ── 13. tl.trans → shape change, no command ─────────────────────── def test_tl_trans_shape(): tl = _ctx() h = tl.load(0x1000, shape=(32, 64), dtype="f16") t = tl.trans(h) assert t.shape == (64, 32) assert t.id == h.id # same underlying data # Only DmaReadCmd from load, no command from trans assert len(tl.commands) == 1 assert isinstance(tl.commands[0], DmaReadCmd) # ── 14. Kernel registry ────────────────────────────────────────── def test_kernel_registry(): clear_registry() def my_kernel(tl): pass register_kernel("test_kern", my_kernel) assert get_kernel("test_kern") is my_kernel clear_registry() def test_kernel_registry_missing(): clear_registry() import pytest with pytest.raises(KeyError): get_kernel("nonexistent") def test_kernel_registry_duplicate(): clear_registry() register_kernel("dup", lambda tl: None) import pytest with pytest.raises(ValueError): register_kernel("dup", lambda tl: None) clear_registry() # ── 15. GEMM kernel → correct command sequence ─────────────────── def test_gemm_kernel_command_sequence(): """32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite].""" def gemm_kernel(a_ptr, b_ptr, out_ptr, tl): pid = tl.program_id(0) a = tl.load(a_ptr, shape=(32, 64), dtype="f16") b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16") tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2) tl = _ctx(pe_id=3) run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000) types = [type(c).__name__ for c in tl.commands] assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"] # ── 16. Attention kernel → correct command sequence ─────────────── def test_attention_kernel_command_sequence(): """Attention kernel: load→dot→math ops→dot→store.""" def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl, seq_len=16, head_dim=8): pid = tl.program_id(0) q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16") k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16") scores = tl.dot(q, k) row_max = tl.max(scores, axis=1) scores = scores - row_max scores = tl.exp(scores) row_sum = tl.sum(scores, axis=1) scores = scores / row_sum v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16") out = tl.dot(scores, v) tl.store(out_ptr, out) tl = _ctx(pe_id=0) run_kernel( attention_kernel, tl, q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000, ) types = [type(c).__name__ for c in tl.commands] # load, load, dot, max, sub, exp, sum, div, load, dot, store assert types == [ "DmaReadCmd", "DmaReadCmd", # load Q, K "GemmCmd", # Q @ K "MathCmd", "MathCmd", "MathCmd", # max, sub, exp "MathCmd", "MathCmd", # sum, div "DmaReadCmd", # load V "GemmCmd", # scores @ V "DmaWriteCmd", # store output ] # Verify math ops math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)] math_ops = [c.op for c in math_cmds] assert math_ops == ["max", "sub", "exp", "sum", "div"] # ── 17. Dispatch overhead auto-inserted ─────────────────────────── def test_dispatch_overhead_inserted(): """Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0.""" tl = _ctx_with_overhead() a = tl.load(0x1000, shape=(4, 4), dtype="f16") tl.store(0x2000, a) types = [type(c).__name__ for c in tl.commands] # overhead, load, overhead, store assert types == [ "PeCpuOverheadCmd", "DmaReadCmd", "PeCpuOverheadCmd", "DmaWriteCmd", ] # ── 18. where operation ────────────────────────────────────────── def test_tl_where(): tl = _ctx() cond = tl.load(0x1000, shape=(4, 4), dtype="i32") a = tl.load(0x2000, shape=(4, 4), dtype="f16") b = tl.load(0x3000, shape=(4, 4), dtype="f16") out = tl.where(cond, a, b) assert isinstance(out, TensorHandle) math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)] assert len(math_cmds) == 1 assert math_cmds[0].op == "where" assert len(math_cmds[0].inputs) == 3