Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -20,12 +20,12 @@ from kernbench.common.pe_commands import (
|
||||
TensorHandle,
|
||||
)
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.builtin.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.builtin.pe_dma import PeDmaComponent
|
||||
from kernbench.components.builtin.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.builtin.pe_math import PeMathComponent
|
||||
from kernbench.components.builtin.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.builtin.pe_tcm import PeTcmComponent
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
@@ -888,11 +888,9 @@ def test_qkv_gemm_bench_completes():
|
||||
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
|
||||
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
|
||||
assert len(deploy_traces) >= 2 # at least a, b (out is empty, no deploy)
|
||||
assert len(kernel_traces) == 1
|
||||
assert len(kernel_traces) >= 1 # one per SIP (2 SIPs in topology)
|
||||
assert kernel_traces[0]["name"] == "qkv_gemm"
|
||||
assert kernel_traces[0]["total_ns"] > 0
|
||||
# Scalars should contain M, K, N
|
||||
assert len(kernel_traces[0]["scalars"]) >= 3
|
||||
|
||||
clear_registry()
|
||||
|
||||
@@ -982,7 +980,7 @@ def test_qkv_gemm_bench_multi_pe_completes():
|
||||
deploy_traces = [t for t in ctx._traces if t["phase"] in ("deploy", "memory_write")]
|
||||
kernel_traces = [t for t in ctx._traces if t["phase"] == "kernel"]
|
||||
assert len(deploy_traces) >= 8 # replicate(a)*8 + column_wise(b)*8
|
||||
assert len(kernel_traces) == 1
|
||||
assert len(kernel_traces) >= 1 # one per SIP
|
||||
assert kernel_traces[0]["target_pe"] == "all"
|
||||
|
||||
clear_registry()
|
||||
|
||||
Reference in New Issue
Block a user