08812eda58
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
350 lines
12 KiB
Python
350 lines
12 KiB
Python
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
|
||
from kernbench.common.pe_commands import (
|
||
CompletionHandle,
|
||
CompositeCmd,
|
||
DmaReadCmd,
|
||
DmaWriteCmd,
|
||
GemmCmd,
|
||
MathCmd,
|
||
PeCpuOverheadCmd,
|
||
TensorHandle,
|
||
WaitCmd,
|
||
)
|
||
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
|
||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||
|
||
|
||
def _ctx(**kwargs) -> TLContext:
|
||
return TLContext(dispatch_cycles=0, **kwargs)
|
||
|
||
|
||
def _ctx_with_overhead(**kwargs) -> TLContext:
|
||
return TLContext(dispatch_cycles=1, **kwargs)
|
||
|
||
|
||
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
|
||
|
||
|
||
def test_tl_load_generates_dma_read():
|
||
tl = _ctx()
|
||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||
assert isinstance(h, TensorHandle)
|
||
assert h.shape == (32, 64)
|
||
assert h.nbytes == 32 * 64 * 2
|
||
cmds = tl.commands
|
||
assert len(cmds) == 1
|
||
assert isinstance(cmds[0], DmaReadCmd)
|
||
assert cmds[0].src_addr == 0x1000
|
||
assert cmds[0].nbytes == 32 * 64 * 2
|
||
|
||
|
||
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
|
||
|
||
|
||
def test_tl_store_generates_dma_write():
|
||
tl = _ctx()
|
||
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
|
||
tl.store(0x2000, h)
|
||
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||
assert len(cmds) == 1
|
||
assert cmds[0].dst_addr == 0x2000
|
||
assert cmds[0].nbytes == 16 * 16 * 4
|
||
|
||
|
||
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
|
||
|
||
|
||
def test_tl_dot_generates_gemm_cmd():
|
||
tl = _ctx()
|
||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
|
||
out = tl.dot(a, b)
|
||
assert out.shape == (32, 16)
|
||
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
|
||
assert len(cmds) == 1
|
||
assert cmds[0].m == 32
|
||
assert cmds[0].k == 64
|
||
assert cmds[0].n == 16
|
||
|
||
|
||
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
|
||
|
||
|
||
def test_tl_math_unary_ops():
|
||
tl = _ctx()
|
||
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
|
||
for op_name, op_fn in [
|
||
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
|
||
("abs", tl.abs), ("sigmoid", tl.sigmoid),
|
||
("cos", tl.cos), ("sin", tl.sin),
|
||
]:
|
||
result = op_fn(x)
|
||
assert isinstance(result, TensorHandle)
|
||
assert result.shape == x.shape
|
||
|
||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||
ops = [c.op for c in math_cmds]
|
||
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
|
||
|
||
|
||
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
|
||
|
||
|
||
def test_tl_math_binary_ops():
|
||
tl = _ctx()
|
||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
|
||
|
||
# Need active context for operators
|
||
tl2 = _ctx()
|
||
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
|
||
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
|
||
|
||
def kernel(tl):
|
||
pass
|
||
|
||
# Use run_kernel to activate context, then test operators
|
||
tl3 = _ctx()
|
||
|
||
def binary_kernel(tl):
|
||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||
_ = a + b
|
||
_ = a - b
|
||
_ = a * b
|
||
_ = a / b
|
||
|
||
run_kernel(binary_kernel, tl3)
|
||
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
|
||
ops = [c.op for c in math_cmds]
|
||
assert ops == ["add", "sub", "mul", "div"]
|
||
|
||
|
||
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
|
||
|
||
|
||
def test_tl_reduction_ops():
|
||
tl = _ctx()
|
||
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||
s = tl.sum(x, axis=1)
|
||
m = tl.max(x, axis=0)
|
||
assert s.shape == (32, 1)
|
||
assert m.shape == (1, 64)
|
||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
|
||
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
|
||
|
||
|
||
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
|
||
|
||
|
||
def test_tl_composite_nonblocking():
|
||
tl = _ctx()
|
||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
|
||
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
|
||
assert isinstance(h, CompletionHandle)
|
||
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
||
assert len(comp_cmds) == 1
|
||
assert comp_cmds[0].op == "gemm"
|
||
assert comp_cmds[0].out_addr == 0x3000
|
||
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
||
|
||
|
||
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
|
||
|
||
|
||
def test_tl_wait_specific():
|
||
tl = _ctx()
|
||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
|
||
tl.wait(h)
|
||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||
assert len(wait_cmds) == 1
|
||
assert wait_cmds[0].handle == h
|
||
|
||
|
||
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
|
||
|
||
|
||
def test_tl_wait_all():
|
||
tl = _ctx()
|
||
tl.wait()
|
||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||
assert len(wait_cmds) == 1
|
||
assert wait_cmds[0].handle is None
|
||
|
||
|
||
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
|
||
|
||
|
||
def test_tl_cycles():
|
||
tl = _ctx()
|
||
tl.cycles(10)
|
||
assert len(tl.commands) == 1
|
||
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
|
||
assert tl.commands[0].cycles == 10
|
||
|
||
|
||
# ── 11. tl.program_id ────────────────────────────────────────────
|
||
|
||
|
||
def test_tl_program_id():
|
||
tl = TLContext(pe_id=5, num_programs=8)
|
||
assert tl.program_id(0) == 5
|
||
assert tl.num_programs(0) == 8
|
||
|
||
|
||
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
|
||
|
||
|
||
def test_tl_arange_zeros_full():
|
||
tl = _ctx()
|
||
r = tl.arange(0, 16, dtype="i32")
|
||
assert r.shape == (16,)
|
||
assert r.dtype == "i32"
|
||
|
||
z = tl.zeros((4, 8), dtype="f16")
|
||
assert z.shape == (4, 8)
|
||
assert z.nbytes == 4 * 8 * 2
|
||
|
||
f = tl.full((2, 3), value=1.0, dtype="f32")
|
||
assert f.shape == (2, 3)
|
||
assert f.nbytes == 2 * 3 * 4
|
||
|
||
|
||
# ── 13. tl.trans → shape change, no command ───────────────────────
|
||
|
||
|
||
def test_tl_trans_shape():
|
||
tl = _ctx()
|
||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||
t = tl.trans(h)
|
||
assert t.shape == (64, 32)
|
||
assert t.id == h.id # same underlying data
|
||
# Only DmaReadCmd from load, no command from trans
|
||
assert len(tl.commands) == 1
|
||
assert isinstance(tl.commands[0], DmaReadCmd)
|
||
|
||
|
||
# ── 14. Kernel registry ──────────────────────────────────────────
|
||
|
||
|
||
def test_kernel_registry():
|
||
clear_registry()
|
||
|
||
def my_kernel(tl):
|
||
pass
|
||
|
||
register_kernel("test_kern", my_kernel)
|
||
assert get_kernel("test_kern") is my_kernel
|
||
clear_registry()
|
||
|
||
|
||
def test_kernel_registry_missing():
|
||
clear_registry()
|
||
import pytest
|
||
with pytest.raises(KeyError):
|
||
get_kernel("nonexistent")
|
||
|
||
|
||
def test_kernel_registry_duplicate():
|
||
clear_registry()
|
||
register_kernel("dup", lambda tl: None)
|
||
import pytest
|
||
with pytest.raises(ValueError):
|
||
register_kernel("dup", lambda tl: None)
|
||
clear_registry()
|
||
|
||
|
||
# ── 15. GEMM kernel → correct command sequence ───────────────────
|
||
|
||
|
||
def test_gemm_kernel_command_sequence():
|
||
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
|
||
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
|
||
pid = tl.program_id(0)
|
||
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
||
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
|
||
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
|
||
|
||
tl = _ctx(pe_id=3)
|
||
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
|
||
types = [type(c).__name__ for c in tl.commands]
|
||
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
|
||
|
||
|
||
# ── 16. Attention kernel → correct command sequence ───────────────
|
||
|
||
|
||
def test_attention_kernel_command_sequence():
|
||
"""Attention kernel: load→dot→math ops→dot→store."""
|
||
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
|
||
seq_len=16, head_dim=8):
|
||
pid = tl.program_id(0)
|
||
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
|
||
scores = tl.dot(q, k)
|
||
row_max = tl.max(scores, axis=1)
|
||
scores = scores - row_max
|
||
scores = tl.exp(scores)
|
||
row_sum = tl.sum(scores, axis=1)
|
||
scores = scores / row_sum
|
||
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||
out = tl.dot(scores, v)
|
||
tl.store(out_ptr, out)
|
||
|
||
tl = _ctx(pe_id=0)
|
||
run_kernel(
|
||
attention_kernel, tl,
|
||
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
|
||
)
|
||
types = [type(c).__name__ for c in tl.commands]
|
||
# load, load, dot, max, sub, exp, sum, div, load, dot, store
|
||
assert types == [
|
||
"DmaReadCmd", "DmaReadCmd", # load Q, K
|
||
"GemmCmd", # Q @ K
|
||
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
|
||
"MathCmd", "MathCmd", # sum, div
|
||
"DmaReadCmd", # load V
|
||
"GemmCmd", # scores @ V
|
||
"DmaWriteCmd", # store output
|
||
]
|
||
# Verify math ops
|
||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||
math_ops = [c.op for c in math_cmds]
|
||
assert math_ops == ["max", "sub", "exp", "sum", "div"]
|
||
|
||
|
||
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
|
||
|
||
|
||
def test_dispatch_overhead_inserted():
|
||
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
|
||
tl = _ctx_with_overhead()
|
||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||
tl.store(0x2000, a)
|
||
types = [type(c).__name__ for c in tl.commands]
|
||
# overhead, load, overhead, store
|
||
assert types == [
|
||
"PeCpuOverheadCmd", "DmaReadCmd",
|
||
"PeCpuOverheadCmd", "DmaWriteCmd",
|
||
]
|
||
|
||
|
||
# ── 18. where operation ──────────────────────────────────────────
|
||
|
||
|
||
def test_tl_where():
|
||
tl = _ctx()
|
||
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
|
||
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
|
||
out = tl.where(cond, a, b)
|
||
assert isinstance(out, TensorHandle)
|
||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||
assert len(math_cmds) == 1
|
||
assert math_cmds[0].op == "where"
|
||
assert len(math_cmds[0].inputs) == 3
|