Files
kernbench2/tests/test_va_integration.py
ywkang 357cab525b ADR-0026: DPPolicy intra-device only + ShardSpec structural coords
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is
solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy
itself describes only the cube × PE layout within one SIP. ShardSpec
switches to structural (sip, cube, pe) coordinates; the flat pe_index
field/property is fully removed — silent drift between global-flat and
SIP-local interpretations was a foot-gun flagged by ADR-0024 D11.

Breaking API (explicit TypeError / AttributeError):
- DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError
- ShardSpec.pe_index -> AttributeError
- ShardSpec(pe_index=...) -> TypeError
- resolve_dp_policy now takes target_sip= (required), no num_sips.

Downstream migration:
- PE allocator dict keyed by (sip, cube, pe) tuples, in both
  _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup.
- _create_tensor passes target_sip=current_sip; post-hoc pe_index
  shifting removed entirely.
- launch._compute_local_shape drops the dp.sip branch.
- Internal resolvers (column_wise / row_wise / replicate / tiled_*)
  return _LocalPeShard (cube-local identifier) instead of ShardSpec —
  resolve_dp_policy lifts them to full structural coords.

Tests:
- New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the
  contract end-to-end.
- test_sip_parallel.py rewritten: SIP composition now modeled as two
  resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style).
- Call-site migration: test_tensor, test_va_integration, test_va_offset,
  test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches
  gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy
  branch) all use intra-device DPPolicy and structural ShardSpec.

Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged
ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 13:02:19 -07:00

229 lines
7.4 KiB
Python

"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
Validates:
T10. TensorHandle has va_base; TensorShard does NOT have va field
T11. deploy_tensor allocates VA + creates mapping entries
T12. Tensor.va returns the tensor's VA base
T13. tl.load/tl.store generate DMA commands with VA (not PA)
T14. Kernel VA-based offset calculation flows through DMA commands
"""
import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
)
from kernbench.runtime_api.kernel import TensorArgShard
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
return {
(0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
def _make_va_allocator() -> VirtualAllocator:
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
# ── T10. TensorHandle has va_base ────────────────────────────────────
def test_tensor_handle_has_va_base():
"""TensorHandle must have a 'va_base' field."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
assert th.va_base == 0x1_0000_0000
def test_tensor_handle_va_base_immutable():
"""TensorHandle.va_base is immutable (frozen dataclass)."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
with pytest.raises(AttributeError):
th.va_base = 0x2_0000_0000 # type: ignore[misc]
def test_tensor_shard_no_va_field():
"""TensorShard should NOT have a va field — va is derived from
TensorHandle.va_base + shard.offset_bytes."""
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
assert th.va_base is not None
assert th.va_base > 0
def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_does_not_install_mmu_mappings():
"""deploy_tensor does NOT install MMU mappings — that's context's job."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
# No MMU should have any entries (mappings come from fabric MmuMapMsg)
for mmu in mmus.values():
assert mmu.num_entries == 0
# ── T12. Tensor.va property ──────────────────────────────────────────
def test_tensor_va_property():
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
from kernbench.runtime_api.tensor import Tensor
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
t._handle = deploy_tensor(
name="test",
shape=(2048,),
dtype="f16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
)
assert t.va > 0
assert t.va == t._handle.va_base
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
def test_tl_load_uses_va_in_dma_cmd():
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
tl = TLContext(dispatch_cycles=0)
va_ptr = 0x1_0000_0000
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
# The DMA command should carry the VA
assert dma_cmds[0].src_addr == va_ptr
def test_tl_store_uses_va_in_dma_cmd():
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
tl = TLContext(dispatch_cycles=0)
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
va_out = 0x2_0000_0000
tl.store(va_out, h)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(dma_cmds) == 1
assert dma_cmds[0].dst_addr == va_out
# ── T14. Kernel VA offset calculation ─────────────────────────────────
def test_kernel_va_offset_in_dma():
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
pid = tl.program_id(0)
elem_bytes = 2 # f16
offset = pid * BLOCK_SIZE * elem_bytes
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
va_base = 0x1_0000_0000
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
run_kernel(tiled_kernel, tl, a_ptr=va_base)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
assert dma_cmds[0].src_addr == expected_va