Files
kernbench2/tests/test_triton_emu.py
T
ywkang 998cc85762 Add PE-level IPCQ collective infra + unified ccl_allreduce bench (ADR-0023)
Major changes:

PE-level IPCQ infrastructure:
- New PE_IPCQ component: ring-buffer control plane with 4-direction
  neighbor mapping, head/tail pointers, backpressure (poll/sleep).
- PE_DMA extended with vc_comm channel for IPCQ outbound/inbound DMA,
  including in-flight data snapshot (D9) and op_log recording at
  outbound time for Phase 2 replay correctness.
- IpcqDmaToken piggyback model: data + metadata travel together,
  atomic visibility at receiver (invariant I6).
- Credit return fast path: bottleneck-BW latency, no fabric vc_comm.

Phase 2 data execution (ADR-0020 integration):
- op_log extended: DmaWriteCmd now captures src_space/src_addr for
  Phase 2 dma_write copy; ipcq_copy ops recorded at outbound time.
- DataExecutor replays dma_write + ipcq_copy in t_start order.
- Engine._flush_data_phase: incremental cursor-based replay after
  each engine.wait() so host reads see post-Phase-2 data.
- KernelRunner Phase 1 writes disabled when op_log is active to
  prevent stale data from corrupting the MemoryStore snapshot.

TLContext / kernel API:
- tl.send(dir, src=TensorHandle), tl.recv(dir, shape, dtype),
  tl.recv_async, tl.wait(RecvFuture), copy_to_dst mode.
- TensorHandle operator overloading (add/sub/mul/div) via thread-local
  active TLContext → MathCmd dispatch through PE_MATH.
- PE-local scratch allocator for math output handles.
- tl.load returns space="hbm" handles for correct Phase 2 addressing.
- Additional math functions: maximum, minimum, fma, clamp, softmax, cdiv.

Unified ccl_allreduce bench (PyTorch-compat host code):
- Single benches/ccl_allreduce.py with run() + worker(rank, ws, torch)
  split matching real PyTorch DDP worker pattern.
- torch.distributed facade: init_process_group, get_world_size,
  get_rank, get_backend, all_reduce, barrier — only real PyTorch names.
- AhbmCCLBackend: eager install_ipcq at init, all_reduce dispatches
  kernel via tensor shard metadata (n_elem from shards[0].nbytes).
- world_size derived from topology spec (sips × cubes × pes_per_cube)
  with optional algorithm-level override in ccl.yaml.

Tensor API (PyTorch-compat surface):
- Tensor.numpy(): gather-aware (all shards via VA-based addressing).
- Tensor.copy_(source): scatter from host tensor into sharded target.
- RuntimeContext.from_numpy(arr): host-side staging tensor.
- Tensor.data property fixed to use numpy() (was shards[0]-only).

Algorithm modules moved to src/kernbench/ccl/algorithms/:
- ring_allreduce, mesh_allreduce, tree_allreduce, hello_send.
- Each module exports kernel_args(world_size, n_elem) helper.
- ccl.yaml module paths updated to kernbench.ccl.algorithms.*.

Dead code removed:
- 7 per-variant bench files (ccl_allreduce_{tcm,hbm,sram}, etc.).
- _run_ccl_bench greenlet-per-SIP scheduler.
- benches.loader.is_ccl_bench + run_rank detection.
- benches/ccl/ directory.

Tests:
- New test_ccl_allreduce_matrix.py: 7 parametrized cases
  (ring×3 buffers, ring 8/16, mesh 4, tree 7).
- New test_runtime_api_tensor.py: copy_/numpy/from_numpy unit tests.
- Existing tests updated for new import paths + world_size_override.

Docs:
- Korean ccl-author-guide.md and ADR-0023 paths updated.
- New English versions: ccl-author-guide.en.md, ADR-0023.en.md.

502 tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-12 19:36:59 -07:00

398 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
def _ctx(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=0, **kwargs)
def _ctx_with_overhead(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=1, **kwargs)
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
def test_tl_load_generates_dma_read():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
assert isinstance(h, TensorHandle)
assert h.shape == (32, 64)
assert h.nbytes == 32 * 64 * 2
cmds = tl.commands
assert len(cmds) == 1
assert isinstance(cmds[0], DmaReadCmd)
assert cmds[0].src_addr == 0x1000
assert cmds[0].nbytes == 32 * 64 * 2
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
def test_tl_store_generates_dma_write():
tl = _ctx()
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
tl.store(0x2000, h)
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(cmds) == 1
assert cmds[0].dst_addr == 0x2000
assert cmds[0].nbytes == 16 * 16 * 4
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
def test_tl_dot_generates_gemm_cmd():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
out = tl.dot(a, b)
assert out.shape == (32, 16)
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
assert len(cmds) == 1
assert cmds[0].m == 32
assert cmds[0].k == 64
assert cmds[0].n == 16
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
def test_tl_math_unary_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
for op_name, op_fn in [
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
("abs", tl.abs), ("sigmoid", tl.sigmoid),
("cos", tl.cos), ("sin", tl.sin),
]:
result = op_fn(x)
assert isinstance(result, TensorHandle)
assert result.shape == x.shape
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
def test_tl_math_extra_ops():
"""tl.maximum/minimum/fma/clamp/softmax + tl.cdiv (real-Triton parity)."""
tl = _ctx()
a = tl.load(0x1000, shape=(8, 8), dtype="f16")
b = tl.load(0x2000, shape=(8, 8), dtype="f16")
c = tl.load(0x3000, shape=(8, 8), dtype="f16")
tl.maximum(a, b)
tl.minimum(a, b)
tl.fma(a, b, c)
tl.clamp(a, b, c)
tl.softmax(a, axis=1)
math_cmds = [cm for cm in tl.commands if isinstance(cm, MathCmd)]
ops = [cm.op for cm in math_cmds]
assert ops == ["maximum", "minimum", "fma", "clamp", "softmax"]
# ternary fma/clamp must record three inputs
fma_cmd = math_cmds[2]
assert len(fma_cmd.inputs) == 3
clamp_cmd = math_cmds[3]
assert len(clamp_cmd.inputs) == 3
# softmax records the axis
assert math_cmds[4].axis == 1
# cdiv is a scalar helper, not a tensor op
from kernbench.triton_emu.tl_context import TLContext
assert TLContext.cdiv(10, 3) == 4
assert TLContext.cdiv(9, 3) == 3
assert TLContext.cdiv(0, 4) == 0
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
def test_tl_math_binary_ops():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
# Need active context for operators
tl2 = _ctx()
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
def kernel(tl):
pass
# Use run_kernel to activate context, then test operators
tl3 = _ctx()
def binary_kernel(tl):
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
_ = a + b
_ = a - b
_ = a * b
_ = a / b
run_kernel(binary_kernel, tl3)
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["add", "sub", "mul", "div"]
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
def test_tl_reduction_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
s = tl.sum(x, axis=1)
m = tl.max(x, axis=0)
assert s.shape == (32, 1)
assert m.shape == (1, 64)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
def test_tl_composite_nonblocking():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
assert isinstance(h, CompletionHandle)
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
assert len(comp_cmds) == 1
assert comp_cmds[0].op == "gemm"
assert comp_cmds[0].out_addr == 0x3000
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
def test_tl_wait_specific():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
tl.wait(h)
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle == h
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
def test_tl_wait_all():
tl = _ctx()
tl.wait()
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle is None
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
def test_tl_cycles():
tl = _ctx()
tl.cycles(10)
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
assert tl.commands[0].cycles == 10
# ── 11. tl.program_id ────────────────────────────────────────────
def test_tl_program_id():
tl = TLContext(pe_id=5, num_programs=8)
assert tl.program_id(0) == 5
assert tl.num_programs(0) == 8
def test_tl_program_id_axis1():
"""axis=1 returns cube_id and num_cubes."""
tl = TLContext(pe_id=3, num_programs=8, cube_id=7, num_cubes=16)
assert tl.program_id(0) == 3
assert tl.program_id(1) == 7
assert tl.num_programs(0) == 8
assert tl.num_programs(1) == 16
def test_tl_program_id_global():
"""global_pid = cube_id * num_pes_per_cube + local_pe_id."""
pe_id, cube_id, num_pes = 5, 3, 8
tl = TLContext(pe_id=pe_id, num_programs=num_pes, cube_id=cube_id, num_cubes=16)
global_pid = tl.program_id(1) * tl.num_programs(0) + tl.program_id(0)
assert global_pid == cube_id * num_pes + pe_id
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
def test_tl_arange_zeros_full():
tl = _ctx()
r = tl.arange(0, 16, dtype="i32")
assert r.shape == (16,)
assert r.dtype == "i32"
z = tl.zeros((4, 8), dtype="f16")
assert z.shape == (4, 8)
assert z.nbytes == 4 * 8 * 2
f = tl.full((2, 3), value=1.0, dtype="f32")
assert f.shape == (2, 3)
assert f.nbytes == 2 * 3 * 4
# ── 13. tl.trans → shape change, no command ───────────────────────
def test_tl_trans_shape():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
t = tl.trans(h)
assert t.shape == (64, 32)
assert t.id == h.id # same underlying data
# Only DmaReadCmd from load, no command from trans
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], DmaReadCmd)
# ── 14. Kernel registry ──────────────────────────────────────────
def test_kernel_registry():
clear_registry()
def my_kernel(tl):
pass
register_kernel("test_kern", my_kernel)
assert get_kernel("test_kern") is my_kernel
clear_registry()
def test_kernel_registry_missing():
clear_registry()
import pytest
with pytest.raises(KeyError):
get_kernel("nonexistent")
def test_kernel_registry_duplicate():
clear_registry()
register_kernel("dup", lambda tl: None)
import pytest
with pytest.raises(ValueError):
register_kernel("dup", lambda tl: None)
clear_registry()
# ── 15. GEMM kernel → correct command sequence ───────────────────
def test_gemm_kernel_command_sequence():
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
tl = _ctx(pe_id=3)
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
types = [type(c).__name__ for c in tl.commands]
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
# ── 16. Attention kernel → correct command sequence ───────────────
def test_attention_kernel_command_sequence():
"""Attention kernel: load→dot→math ops→dot→store."""
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
seq_len=16, head_dim=8):
pid = tl.program_id(0)
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
scores = tl.dot(q, k)
row_max = tl.max(scores, axis=1)
scores = scores - row_max
scores = tl.exp(scores)
row_sum = tl.sum(scores, axis=1)
scores = scores / row_sum
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
out = tl.dot(scores, v)
tl.store(out_ptr, out)
tl = _ctx(pe_id=0)
run_kernel(
attention_kernel, tl,
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
)
types = [type(c).__name__ for c in tl.commands]
# load, load, dot, max, sub, exp, sum, div, load, dot, store
assert types == [
"DmaReadCmd", "DmaReadCmd", # load Q, K
"GemmCmd", # Q @ K
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
"MathCmd", "MathCmd", # sum, div
"DmaReadCmd", # load V
"GemmCmd", # scores @ V
"DmaWriteCmd", # store output
]
# Verify math ops
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
math_ops = [c.op for c in math_cmds]
assert math_ops == ["max", "sub", "exp", "sum", "div"]
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
def test_dispatch_overhead_inserted():
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
tl = _ctx_with_overhead()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
tl.store(0x2000, a)
types = [type(c).__name__ for c in tl.commands]
# overhead, load, overhead, store
assert types == [
"PeCpuOverheadCmd", "DmaReadCmd",
"PeCpuOverheadCmd", "DmaWriteCmd",
]
# ── 18. where operation ──────────────────────────────────────────
def test_tl_where():
tl = _ctx()
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
out = tl.where(cond, a, b)
assert isinstance(out, TensorHandle)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert len(math_cmds) == 1
assert math_cmds[0].op == "where"
assert len(math_cmds[0].inputs) == 3