Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+226
View File
@@ -0,0 +1,226 @@
"""Tests for PE_MMU component integration and MmuMapMsg fabric path.
Validates:
T15-a. PE_MMU component registered in ComponentRegistry
T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table
T15-c. PE_DMA translates VA→PA via mmu before routing
T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields
T17. PE_CPU passes VA (not PA) to kernel when VA is available
T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA
"""
import pytest
try:
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
except ImportError:
pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True)
# ── T16. MmuMapMsg / MmuUnmapMsg message types ──────────────────────
def test_mmu_map_msg_fields():
"""MmuMapMsg carries VA→PA mapping entries for broadcast."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_map"
assert len(msg.entries) == 2
assert msg.entries[0]["va"] == 0x1_0000_0000
assert msg.entries[0]["pa"] == 0xA000_0000
assert msg.entries[0]["size"] == 4096
def test_mmu_map_msg_immutable():
"""MmuMapMsg is frozen."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(),
target_cubes="all",
target_pe="all",
)
with pytest.raises(AttributeError):
msg.entries = () # type: ignore[misc]
def test_mmu_unmap_msg_fields():
"""MmuUnmapMsg carries VA ranges to unmap."""
msg = MmuUnmapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
assert msg.msg_type == "mmu_unmap"
assert len(msg.entries) == 1
assert msg.entries[0]["va"] == 0x1_0000_0000
# ── T15-a. PE_MMU component registry ────────────────────────────────
def test_pe_mmu_registry():
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="pe_mmu_v1",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5},
)
comp = ComponentRegistry.create(node)
assert isinstance(comp, PeMmuComponent)
# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ─────────
def test_pe_mmu_processes_map_msg():
"""PE_MMU component receives MmuMapMsg → translate works."""
import simpy
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Node
env = simpy.Environment()
node = Node(
id="sip0.cube0.pe0.pe_mmu",
kind="pe_mmu",
impl="pe_mmu_v1",
pos_mm=None,
attrs={"tlb_overhead_ns": 0.5, "page_size": 4096},
)
comp = PeMmuComponent(node)
comp.in_ports["src"] = simpy.Store(env)
comp.start(env)
# Submit MmuMapMsg via inbox
map_msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=(
{"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096},
),
target_cubes="all",
target_pe="all",
)
done = env.event()
txn = Transaction(
request=map_msg,
path=["sip0.cube0.pe0.pe_mmu"],
step=0, nbytes=0, done=done,
)
def inject():
yield comp._inbox.put(txn)
env.process(inject())
env.run(until=100)
# After processing, the MMU's translate should work
from kernbench.policy.address.pe_mmu import PeMMU
mmu = comp.mmu # the underlying PeMMU utility object
assert isinstance(mmu, PeMMU)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
# ── T15-c. PE_DMA uses MMU translate ────────────────────────────────
def test_pe_dma_translates_va():
"""PE_DMA.handle_command calls mmu.translate(va) → PA before routing.
This test validates the contract: after Phase 2, DmaReadCmd carries VA,
and PE_DMA must translate it to PA via the MMU before resolving the
HBM node path.
"""
# This test validates the interface contract. Full integration test
# requires the engine wiring which is validated in test_engine.
# Here we check that PE_DMA has an mmu attribute it can call.
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.topology.types import Node
node = Node(
id="sip0.cube0.pe0.pe_dma",
kind="pe_dma",
impl="pe_dma_v1",
pos_mm=None,
attrs={"rd_engines": 1, "wr_engines": 1},
)
comp = PeDmaComponent(node)
# PE_DMA must have a way to access the MMU (via ctx or direct reference)
# The exact wiring mechanism is flexible, but the attribute must exist
assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or (
hasattr(comp, 'ctx') and comp.ctx is not None
), "PE_DMA must have access to PE_MMU for VA translation"
# ── T17. PE_CPU passes VA to kernel ──────────────────────────────────
def test_pe_cpu_uses_va_base_from_tensor_arg():
"""PE_CPU should use TensorArg.va_base for kernel pointer args.
After Phase 2, TensorArg carries va_base alongside shards.
PE_CPU extracts va_base and passes it to the kernel function
so kernels operate on VA (not PA).
"""
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000,
nbytes=4096, offset_bytes=0)
targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000)
# PE_CPU should use targ.va_base for kernel pointer arg
assert targ.va_base == 0x1_0000_0000
# PA still accessible via shard for direct-PA operations (IPCQ etc.)
assert shard.pa == 0x1000
# ── T18. MmuMapMsg broadcast pattern ─────────────────────────────────
def test_mmu_map_msg_broadcast_target():
"""MmuMapMsg with target_pe='all' is a broadcast to all PEs."""
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},),
target_cubes="all",
target_pe="all",
)
assert msg.target_pe == "all"
assert msg.target_cubes == "all"
def test_mmu_map_msg_same_entries_all_pes():
"""All PEs in a broadcast receive identical entries (not per-PE splits)."""
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192},
{"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192},
)
msg = MmuMapMsg(
correlation_id="c0",
request_id="r0",
entries=entries,
target_cubes="all",
target_pe="all",
)
# The message carries the full mapping — every PE receives exactly this
assert msg.entries == entries
+241
View File
@@ -0,0 +1,241 @@
"""Tests for MmuMapMsg fabric path and cross-cube mapping.
Validates:
F1. MmuMapMsg traverses fabric: latency > 0 (not sideband)
F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs
F3. After MmuMapMsg, PE_MMU has correct mappings
F4. Cross-cube sharded tensor: all PEs get global mappings
F5. Replicate tensor: each PE gets own cube's PA (local override)
F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM
F7. Overlap detection: replicate vs sharded identified correctly
F8. Existing regression: PA-only benchmarks still pass
"""
import pytest
from pathlib import Path
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec
from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _engine():
return GraphEngine(load_topology(TOPOLOGY_PATH))
# ── F1. MmuMapMsg fabric latency ─────────────────────────────────────
def test_mmu_map_via_fabric_has_latency():
"""MmuMapMsg submitted through engine.submit() completes with latency > 0."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_0",
entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},),
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
# Fabric traversal must have non-zero latency
assert trace is not None
assert trace.get("total_ns", 0) > 0
# ── F2. MmuMapMsg fan-out ────────────────────────────────────────────
def test_mmu_map_reaches_all_pes_in_cube():
"""MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_1",
entries=({"va": va, "pa": pa, "size": size},),
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
# Verify all 8 PE_MMUs in cube 0 have the mapping
for pe_id in range(8):
mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu"
mmu_comp = engine._components[mmu_id]
assert mmu_comp.mmu.translate(va) == pa
# ── F3. Multiple MmuMapMsg entries ───────────────────────────────────
def test_mmu_map_multiple_entries():
"""MmuMapMsg with multiple entries installs all of them."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
)
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_2",
entries=entries,
target_cubes=(0,),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000
# ── F4. Cross-cube sharded: global mapping ───────────────────────────
def test_cross_cube_sharded_all_pes_get_global_mapping():
"""For sharded tensor across cubes (unique offsets), all PEs get all mappings."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
# Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096
entries = (
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1
)
# Broadcast to both cubes
msg = MmuMapMsg(
correlation_id="c0",
request_id="mmu_map_xc",
entries=entries,
target_cubes=(0, 1),
target_pe="all",
)
h = engine.submit(msg)
engine.wait(h)
# PE in cube0 can translate both cube0 and cube1 addresses
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local
assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote
# PE in cube1 can also translate both
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote
assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local
# ── F5. Replicate: local PA override ─────────────────────────────────
def test_replicate_local_pa_override():
"""For replicated tensor (same VA range), each cube's PEs see local PA."""
from kernbench.runtime_api.kernel import MmuMapMsg
engine = _engine()
va, size = 0x1_0000_0000, 4096
# Cube 0 gets its own PA
msg0 = MmuMapMsg(
correlation_id="c0",
request_id="mmu_rep_c0",
entries=({"va": va, "pa": 0xA000_0000, "size": size},),
target_cubes=(0,),
target_pe="all",
)
h0 = engine.submit(msg0)
engine.wait(h0)
# Cube 1 gets a different PA for the same VA
msg1 = MmuMapMsg(
correlation_id="c0",
request_id="mmu_rep_c1",
entries=({"va": va, "pa": 0xB000_0000, "size": size},),
target_cubes=(1,),
target_pe="all",
)
h1 = engine.submit(msg1)
engine.wait(h1)
# Cube 0 PEs translate to cube 0's PA
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_c0.mmu.translate(va) == 0xA000_0000
# Cube 1 PEs translate to cube 1's PA
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
assert mmu_c1.mmu.translate(va) == 0xB000_0000
# ── F7. Overlap detection ────────────────────────────────────────────
def test_detect_overlapping_shards():
"""Utility: detect if shards have overlapping VA ranges (replicate indicator)."""
from kernbench.runtime_api.tensor import TensorShard
# Sharded: unique offsets
sharded = [
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096),
]
offsets = [(s.offset_bytes, s.nbytes) for s in sharded]
assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets"
# Replicated: same offset
replicated = [
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0),
]
offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated]
assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets"
# ── F8. Regression: existing benchmarks still pass ───────────────────
def test_qkv_gemm_still_passes():
"""QKV GEMM benchmark completes successfully with VA/MMU enabled."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import BenchResult, DeviceSelector
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_regression",
spec=graph.spec,
)
from benches.qkv_gemm import run as bench_run
bench_run(ctx)
ctx.wait_all()
# If we get here without exception, the benchmark succeeded
+6 -6
View File
@@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn():
gemm.in_ports["src"] = simpy.Store(env)
gemm.start(env)
a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn():
math_comp.in_ports["src"] = simpy.Store(env)
math_comp.start(env)
x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = MathCmd(op="exp", inputs=(x,), out=out)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -777,7 +777,7 @@ def test_tl_ref_no_dma():
tl = TLContext(pe_id=0, dispatch_cycles=0)
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
assert handle.pa == 0x1000
assert handle.addr == 0x1000
assert handle.shape == (4, 4)
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"
+203
View File
@@ -0,0 +1,203 @@
"""Tests for PeMMU: per-PE virtual-to-physical address translation.
Validates:
T1. Basic map + translate
T2. Page-aligned dict lookup (O(1), multi-page range)
T3. Multiple tensor mappings accumulate
T4. unmap removes entries, translate raises PageFault
T5. PageFault on unmapped VA
T6. Identical mapping broadcast across multiple PEs
"""
import pytest
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
_2MB = 2 * 1024 * 1024
# ── T1. Basic map + translate ────────────────────────────────────────
def test_map_and_translate_basic():
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
def test_translate_preserves_offset():
"""translate(va + offset) returns pa + offset within a page."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
def test_multi_page_mapping():
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
mmu = PeMMU(page_size=_2MB)
va_base = 0x1_0000_0000
pa_base = 0x2_0000_0000
size = 8 * 1024 * 1024 # 8 MB = 4 pages
mmu.map(va=va_base, pa=pa_base, size=size)
# First page
assert mmu.translate(va_base) == pa_base
# Second page start
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
# Third page with offset
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
# Last byte of last page
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
def test_page_table_entry_count():
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
mmu = PeMMU(page_size=_2MB)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
assert mmu.num_entries == 4
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
assert mmu.num_entries == 5
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
def test_multiple_mappings_accumulate():
"""Two non-overlapping tensors → both translate correctly."""
mmu = PeMMU(page_size=4096)
# Tensor A
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
# Tensor B (different VA range)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
def test_mappings_do_not_interfere():
"""Adjacent VA ranges map to completely independent PA ranges."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
assert mmu.translate(0x0000_1000) == 0x0000_0000
# ── T4. unmap removes entries ────────────────────────────────────────
def test_unmap_removes_mapping():
"""After unmap, translate raises PageFault."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
def test_unmap_partial_range():
"""Unmap only part of a multi-page mapping; rest stays valid."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
assert mmu.num_entries == 2
# Unmap first page only
mmu.unmap(va=0x1000_0000, size=4096)
assert mmu.num_entries == 1
with pytest.raises(PageFault):
mmu.translate(0x1000_0000)
# Second page still valid
assert mmu.translate(0x1000_1000) == 0x2000_1000
def test_unmap_does_not_affect_other_mappings():
"""Unmapping tensor A does not affect tensor B."""
mmu = PeMMU(page_size=4096)
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
mmu.unmap(va=0x1_0000_0000, size=4096)
with pytest.raises(PageFault):
mmu.translate(0x1_0000_0000)
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
def test_pagefault_on_unmapped_va():
"""translate() on never-mapped VA raises PageFault."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault):
mmu.translate(0xDEAD_BEEF)
def test_pagefault_contains_va():
"""PageFault exception carries the faulting VA."""
mmu = PeMMU(page_size=4096)
with pytest.raises(PageFault, match="0xdeadbeef"):
mmu.translate(0xDEAD_BEEF)
# ── T6. Identical mapping broadcast across PEs ───────────────────────
def test_broadcast_same_mapping_to_all_pes():
"""All PEs receive the same full mapping → identical translate results."""
entries = [
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
]
num_pes = 8
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
# Broadcast: every PE gets the same entries
for mmu in mmus:
for va, pa, size in entries:
mmu.map(va=va, pa=pa, size=size)
# All PEs translate identically
for mmu in mmus:
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
def test_cross_pe_access_via_broadcast():
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
mmu_pe0 = PeMMU(page_size=4096)
# Full mapping includes PE3's shard
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
# PE0 accesses PE3's region → valid translation
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
assert pa == 0xD000_0100
# ── TLB overhead attribute ───────────────────────────────────────────
def test_tlb_overhead_default():
"""Default tlb_overhead_ns is 0."""
mmu = PeMMU()
assert mmu.overhead_ns == 0.0
def test_tlb_overhead_configurable():
"""tlb_overhead_ns is configurable."""
mmu = PeMMU(overhead_ns=0.5)
assert mmu.overhead_ns == 0.5
+193
View File
@@ -0,0 +1,193 @@
"""Tests for tensor free: del-based + context manager cleanup.
Validates:
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
TF3. Context manager cleans up all tensors on exit
TF4. del after context exit is safe (no crash)
TF5. Alloc-del-alloc cycle reuses VA and PA
TF6. del already-freed tensor is safe (no crash)
"""
import gc
import pytest
from pathlib import Path
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PageFault
from kernbench.policy.placement.dp import DPPolicy
from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_ctx():
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_free",
spec=graph.spec,
)
return ctx, engine
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
def test_allocator_free_hbm_reclaims_space():
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_hbm(4096)
used_after_alloc = a.hbm_used
a.free_hbm(pa1, 4096)
assert a.hbm_used == used_after_alloc - 4096
pa2 = a.alloc_hbm(4096)
assert pa2 is not None
def test_allocator_free_tcm_reclaims_space():
"""free_tcm returns TCM space."""
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
pa1 = a.alloc_tcm(256)
used_after_alloc = a.tcm_used
a.free_tcm(pa1, 256)
assert a.tcm_used == used_after_alloc - 256
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
def test_del_tensor_unmaps_mmu():
"""del tensor removes MMU mappings."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
va_base = t._handle.va_base
# Verify mapping exists
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
assert mmu_comp.mmu.translate(va_base) is not None
# Delete tensor
del t
gc.collect()
# Mapping should be gone
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
def test_del_tensor_reclaims_va():
"""del tensor returns VA space for reuse."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
va1 = t1._handle.va_base
del t1
gc.collect()
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
va2 = t2._handle.va_base
assert va2 == va1
# ── TF3. Context manager cleanup ─────────────────────────────────────
def test_context_manager_cleans_all():
"""Exiting context manager cleans up all tensors."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="ctx_mgr",
spec=graph.spec,
) as ctx:
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
va_base = t._handle.va_base
# After context exit, MMU mappings should be cleared
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
with pytest.raises(PageFault):
mmu_comp.mmu.translate(va_base)
# ── TF4. del after context exit is safe ───────────────────────────────
def test_del_after_context_exit_no_crash():
"""del tensor after context manager exit does not crash."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="safe_del",
spec=graph.spec,
)
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
# Simulate context going away
ctx.cleanup()
# del should not crash even though context already cleaned up
del t
gc.collect()
# No exception = pass
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
def test_alloc_del_cycle():
"""Multiple alloc-del cycles work correctly."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
for i in range(3):
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
assert t._handle is not None
assert t._handle.va_base > 0
del t
gc.collect()
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
def test_del_already_cleaned_tensor_no_crash():
"""del on a tensor whose handle is already None does not crash."""
ctx, engine = _make_ctx()
dp = DPPolicy(cube="replicate", pe="replicate")
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
ctx.cleanup() # clears all tensors
# t._handle is now None
del t # should not crash
gc.collect()
+11 -10
View File
@@ -17,31 +17,32 @@ def test_full_graph_node_count():
g = _graph()
# 1 switch
# + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn)
# + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
# + 16 cubes × (cube_comps + 8 PEs × 7 pe_comps))
# IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
# + 16 ucie_conn (4 ports × 4 connections)
# + 2 xbar_top/bot
# + 8 hbm_slices = 35
# = 1 + 2*(23 + 16*(35+48)) = 1 + 2*(23+1328) = 1 + 2702 = 2703
assert len(g.nodes) == 2703
# pe_comps: 7 (pe_cpu, pe_scheduler, pe_dma, pe_gemm, pe_math, pe_mmu, pe_tcm)
# = 1 + 2*(23 + 16*(35+56)) = 1 + 2*(23+1456) = 1 + 2958 = 2959
assert len(g.nodes) == 2959
def test_full_graph_edge_count():
g = _graph()
# Per cube: 184
# Per cube: 192
# PE-internal: 56
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8, noc→pe_mmu: 8
# xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8
# noc↔xbar_top: 2, noc↔xbar_bot: 2
# xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2
# xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2
# ucie: 64, m_cpu↔noc: 2, noc↔sram: 2
# Total: 56+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 184
# Total: 56+8+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 192
# IO edges per SIP: 77
# Per SIP: 16*184 + 48 inter-cube + 77 IO = 3069
# Total: 2 * 3069 = 6138
assert len(g.edges) == 6138
# Per SIP: 16*192 + 48 inter-cube + 77 IO = 3197
# Total: 2 * 3197 = 6394
assert len(g.edges) == 6394
# ── Full graph: specific nodes exist ─────────────────────────────────
@@ -267,7 +268,7 @@ def test_cube_view_pe_to_noc():
def test_pe_view_has_all_components():
v = _graph().pe_view
assert set(v.nodes.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
}
+1 -1
View File
@@ -23,7 +23,7 @@ def test_pe_template_components():
spec = _read_spec(TOPOLOGY_PATH)
comps = spec["cube"]["pe_template"]["components"]
assert set(comps.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
}
+3 -3
View File
@@ -34,7 +34,7 @@ def test_tl_load_generates_dma_read():
cmds = tl.commands
assert len(cmds) == 1
assert isinstance(cmds[0], DmaReadCmd)
assert cmds[0].src_pa == 0x1000
assert cmds[0].src_addr == 0x1000
assert cmds[0].nbytes == 32 * 64 * 2
@@ -47,7 +47,7 @@ def test_tl_store_generates_dma_write():
tl.store(0x2000, h)
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(cmds) == 1
assert cmds[0].dst_pa == 0x2000
assert cmds[0].dst_addr == 0x2000
assert cmds[0].nbytes == 16 * 16 * 4
@@ -148,7 +148,7 @@ def test_tl_composite_nonblocking():
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
assert len(comp_cmds) == 1
assert comp_cmds[0].op == "gemm"
assert comp_cmds[0].out_pa == 0x3000
assert comp_cmds[0].out_addr == 0x3000
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
+140
View File
@@ -0,0 +1,140 @@
"""Tests for VirtualAllocator: device-wide VA space management.
Validates:
T7. Basic VA allocation (contiguous, non-overlapping)
T8. VA free + reallocation (free-list reuse)
T9. VA space exhaustion raises AllocationError
"""
import pytest
from kernbench.policy.address.va_allocator import VirtualAllocator
_KB = 1024
_MB = 1024 * 1024
_GB = 1024 * 1024 * 1024
# ── T7. Basic VA allocation ──────────────────────────────────────────
def test_alloc_returns_aligned_va():
"""First allocation returns va_base."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
addr = va.alloc(4096)
assert addr == 0x1_0000_0000
def test_alloc_sequential_non_overlapping():
"""Two allocations return contiguous, non-overlapping VA ranges."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(8192)
assert a1 == 0x1_0000_0000
assert a2 == 0x1_0000_1000 # a1 + 4096
# No overlap
assert a2 >= a1 + 4096
def test_alloc_page_aligned():
"""Allocations are page-aligned even if requested size is not page-multiple."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(100) # < 1 page, but occupies 1 page
a2 = va.alloc(100)
assert a2 == 0x1_0000_1000 # aligned to next page
def test_alloc_large_contiguous():
"""Large allocation (multiple pages) is contiguous."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=2 * _MB)
addr = va.alloc(8 * _MB) # 4 pages
assert addr == 0x0
# Next alloc starts after 8 MB
addr2 = va.alloc(2 * _MB)
assert addr2 == 8 * _MB
# ── T8. VA free + reallocation ───────────────────────────────────────
def test_free_and_realloc():
"""Freed VA range can be reused by subsequent allocation."""
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
va.free(a1, 4096)
# New alloc should reuse a1's range
a3 = va.alloc(4096)
assert a3 == a1
def test_free_coalesce():
"""Freeing adjacent blocks allows larger reallocation."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
a3 = va.alloc(4096)
# Free first two (adjacent)
va.free(a1, 4096)
va.free(a2, 4096)
# Should be able to allocate 8192 contiguous from the freed region
a4 = va.alloc(8192)
assert a4 == a1 # reuses coalesced region
def test_free_out_of_order():
"""Free in non-sequential order still works."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
a3 = va.alloc(4096)
va.free(a2, 4096) # free middle
a4 = va.alloc(4096)
assert a4 == a2 # reuses middle slot
# ── T9. VA space exhaustion ──────────────────────────────────────────
def test_alloc_exhaustion():
"""Allocation beyond VA space raises AllocationError."""
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
va.alloc(4096)
va.alloc(4096)
with pytest.raises(Exception, match="[Aa]lloc|[Ee]xhaust|[Oo]ut of"):
va.alloc(4096)
def test_alloc_after_partial_free():
"""After freeing some, can allocate again within freed space."""
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
a1 = va.alloc(4096)
a2 = va.alloc(4096)
# Space is full
with pytest.raises(Exception):
va.alloc(4096)
# Free one, now can allocate again
va.free(a1, 4096)
a3 = va.alloc(4096)
assert a3 == a1
# ── Stats / inspection ───────────────────────────────────────────────
def test_used_and_total():
"""used and total properties reflect allocation state."""
va = VirtualAllocator(va_base=0x0, va_size=1 * _MB, page_size=4096)
assert va.used == 0
assert va.total == 1 * _MB
va.alloc(4096)
assert va.used == 4096
va.alloc(8192)
assert va.used == 4096 + 8192 # page-aligned: 4096 + 8192 = 12288
+230
View File
@@ -0,0 +1,230 @@
"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
Validates:
T10. TensorHandle has va_base; TensorShard does NOT have va field
T11. deploy_tensor allocates VA + creates mapping entries
T12. Tensor.va returns the tensor's VA base
T13. tl.load/tl.store generate DMA commands with VA (not PA)
T14. Kernel VA-based offset calculation flows through DMA commands
"""
import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, ShardSpec
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
)
from kernbench.runtime_api.kernel import TensorArgShard
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.triton_emu.tl_context import TLContext, run_kernel
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
def _make_va_allocator() -> VirtualAllocator:
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
# ── T10. TensorHandle has va_base ────────────────────────────────────
def test_tensor_handle_has_va_base():
"""TensorHandle must have a 'va_base' field."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
assert th.va_base == 0x1_0000_0000
def test_tensor_handle_va_base_immutable():
"""TensorHandle.va_base is immutable (frozen dataclass)."""
th = TensorHandle(
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
shards=(), va_base=0x1_0000_0000,
)
with pytest.raises(AttributeError):
th.va_base = 0x2_0000_0000 # type: ignore[misc]
def test_tensor_shard_no_va_field():
"""TensorShard should NOT have a va field — va is derived from
TensorHandle.va_base + shard.offset_bytes."""
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert th.va_base is not None
assert th.va_base > 0
def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Each shard's VA is derivable: va_base + offset_bytes
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert shard_va > 0
def test_deploy_tensor_registers_mmu_mappings():
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
allocs = _make_allocators()
va_alloc = _make_va_allocator()
mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
# Every MMU should have entries (broadcast)
for mmu in mmus.values():
assert mmu.num_entries > 0
# Each shard's derived VA should translate to its PA in every MMU
for mmu in mmus.values():
for s in th.shards:
shard_va = th.va_base + s.offset_bytes
assert mmu.translate(shard_va) == s.pa
# ── T12. Tensor.va property ──────────────────────────────────────────
def test_tensor_va_property():
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
from kernbench.runtime_api.tensor import Tensor
allocs = _make_allocators(1)
va_alloc = _make_va_allocator()
mmus = _make_mmus(1)
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test")
t._handle = deploy_tensor(
name="test",
shape=(2048,),
dtype="f16",
placement=placement,
allocators=allocs,
va_allocator=va_alloc,
mmus=mmus,
)
assert t.va > 0
assert t.va == t._handle.va_base
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
def test_tl_load_uses_va_in_dma_cmd():
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
tl = TLContext(dispatch_cycles=0)
va_ptr = 0x1_0000_0000
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
# The DMA command should carry the VA
assert dma_cmds[0].src_addr == va_ptr
def test_tl_store_uses_va_in_dma_cmd():
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
tl = TLContext(dispatch_cycles=0)
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
va_out = 0x2_0000_0000
tl.store(va_out, h)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(dma_cmds) == 1
assert dma_cmds[0].dst_addr == va_out
# ── T14. Kernel VA offset calculation ─────────────────────────────────
def test_kernel_va_offset_in_dma():
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
pid = tl.program_id(0)
elem_bytes = 2 # f16
offset = pid * BLOCK_SIZE * elem_bytes
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
va_base = 0x1_0000_0000
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
run_kernel(tiled_kernel, tl, a_ptr=va_base)
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
assert len(dma_cmds) == 1
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
assert dma_cmds[0].src_addr == expected_va