Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""VA offset verification benchmark.
|
||||
|
||||
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
|
||||
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
|
||||
block from a sharded tensor and stores it back.
|
||||
|
||||
The kernel uses standard Triton patterns:
|
||||
- tl.program_id(0) for PE index within cube
|
||||
- tl.num_programs(0) for PE count within cube
|
||||
- Shape args are automatically localized by launch()
|
||||
"""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
M, K = 128, 256
|
||||
DTYPE = "f16"
|
||||
|
||||
|
||||
def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
|
||||
"""Standard Triton copy kernel. M and K are cube-local (set by launch)."""
|
||||
pid = tl.program_id(0)
|
||||
num_pe = tl.num_programs(0)
|
||||
cols_per_pe = K // num_pe
|
||||
elem_bytes = 2 # f16
|
||||
offset = pid * M * cols_per_pe * elem_bytes
|
||||
data = tl.load(src_ptr + offset, shape=(M, cols_per_pe), dtype=DTYPE)
|
||||
tl.store(dst_ptr + offset, data)
|
||||
|
||||
|
||||
def run(torch):
|
||||
"""Run the VA offset verification benchmark with full TP sharding."""
|
||||
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
|
||||
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
|
||||
|
||||
# launch() automatically converts M, K to cube-local values
|
||||
torch.launch("va_offset_copy", _copy_kernel, src, dst, M, K)
|
||||
|
||||
# Sanity check: kernel completed with non-zero latency
|
||||
kernel_traces = [t for t in torch._traces if t["phase"] == "kernel"]
|
||||
assert len(kernel_traces) > 0, "No kernel traces recorded"
|
||||
for kt in kernel_traces:
|
||||
assert kt["total_ns"] > 0, f"Kernel latency is zero for {kt}"
|
||||
Reference in New Issue
Block a user