commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+349
View File
@@ -0,0 +1,349 @@
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
def _ctx(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=0, **kwargs)
def _ctx_with_overhead(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=1, **kwargs)
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
def test_tl_load_generates_dma_read():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
assert isinstance(h, TensorHandle)
assert h.shape == (32, 64)
assert h.nbytes == 32 * 64 * 2
cmds = tl.commands
assert len(cmds) == 1
assert isinstance(cmds[0], DmaReadCmd)
assert cmds[0].src_pa == 0x1000
assert cmds[0].nbytes == 32 * 64 * 2
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
def test_tl_store_generates_dma_write():
tl = _ctx()
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
tl.store(0x2000, h)
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(cmds) == 1
assert cmds[0].dst_pa == 0x2000
assert cmds[0].nbytes == 16 * 16 * 4
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
def test_tl_dot_generates_gemm_cmd():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
out = tl.dot(a, b)
assert out.shape == (32, 16)
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
assert len(cmds) == 1
assert cmds[0].m == 32
assert cmds[0].k == 64
assert cmds[0].n == 16
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
def test_tl_math_unary_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
for op_name, op_fn in [
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
("abs", tl.abs), ("sigmoid", tl.sigmoid),
("cos", tl.cos), ("sin", tl.sin),
]:
result = op_fn(x)
assert isinstance(result, TensorHandle)
assert result.shape == x.shape
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
def test_tl_math_binary_ops():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
# Need active context for operators
tl2 = _ctx()
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
def kernel(tl):
pass
# Use run_kernel to activate context, then test operators
tl3 = _ctx()
def binary_kernel(tl):
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
_ = a + b
_ = a - b
_ = a * b
_ = a / b
run_kernel(binary_kernel, tl3)
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["add", "sub", "mul", "div"]
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
def test_tl_reduction_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
s = tl.sum(x, axis=1)
m = tl.max(x, axis=0)
assert s.shape == (32, 1)
assert m.shape == (1, 64)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
def test_tl_composite_nonblocking():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
assert isinstance(h, CompletionHandle)
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
assert len(comp_cmds) == 1
assert comp_cmds[0].op == "gemm"
assert comp_cmds[0].out_pa == 0x3000
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
def test_tl_wait_specific():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
tl.wait(h)
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle == h
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
def test_tl_wait_all():
tl = _ctx()
tl.wait()
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle is None
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
def test_tl_cycles():
tl = _ctx()
tl.cycles(10)
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
assert tl.commands[0].cycles == 10
# ── 11. tl.program_id ────────────────────────────────────────────
def test_tl_program_id():
tl = TLContext(pe_id=5, num_programs=8)
assert tl.program_id(0) == 5
assert tl.num_programs(0) == 8
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
def test_tl_arange_zeros_full():
tl = _ctx()
r = tl.arange(0, 16, dtype="i32")
assert r.shape == (16,)
assert r.dtype == "i32"
z = tl.zeros((4, 8), dtype="f16")
assert z.shape == (4, 8)
assert z.nbytes == 4 * 8 * 2
f = tl.full((2, 3), value=1.0, dtype="f32")
assert f.shape == (2, 3)
assert f.nbytes == 2 * 3 * 4
# ── 13. tl.trans → shape change, no command ───────────────────────
def test_tl_trans_shape():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
t = tl.trans(h)
assert t.shape == (64, 32)
assert t.id == h.id # same underlying data
# Only DmaReadCmd from load, no command from trans
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], DmaReadCmd)
# ── 14. Kernel registry ──────────────────────────────────────────
def test_kernel_registry():
clear_registry()
def my_kernel(tl):
pass
register_kernel("test_kern", my_kernel)
assert get_kernel("test_kern") is my_kernel
clear_registry()
def test_kernel_registry_missing():
clear_registry()
import pytest
with pytest.raises(KeyError):
get_kernel("nonexistent")
def test_kernel_registry_duplicate():
clear_registry()
register_kernel("dup", lambda tl: None)
import pytest
with pytest.raises(ValueError):
register_kernel("dup", lambda tl: None)
clear_registry()
# ── 15. GEMM kernel → correct command sequence ───────────────────
def test_gemm_kernel_command_sequence():
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
tl = _ctx(pe_id=3)
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
types = [type(c).__name__ for c in tl.commands]
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
# ── 16. Attention kernel → correct command sequence ───────────────
def test_attention_kernel_command_sequence():
"""Attention kernel: load→dot→math ops→dot→store."""
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
seq_len=16, head_dim=8):
pid = tl.program_id(0)
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
scores = tl.dot(q, k)
row_max = tl.max(scores, axis=1)
scores = scores - row_max
scores = tl.exp(scores)
row_sum = tl.sum(scores, axis=1)
scores = scores / row_sum
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
out = tl.dot(scores, v)
tl.store(out_ptr, out)
tl = _ctx(pe_id=0)
run_kernel(
attention_kernel, tl,
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
)
types = [type(c).__name__ for c in tl.commands]
# load, load, dot, max, sub, exp, sum, div, load, dot, store
assert types == [
"DmaReadCmd", "DmaReadCmd", # load Q, K
"GemmCmd", # Q @ K
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
"MathCmd", "MathCmd", # sum, div
"DmaReadCmd", # load V
"GemmCmd", # scores @ V
"DmaWriteCmd", # store output
]
# Verify math ops
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
math_ops = [c.op for c in math_cmds]
assert math_ops == ["max", "sub", "exp", "sum", "div"]
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
def test_dispatch_overhead_inserted():
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
tl = _ctx_with_overhead()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
tl.store(0x2000, a)
types = [type(c).__name__ for c in tl.commands]
# overhead, load, overhead, store
assert types == [
"PeCpuOverheadCmd", "DmaReadCmd",
"PeCpuOverheadCmd", "DmaWriteCmd",
]
# ── 18. where operation ──────────────────────────────────────────
def test_tl_where():
tl = _ctx()
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
out = tl.where(cond, a, b)
assert isinstance(out, TensorHandle)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert len(math_cmds) == 1
assert math_cmds[0].op == "where"
assert len(math_cmds[0].inputs) == 3