Files
kernbench2/benches/va_offset_verify.py
T
ywkang 63669f82cb Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 01:13:17 -07:00

43 lines
1.6 KiB
Python

"""VA offset verification benchmark.
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
block from a sharded tensor and stores it back.
The kernel uses standard Triton patterns:
- tl.program_id(0) for PE index within cube
- tl.num_programs(0) for PE count within cube
- Shape args are automatically localized by launch()
"""
from kernbench.policy.placement.dp import DPPolicy
M, K = 128, 256
DTYPE = "f16"
def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
"""Standard Triton copy kernel. M and K are cube-local (set by launch)."""
pid = tl.program_id(0)
num_pe = tl.num_programs(0)
cols_per_pe = K // num_pe
elem_bytes = 2 # f16
offset = pid * M * cols_per_pe * elem_bytes
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
tl.store(dst_ptr + offset, data)
def run(torch):
"""Run the VA offset verification benchmark with full TP sharding."""
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
# launch() automatically converts M, K to cube-local values
torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K)
# Sanity check: kernel completed with non-zero latency
kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"]
assert len(kernel_traces) > 0, "No kernel traces recorded"
for kt in kernel_traces:
assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"