commit - release 1
This commit is contained in:
@@ -0,0 +1,349 @@
|
||||
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
PeCpuOverheadCmd,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
|
||||
def _ctx(**kwargs) -> TLContext:
|
||||
return TLContext(dispatch_cycles=0, **kwargs)
|
||||
|
||||
|
||||
def _ctx_with_overhead(**kwargs) -> TLContext:
|
||||
return TLContext(dispatch_cycles=1, **kwargs)
|
||||
|
||||
|
||||
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_load_generates_dma_read():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
assert isinstance(h, TensorHandle)
|
||||
assert h.shape == (32, 64)
|
||||
assert h.nbytes == 32 * 64 * 2
|
||||
cmds = tl.commands
|
||||
assert len(cmds) == 1
|
||||
assert isinstance(cmds[0], DmaReadCmd)
|
||||
assert cmds[0].src_pa == 0x1000
|
||||
assert cmds[0].nbytes == 32 * 64 * 2
|
||||
|
||||
|
||||
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_store_generates_dma_write():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
|
||||
tl.store(0x2000, h)
|
||||
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
assert len(cmds) == 1
|
||||
assert cmds[0].dst_pa == 0x2000
|
||||
assert cmds[0].nbytes == 16 * 16 * 4
|
||||
|
||||
|
||||
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_dot_generates_gemm_cmd():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
|
||||
out = tl.dot(a, b)
|
||||
assert out.shape == (32, 16)
|
||||
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
|
||||
assert len(cmds) == 1
|
||||
assert cmds[0].m == 32
|
||||
assert cmds[0].k == 64
|
||||
assert cmds[0].n == 16
|
||||
|
||||
|
||||
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
|
||||
|
||||
|
||||
def test_tl_math_unary_ops():
|
||||
tl = _ctx()
|
||||
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
|
||||
for op_name, op_fn in [
|
||||
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
|
||||
("abs", tl.abs), ("sigmoid", tl.sigmoid),
|
||||
("cos", tl.cos), ("sin", tl.sin),
|
||||
]:
|
||||
result = op_fn(x)
|
||||
assert isinstance(result, TensorHandle)
|
||||
assert result.shape == x.shape
|
||||
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
ops = [c.op for c in math_cmds]
|
||||
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
|
||||
|
||||
|
||||
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_math_binary_ops():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
|
||||
|
||||
# Need active context for operators
|
||||
tl2 = _ctx()
|
||||
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
|
||||
def kernel(tl):
|
||||
pass
|
||||
|
||||
# Use run_kernel to activate context, then test operators
|
||||
tl3 = _ctx()
|
||||
|
||||
def binary_kernel(tl):
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
_ = a + b
|
||||
_ = a - b
|
||||
_ = a * b
|
||||
_ = a / b
|
||||
|
||||
run_kernel(binary_kernel, tl3)
|
||||
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
|
||||
ops = [c.op for c in math_cmds]
|
||||
assert ops == ["add", "sub", "mul", "div"]
|
||||
|
||||
|
||||
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
|
||||
|
||||
|
||||
def test_tl_reduction_ops():
|
||||
tl = _ctx()
|
||||
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
s = tl.sum(x, axis=1)
|
||||
m = tl.max(x, axis=0)
|
||||
assert s.shape == (32, 1)
|
||||
assert m.shape == (1, 64)
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
|
||||
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
|
||||
|
||||
|
||||
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
|
||||
|
||||
|
||||
def test_tl_composite_nonblocking():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
|
||||
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
|
||||
assert isinstance(h, CompletionHandle)
|
||||
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
||||
assert len(comp_cmds) == 1
|
||||
assert comp_cmds[0].op == "gemm"
|
||||
assert comp_cmds[0].out_pa == 0x3000
|
||||
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
||||
|
||||
|
||||
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_wait_specific():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
|
||||
tl.wait(h)
|
||||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||||
assert len(wait_cmds) == 1
|
||||
assert wait_cmds[0].handle == h
|
||||
|
||||
|
||||
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
|
||||
|
||||
|
||||
def test_tl_wait_all():
|
||||
tl = _ctx()
|
||||
tl.wait()
|
||||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||||
assert len(wait_cmds) == 1
|
||||
assert wait_cmds[0].handle is None
|
||||
|
||||
|
||||
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
|
||||
|
||||
|
||||
def test_tl_cycles():
|
||||
tl = _ctx()
|
||||
tl.cycles(10)
|
||||
assert len(tl.commands) == 1
|
||||
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
|
||||
assert tl.commands[0].cycles == 10
|
||||
|
||||
|
||||
# ── 11. tl.program_id ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_program_id():
|
||||
tl = TLContext(pe_id=5, num_programs=8)
|
||||
assert tl.program_id(0) == 5
|
||||
assert tl.num_programs(0) == 8
|
||||
|
||||
|
||||
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
|
||||
|
||||
|
||||
def test_tl_arange_zeros_full():
|
||||
tl = _ctx()
|
||||
r = tl.arange(0, 16, dtype="i32")
|
||||
assert r.shape == (16,)
|
||||
assert r.dtype == "i32"
|
||||
|
||||
z = tl.zeros((4, 8), dtype="f16")
|
||||
assert z.shape == (4, 8)
|
||||
assert z.nbytes == 4 * 8 * 2
|
||||
|
||||
f = tl.full((2, 3), value=1.0, dtype="f32")
|
||||
assert f.shape == (2, 3)
|
||||
assert f.nbytes == 2 * 3 * 4
|
||||
|
||||
|
||||
# ── 13. tl.trans → shape change, no command ───────────────────────
|
||||
|
||||
|
||||
def test_tl_trans_shape():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
t = tl.trans(h)
|
||||
assert t.shape == (64, 32)
|
||||
assert t.id == h.id # same underlying data
|
||||
# Only DmaReadCmd from load, no command from trans
|
||||
assert len(tl.commands) == 1
|
||||
assert isinstance(tl.commands[0], DmaReadCmd)
|
||||
|
||||
|
||||
# ── 14. Kernel registry ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_kernel_registry():
|
||||
clear_registry()
|
||||
|
||||
def my_kernel(tl):
|
||||
pass
|
||||
|
||||
register_kernel("test_kern", my_kernel)
|
||||
assert get_kernel("test_kern") is my_kernel
|
||||
clear_registry()
|
||||
|
||||
|
||||
def test_kernel_registry_missing():
|
||||
clear_registry()
|
||||
import pytest
|
||||
with pytest.raises(KeyError):
|
||||
get_kernel("nonexistent")
|
||||
|
||||
|
||||
def test_kernel_registry_duplicate():
|
||||
clear_registry()
|
||||
register_kernel("dup", lambda tl: None)
|
||||
import pytest
|
||||
with pytest.raises(ValueError):
|
||||
register_kernel("dup", lambda tl: None)
|
||||
clear_registry()
|
||||
|
||||
|
||||
# ── 15. GEMM kernel → correct command sequence ───────────────────
|
||||
|
||||
|
||||
def test_gemm_kernel_command_sequence():
|
||||
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
|
||||
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||
pid = tl.program_id(0)
|
||||
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
|
||||
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
|
||||
|
||||
tl = _ctx(pe_id=3)
|
||||
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
|
||||
|
||||
|
||||
# ── 16. Attention kernel → correct command sequence ───────────────
|
||||
|
||||
|
||||
def test_attention_kernel_command_sequence():
|
||||
"""Attention kernel: load→dot→math ops→dot→store."""
|
||||
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
|
||||
seq_len=16, head_dim=8):
|
||||
pid = tl.program_id(0)
|
||||
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||||
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
|
||||
scores = tl.dot(q, k)
|
||||
row_max = tl.max(scores, axis=1)
|
||||
scores = scores - row_max
|
||||
scores = tl.exp(scores)
|
||||
row_sum = tl.sum(scores, axis=1)
|
||||
scores = scores / row_sum
|
||||
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||||
out = tl.dot(scores, v)
|
||||
tl.store(out_ptr, out)
|
||||
|
||||
tl = _ctx(pe_id=0)
|
||||
run_kernel(
|
||||
attention_kernel, tl,
|
||||
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
|
||||
)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
# load, load, dot, max, sub, exp, sum, div, load, dot, store
|
||||
assert types == [
|
||||
"DmaReadCmd", "DmaReadCmd", # load Q, K
|
||||
"GemmCmd", # Q @ K
|
||||
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
|
||||
"MathCmd", "MathCmd", # sum, div
|
||||
"DmaReadCmd", # load V
|
||||
"GemmCmd", # scores @ V
|
||||
"DmaWriteCmd", # store output
|
||||
]
|
||||
# Verify math ops
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
math_ops = [c.op for c in math_cmds]
|
||||
assert math_ops == ["max", "sub", "exp", "sum", "div"]
|
||||
|
||||
|
||||
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
|
||||
|
||||
|
||||
def test_dispatch_overhead_inserted():
|
||||
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
|
||||
tl = _ctx_with_overhead()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
tl.store(0x2000, a)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
# overhead, load, overhead, store
|
||||
assert types == [
|
||||
"PeCpuOverheadCmd", "DmaReadCmd",
|
||||
"PeCpuOverheadCmd", "DmaWriteCmd",
|
||||
]
|
||||
|
||||
|
||||
# ── 18. where operation ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_where():
|
||||
tl = _ctx()
|
||||
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
|
||||
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
|
||||
out = tl.where(cond, a, b)
|
||||
assert isinstance(out, TensorHandle)
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
assert len(math_cmds) == 1
|
||||
assert math_cmds[0].op == "where"
|
||||
assert len(math_cmds[0].inputs) == 3
|
||||
Reference in New Issue
Block a user