Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
|
||||
|
||||
Validates:
|
||||
T10. TensorHandle has va_base; TensorShard does NOT have va field
|
||||
T11. deploy_tensor allocates VA + creates mapping entries
|
||||
T12. Tensor.va returns the tensor's VA base
|
||||
T13. tl.load/tl.store generate DMA commands with VA (not PA)
|
||||
T14. Kernel VA-based offset calculation flows through DMA commands
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import column_wise, ShardSpec
|
||||
from kernbench.runtime_api.tensor import (
|
||||
TensorHandle,
|
||||
TensorShard,
|
||||
deploy_tensor,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import TensorArgShard
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
|
||||
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
|
||||
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
|
||||
|
||||
|
||||
def _make_va_allocator() -> VirtualAllocator:
|
||||
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
|
||||
|
||||
|
||||
# ── T10. TensorHandle has va_base ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_handle_has_va_base():
|
||||
"""TensorHandle must have a 'va_base' field."""
|
||||
th = TensorHandle(
|
||||
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||
shards=(), va_base=0x1_0000_0000,
|
||||
)
|
||||
assert th.va_base == 0x1_0000_0000
|
||||
|
||||
|
||||
def test_tensor_handle_va_base_immutable():
|
||||
"""TensorHandle.va_base is immutable (frozen dataclass)."""
|
||||
th = TensorHandle(
|
||||
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||
shards=(), va_base=0x1_0000_0000,
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
th.va_base = 0x2_0000_0000 # type: ignore[misc]
|
||||
|
||||
|
||||
def test_tensor_shard_no_va_field():
|
||||
"""TensorShard should NOT have a va field — va is derived from
|
||||
TensorHandle.va_base + shard.offset_bytes."""
|
||||
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||||
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
|
||||
|
||||
|
||||
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
|
||||
|
||||
|
||||
def test_deploy_tensor_assigns_va_base():
|
||||
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
assert th.va_base is not None
|
||||
assert th.va_base > 0
|
||||
|
||||
|
||||
def test_deploy_tensor_va_covers_all_shards():
|
||||
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Each shard's VA is derivable: va_base + offset_bytes
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert shard_va > 0
|
||||
|
||||
|
||||
def test_deploy_tensor_registers_mmu_mappings():
|
||||
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Every MMU should have entries (broadcast)
|
||||
for mmu in mmus.values():
|
||||
assert mmu.num_entries > 0
|
||||
|
||||
# Each shard's derived VA should translate to its PA in every MMU
|
||||
for mmu in mmus.values():
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert mmu.translate(shard_va) == s.pa
|
||||
|
||||
|
||||
# ── T12. Tensor.va property ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_va_property():
|
||||
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
allocs = _make_allocators(1)
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus(1)
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
|
||||
|
||||
t = Tensor(shape=(2048,), dtype="f16", name="test")
|
||||
t._handle = deploy_tensor(
|
||||
name="test",
|
||||
shape=(2048,),
|
||||
dtype="f16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
assert t.va > 0
|
||||
assert t.va == t._handle.va_base
|
||||
|
||||
|
||||
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_load_uses_va_in_dma_cmd():
|
||||
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
|
||||
tl = TLContext(dispatch_cycles=0)
|
||||
va_ptr = 0x1_0000_0000
|
||||
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
# The DMA command should carry the VA
|
||||
assert dma_cmds[0].src_addr == va_ptr
|
||||
|
||||
|
||||
def test_tl_store_uses_va_in_dma_cmd():
|
||||
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
|
||||
tl = TLContext(dispatch_cycles=0)
|
||||
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
|
||||
va_out = 0x2_0000_0000
|
||||
tl.store(va_out, h)
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
assert dma_cmds[0].dst_addr == va_out
|
||||
|
||||
|
||||
# ── T14. Kernel VA offset calculation ─────────────────────────────────
|
||||
|
||||
|
||||
def test_kernel_va_offset_in_dma():
|
||||
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
|
||||
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
|
||||
pid = tl.program_id(0)
|
||||
elem_bytes = 2 # f16
|
||||
offset = pid * BLOCK_SIZE * elem_bytes
|
||||
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
|
||||
|
||||
va_base = 0x1_0000_0000
|
||||
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
|
||||
run_kernel(tiled_kernel, tl, a_ptr=va_base)
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
|
||||
assert dma_cmds[0].src_addr == expected_va
|
||||
Reference in New Issue
Block a user