Files
kernbench2/tests/test_va_integration.py
T
ywkang 63669f82cb Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 01:13:17 -07:00

217 lines
7.0 KiB
Python

"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
Validates:
T10. TensorHandle has va_base; TensorShard does NOT have va field
T11. deploy_tensor allocates VA + creates mapping entries
T12. Tensor.va returns the tensor's VA base
T13. tl.load/tl.store generate DMA commands with VA (not PA)
T14. Kernel VA-based offset calculation flows through DMA commands
"""
import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, ShardSpec
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
)
from kernbench.runtime_api.kernel import TensorArgShard
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
def _make_va_allocator() -> VirtualAllocator:
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
# ── T10. TensorHandle has va_base ────────────────────────────────────
def test_tensor_handle_has_va_base():
"""TensorHandle must have a 'va_base' field."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
assert th.va_base == 0x1_0000_0000
def test_tensor_handle_va_base_immutable():
"""TensorHandle.va_base is immutable (frozen dataclass)."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
with pytest.raises(AttributeError):
th.va_base = 0x2_0000_0000 # type: ignore[misc]
def test_tensor_shard_no_va_field():
"""TensorShard should NOT have a va field — va is derived from
TensorHandle.va_base + shard.offset_bytes."""
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
assert th.va_base is not None
assert th.va_base > 0
def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_does_not_install_mmu_mappings():
"""deploy_tensor does NOT install MMU mappings — that's context's job."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
# No MMU should have any entries (mappings come from fabric MmuMapMsg)
for mmu in mmus.values():
assert mmu.num_entries == 0
# ── T12. Tensor.va property ──────────────────────────────────────────
def test_tensor_va_property():
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
from kernbench.runtime_api.tensor import Tensor
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
t._handle = deploy_tensor(
name="test",
shape=(2048,),
dtype="f16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
assert t.va > 0
assert t.va == t._handle.va_base
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
def test_tl_load_uses_va_in_dma_cmd():
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
tl = TLContext(dispatch_cycles=0)
va_ptr = 0x1_0000_0000
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
# The DMA command should carry the VA
assert dma_cmds[0].src_addr == va_ptr
def test_tl_store_uses_va_in_dma_cmd():
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
tl = TLContext(dispatch_cycles=0)
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
va_out = 0x2_0000_0000
tl.store(va_out, h)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(dma_cmds) == 1
assert dma_cmds[0].dst_addr == va_out
# ── T14. Kernel VA offset calculation ─────────────────────────────────
def test_kernel_va_offset_in_dma():
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
pid = tl.program_id(0)
elem_bytes = 2 # f16
offset = pid * BLOCK_SIZE * elem_bytes
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
va_base = 0x1_0000_0000
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
run_kernel(tiled_kernel, tl, a_ptr=va_base)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
assert dma_cmds[0].src_addr == expected_va