Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,226 @@
|
||||
"""Tests for PE_MMU component integration and MmuMapMsg fabric path.
|
||||
|
||||
Validates:
|
||||
T15-a. PE_MMU component registered in ComponentRegistry
|
||||
T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table
|
||||
T15-c. PE_DMA translates VA→PA via mmu before routing
|
||||
T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields
|
||||
T17. PE_CPU passes VA (not PA) to kernel when VA is available
|
||||
T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA
|
||||
"""
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||
except ImportError:
|
||||
pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True)
|
||||
|
||||
|
||||
# ── T16. MmuMapMsg / MmuUnmapMsg message types ──────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_msg_fields():
|
||||
"""MmuMapMsg carries VA→PA mapping entries for broadcast."""
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=(
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
|
||||
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
|
||||
),
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
assert msg.msg_type == "mmu_map"
|
||||
assert len(msg.entries) == 2
|
||||
assert msg.entries[0]["va"] == 0x1_0000_0000
|
||||
assert msg.entries[0]["pa"] == 0xA000_0000
|
||||
assert msg.entries[0]["size"] == 4096
|
||||
|
||||
|
||||
def test_mmu_map_msg_immutable():
|
||||
"""MmuMapMsg is frozen."""
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=(),
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
msg.entries = () # type: ignore[misc]
|
||||
|
||||
|
||||
def test_mmu_unmap_msg_fields():
|
||||
"""MmuUnmapMsg carries VA ranges to unmap."""
|
||||
msg = MmuUnmapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=(
|
||||
{"va": 0x1_0000_0000, "size": 4096},
|
||||
),
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
assert msg.msg_type == "mmu_unmap"
|
||||
assert len(msg.entries) == 1
|
||||
assert msg.entries[0]["va"] == 0x1_0000_0000
|
||||
|
||||
|
||||
# ── T15-a. PE_MMU component registry ────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_mmu_registry():
|
||||
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
node = Node(
|
||||
id="sip0.cube0.pe0.pe_mmu",
|
||||
kind="pe_mmu",
|
||||
impl="pe_mmu_v1",
|
||||
pos_mm=None,
|
||||
attrs={"tlb_overhead_ns": 0.5},
|
||||
)
|
||||
comp = ComponentRegistry.create(node)
|
||||
assert isinstance(comp, PeMmuComponent)
|
||||
|
||||
|
||||
# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ─────────
|
||||
|
||||
|
||||
def test_pe_mmu_processes_map_msg():
|
||||
"""PE_MMU component receives MmuMapMsg → translate works."""
|
||||
import simpy
|
||||
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
env = simpy.Environment()
|
||||
node = Node(
|
||||
id="sip0.cube0.pe0.pe_mmu",
|
||||
kind="pe_mmu",
|
||||
impl="pe_mmu_v1",
|
||||
pos_mm=None,
|
||||
attrs={"tlb_overhead_ns": 0.5, "page_size": 4096},
|
||||
)
|
||||
comp = PeMmuComponent(node)
|
||||
comp.in_ports["src"] = simpy.Store(env)
|
||||
comp.start(env)
|
||||
|
||||
# Submit MmuMapMsg via inbox
|
||||
map_msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=(
|
||||
{"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096},
|
||||
),
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
done = env.event()
|
||||
txn = Transaction(
|
||||
request=map_msg,
|
||||
path=["sip0.cube0.pe0.pe_mmu"],
|
||||
step=0, nbytes=0, done=done,
|
||||
)
|
||||
|
||||
def inject():
|
||||
yield comp._inbox.put(txn)
|
||||
|
||||
env.process(inject())
|
||||
env.run(until=100)
|
||||
|
||||
# After processing, the MMU's translate should work
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
mmu = comp.mmu # the underlying PeMMU utility object
|
||||
assert isinstance(mmu, PeMMU)
|
||||
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||
|
||||
|
||||
# ── T15-c. PE_DMA uses MMU translate ────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_dma_translates_va():
|
||||
"""PE_DMA.handle_command calls mmu.translate(va) → PA before routing.
|
||||
|
||||
This test validates the contract: after Phase 2, DmaReadCmd carries VA,
|
||||
and PE_DMA must translate it to PA via the MMU before resolving the
|
||||
HBM node path.
|
||||
"""
|
||||
# This test validates the interface contract. Full integration test
|
||||
# requires the engine wiring which is validated in test_engine.
|
||||
# Here we check that PE_DMA has an mmu attribute it can call.
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
node = Node(
|
||||
id="sip0.cube0.pe0.pe_dma",
|
||||
kind="pe_dma",
|
||||
impl="pe_dma_v1",
|
||||
pos_mm=None,
|
||||
attrs={"rd_engines": 1, "wr_engines": 1},
|
||||
)
|
||||
comp = PeDmaComponent(node)
|
||||
|
||||
# PE_DMA must have a way to access the MMU (via ctx or direct reference)
|
||||
# The exact wiring mechanism is flexible, but the attribute must exist
|
||||
assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or (
|
||||
hasattr(comp, 'ctx') and comp.ctx is not None
|
||||
), "PE_DMA must have access to PE_MMU for VA translation"
|
||||
|
||||
|
||||
# ── T17. PE_CPU passes VA to kernel ──────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_cpu_uses_va_base_from_tensor_arg():
|
||||
"""PE_CPU should use TensorArg.va_base for kernel pointer args.
|
||||
|
||||
After Phase 2, TensorArg carries va_base alongside shards.
|
||||
PE_CPU extracts va_base and passes it to the kernel function
|
||||
so kernels operate on VA (not PA).
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
|
||||
|
||||
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000,
|
||||
nbytes=4096, offset_bytes=0)
|
||||
targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000)
|
||||
|
||||
# PE_CPU should use targ.va_base for kernel pointer arg
|
||||
assert targ.va_base == 0x1_0000_0000
|
||||
# PA still accessible via shard for direct-PA operations (IPCQ etc.)
|
||||
assert shard.pa == 0x1000
|
||||
|
||||
|
||||
# ── T18. MmuMapMsg broadcast pattern ─────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_msg_broadcast_target():
|
||||
"""MmuMapMsg with target_pe='all' is a broadcast to all PEs."""
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},),
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
assert msg.target_pe == "all"
|
||||
assert msg.target_cubes == "all"
|
||||
|
||||
|
||||
def test_mmu_map_msg_same_entries_all_pes():
|
||||
"""All PEs in a broadcast receive identical entries (not per-PE splits)."""
|
||||
entries = (
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192},
|
||||
{"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192},
|
||||
)
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
entries=entries,
|
||||
target_cubes="all",
|
||||
target_pe="all",
|
||||
)
|
||||
# The message carries the full mapping — every PE receives exactly this
|
||||
assert msg.entries == entries
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Tests for MmuMapMsg fabric path and cross-cube mapping.
|
||||
|
||||
Validates:
|
||||
F1. MmuMapMsg traverses fabric: latency > 0 (not sideband)
|
||||
F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs
|
||||
F3. After MmuMapMsg, PE_MMU has correct mappings
|
||||
F4. Cross-cube sharded tensor: all PEs get global mappings
|
||||
F5. Replicate tensor: each PE gets own cube's PA (local override)
|
||||
F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM
|
||||
F7. Overlap detection: replicate vs sharded identified correctly
|
||||
F8. Existing regression: PA-only benchmarks still pass
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec
|
||||
from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _engine():
|
||||
return GraphEngine(load_topology(TOPOLOGY_PATH))
|
||||
|
||||
|
||||
# ── F1. MmuMapMsg fabric latency ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_via_fabric_has_latency():
|
||||
"""MmuMapMsg submitted through engine.submit() completes with latency > 0."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_0",
|
||||
entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
# Fabric traversal must have non-zero latency
|
||||
assert trace is not None
|
||||
assert trace.get("total_ns", 0) > 0
|
||||
|
||||
|
||||
# ── F2. MmuMapMsg fan-out ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_reaches_all_pes_in_cube():
|
||||
"""MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_1",
|
||||
entries=({"va": va, "pa": pa, "size": size},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
# Verify all 8 PE_MMUs in cube 0 have the mapping
|
||||
for pe_id in range(8):
|
||||
mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu"
|
||||
mmu_comp = engine._components[mmu_id]
|
||||
assert mmu_comp.mmu.translate(va) == pa
|
||||
|
||||
|
||||
# ── F3. Multiple MmuMapMsg entries ───────────────────────────────────
|
||||
|
||||
|
||||
def test_mmu_map_multiple_entries():
|
||||
"""MmuMapMsg with multiple entries installs all of them."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
entries = (
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
|
||||
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
|
||||
)
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_2",
|
||||
entries=entries,
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||
assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000
|
||||
|
||||
|
||||
# ── F4. Cross-cube sharded: global mapping ───────────────────────────
|
||||
|
||||
|
||||
def test_cross_cube_sharded_all_pes_get_global_mapping():
|
||||
"""For sharded tensor across cubes (unique offsets), all PEs get all mappings."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
# Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096
|
||||
entries = (
|
||||
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0
|
||||
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1
|
||||
)
|
||||
# Broadcast to both cubes
|
||||
msg = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_map_xc",
|
||||
entries=entries,
|
||||
target_cubes=(0, 1),
|
||||
target_pe="all",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
|
||||
# PE in cube0 can translate both cube0 and cube1 addresses
|
||||
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local
|
||||
assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote
|
||||
|
||||
# PE in cube1 can also translate both
|
||||
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||
assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote
|
||||
assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local
|
||||
|
||||
|
||||
# ── F5. Replicate: local PA override ─────────────────────────────────
|
||||
|
||||
|
||||
def test_replicate_local_pa_override():
|
||||
"""For replicated tensor (same VA range), each cube's PEs see local PA."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||
|
||||
engine = _engine()
|
||||
va, size = 0x1_0000_0000, 4096
|
||||
|
||||
# Cube 0 gets its own PA
|
||||
msg0 = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_rep_c0",
|
||||
entries=({"va": va, "pa": 0xA000_0000, "size": size},),
|
||||
target_cubes=(0,),
|
||||
target_pe="all",
|
||||
)
|
||||
h0 = engine.submit(msg0)
|
||||
engine.wait(h0)
|
||||
|
||||
# Cube 1 gets a different PA for the same VA
|
||||
msg1 = MmuMapMsg(
|
||||
correlation_id="c0",
|
||||
request_id="mmu_rep_c1",
|
||||
entries=({"va": va, "pa": 0xB000_0000, "size": size},),
|
||||
target_cubes=(1,),
|
||||
target_pe="all",
|
||||
)
|
||||
h1 = engine.submit(msg1)
|
||||
engine.wait(h1)
|
||||
|
||||
# Cube 0 PEs translate to cube 0's PA
|
||||
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_c0.mmu.translate(va) == 0xA000_0000
|
||||
|
||||
# Cube 1 PEs translate to cube 1's PA
|
||||
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||
assert mmu_c1.mmu.translate(va) == 0xB000_0000
|
||||
|
||||
|
||||
# ── F7. Overlap detection ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_overlapping_shards():
|
||||
"""Utility: detect if shards have overlapping VA ranges (replicate indicator)."""
|
||||
from kernbench.runtime_api.tensor import TensorShard
|
||||
|
||||
# Sharded: unique offsets
|
||||
sharded = [
|
||||
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||
TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096),
|
||||
]
|
||||
offsets = [(s.offset_bytes, s.nbytes) for s in sharded]
|
||||
assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets"
|
||||
|
||||
# Replicated: same offset
|
||||
replicated = [
|
||||
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||
TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0),
|
||||
]
|
||||
offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated]
|
||||
assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets"
|
||||
|
||||
|
||||
# ── F8. Regression: existing benchmarks still pass ───────────────────
|
||||
|
||||
|
||||
def test_qkv_gemm_still_passes():
|
||||
"""QKV GEMM benchmark completes successfully with VA/MMU enabled."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import BenchResult, DeviceSelector
|
||||
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="test_regression",
|
||||
spec=graph.spec,
|
||||
)
|
||||
from benches.qkv_gemm import run as bench_run
|
||||
bench_run(ctx)
|
||||
ctx.wait_all()
|
||||
# If we get here without exception, the benchmark succeeded
|
||||
@@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn():
|
||||
gemm.in_ports["src"] = simpy.Store(env)
|
||||
gemm.start(env)
|
||||
|
||||
a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64)
|
||||
b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64)
|
||||
out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64)
|
||||
b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64)
|
||||
out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
|
||||
done = env.event()
|
||||
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
||||
@@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn():
|
||||
math_comp.in_ports["src"] = simpy.Store(env)
|
||||
math_comp.start(env)
|
||||
|
||||
x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
||||
done = env.event()
|
||||
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
||||
@@ -777,7 +777,7 @@ def test_tl_ref_no_dma():
|
||||
|
||||
tl = TLContext(pe_id=0, dispatch_cycles=0)
|
||||
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
|
||||
assert handle.pa == 0x1000
|
||||
assert handle.addr == 0x1000
|
||||
assert handle.shape == (4, 4)
|
||||
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"
|
||||
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Tests for PeMMU: per-PE virtual-to-physical address translation.
|
||||
|
||||
Validates:
|
||||
T1. Basic map + translate
|
||||
T2. Page-aligned dict lookup (O(1), multi-page range)
|
||||
T3. Multiple tensor mappings accumulate
|
||||
T4. unmap removes entries, translate raises PageFault
|
||||
T5. PageFault on unmapped VA
|
||||
T6. Identical mapping broadcast across multiple PEs
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
|
||||
|
||||
_2MB = 2 * 1024 * 1024
|
||||
|
||||
|
||||
# ── T1. Basic map + translate ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_map_and_translate_basic():
|
||||
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||
|
||||
|
||||
def test_translate_preserves_offset():
|
||||
"""translate(va + offset) returns pa + offset within a page."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
|
||||
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
|
||||
|
||||
|
||||
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
|
||||
|
||||
|
||||
def test_multi_page_mapping():
|
||||
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
|
||||
mmu = PeMMU(page_size=_2MB)
|
||||
va_base = 0x1_0000_0000
|
||||
pa_base = 0x2_0000_0000
|
||||
size = 8 * 1024 * 1024 # 8 MB = 4 pages
|
||||
|
||||
mmu.map(va=va_base, pa=pa_base, size=size)
|
||||
|
||||
# First page
|
||||
assert mmu.translate(va_base) == pa_base
|
||||
# Second page start
|
||||
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
|
||||
# Third page with offset
|
||||
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
|
||||
# Last byte of last page
|
||||
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
|
||||
|
||||
|
||||
def test_page_table_entry_count():
|
||||
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
|
||||
mmu = PeMMU(page_size=_2MB)
|
||||
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
|
||||
assert mmu.num_entries == 4
|
||||
|
||||
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
|
||||
assert mmu.num_entries == 5
|
||||
|
||||
|
||||
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
|
||||
|
||||
|
||||
def test_multiple_mappings_accumulate():
|
||||
"""Two non-overlapping tensors → both translate correctly."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
# Tensor A
|
||||
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
||||
# Tensor B (different VA range)
|
||||
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
||||
|
||||
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
||||
|
||||
|
||||
def test_mappings_do_not_interfere():
|
||||
"""Adjacent VA ranges map to completely independent PA ranges."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
|
||||
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
|
||||
|
||||
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
|
||||
assert mmu.translate(0x0000_1000) == 0x0000_0000
|
||||
|
||||
|
||||
# ── T4. unmap removes entries ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_unmap_removes_mapping():
|
||||
"""After unmap, translate raises PageFault."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||
|
||||
mmu.unmap(va=0x1_0000_0000, size=4096)
|
||||
with pytest.raises(PageFault):
|
||||
mmu.translate(0x1_0000_0000)
|
||||
|
||||
|
||||
def test_unmap_partial_range():
|
||||
"""Unmap only part of a multi-page mapping; rest stays valid."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
|
||||
assert mmu.num_entries == 2
|
||||
|
||||
# Unmap first page only
|
||||
mmu.unmap(va=0x1000_0000, size=4096)
|
||||
assert mmu.num_entries == 1
|
||||
|
||||
with pytest.raises(PageFault):
|
||||
mmu.translate(0x1000_0000)
|
||||
# Second page still valid
|
||||
assert mmu.translate(0x1000_1000) == 0x2000_1000
|
||||
|
||||
|
||||
def test_unmap_does_not_affect_other_mappings():
|
||||
"""Unmapping tensor A does not affect tensor B."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
||||
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
||||
|
||||
mmu.unmap(va=0x1_0000_0000, size=4096)
|
||||
with pytest.raises(PageFault):
|
||||
mmu.translate(0x1_0000_0000)
|
||||
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
||||
|
||||
|
||||
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_pagefault_on_unmapped_va():
|
||||
"""translate() on never-mapped VA raises PageFault."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
with pytest.raises(PageFault):
|
||||
mmu.translate(0xDEAD_BEEF)
|
||||
|
||||
|
||||
def test_pagefault_contains_va():
|
||||
"""PageFault exception carries the faulting VA."""
|
||||
mmu = PeMMU(page_size=4096)
|
||||
with pytest.raises(PageFault, match="0xdeadbeef"):
|
||||
mmu.translate(0xDEAD_BEEF)
|
||||
|
||||
|
||||
# ── T6. Identical mapping broadcast across PEs ───────────────────────
|
||||
|
||||
|
||||
def test_broadcast_same_mapping_to_all_pes():
|
||||
"""All PEs receive the same full mapping → identical translate results."""
|
||||
entries = [
|
||||
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
|
||||
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
|
||||
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
|
||||
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
|
||||
]
|
||||
num_pes = 8
|
||||
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
|
||||
|
||||
# Broadcast: every PE gets the same entries
|
||||
for mmu in mmus:
|
||||
for va, pa, size in entries:
|
||||
mmu.map(va=va, pa=pa, size=size)
|
||||
|
||||
# All PEs translate identically
|
||||
for mmu in mmus:
|
||||
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
|
||||
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
|
||||
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
|
||||
|
||||
|
||||
def test_cross_pe_access_via_broadcast():
|
||||
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
|
||||
mmu_pe0 = PeMMU(page_size=4096)
|
||||
# Full mapping includes PE3's shard
|
||||
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
|
||||
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
|
||||
|
||||
# PE0 accesses PE3's region → valid translation
|
||||
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
|
||||
assert pa == 0xD000_0100
|
||||
|
||||
|
||||
# ── TLB overhead attribute ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tlb_overhead_default():
|
||||
"""Default tlb_overhead_ns is 0."""
|
||||
mmu = PeMMU()
|
||||
assert mmu.overhead_ns == 0.0
|
||||
|
||||
|
||||
def test_tlb_overhead_configurable():
|
||||
"""tlb_overhead_ns is configurable."""
|
||||
mmu = PeMMU(overhead_ns=0.5)
|
||||
assert mmu.overhead_ns == 0.5
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Tests for tensor free: del-based + context manager cleanup.
|
||||
|
||||
Validates:
|
||||
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
|
||||
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
|
||||
TF3. Context manager cleans up all tensors on exit
|
||||
TF4. del after context exit is safe (no crash)
|
||||
TF5. Alloc-del-alloc cycle reuses VA and PA
|
||||
TF6. del already-freed tensor is safe (no crash)
|
||||
"""
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PageFault
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_ctx():
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="test_free",
|
||||
spec=graph.spec,
|
||||
)
|
||||
return ctx, engine
|
||||
|
||||
|
||||
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_allocator_free_hbm_reclaims_space():
|
||||
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
|
||||
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||
pa1 = a.alloc_hbm(4096)
|
||||
used_after_alloc = a.hbm_used
|
||||
a.free_hbm(pa1, 4096)
|
||||
assert a.hbm_used == used_after_alloc - 4096
|
||||
pa2 = a.alloc_hbm(4096)
|
||||
assert pa2 is not None
|
||||
|
||||
|
||||
def test_allocator_free_tcm_reclaims_space():
|
||||
"""free_tcm returns TCM space."""
|
||||
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||
pa1 = a.alloc_tcm(256)
|
||||
used_after_alloc = a.tcm_used
|
||||
a.free_tcm(pa1, 256)
|
||||
assert a.tcm_used == used_after_alloc - 256
|
||||
|
||||
|
||||
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
|
||||
|
||||
|
||||
def test_del_tensor_unmaps_mmu():
|
||||
"""del tensor removes MMU mappings."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
|
||||
va_base = t._handle.va_base
|
||||
|
||||
# Verify mapping exists
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
assert mmu_comp.mmu.translate(va_base) is not None
|
||||
|
||||
# Delete tensor
|
||||
del t
|
||||
gc.collect()
|
||||
|
||||
# Mapping should be gone
|
||||
with pytest.raises(PageFault):
|
||||
mmu_comp.mmu.translate(va_base)
|
||||
|
||||
|
||||
def test_del_tensor_reclaims_va():
|
||||
"""del tensor returns VA space for reuse."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
|
||||
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
|
||||
va1 = t1._handle.va_base
|
||||
|
||||
del t1
|
||||
gc.collect()
|
||||
|
||||
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
|
||||
va2 = t2._handle.va_base
|
||||
assert va2 == va1
|
||||
|
||||
|
||||
# ── TF3. Context manager cleanup ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_context_manager_cleans_all():
|
||||
"""Exiting context manager cleans up all tensors."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="ctx_mgr",
|
||||
spec=graph.spec,
|
||||
) as ctx:
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
|
||||
va_base = t._handle.va_base
|
||||
|
||||
# After context exit, MMU mappings should be cleared
|
||||
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||
with pytest.raises(PageFault):
|
||||
mmu_comp.mmu.translate(va_base)
|
||||
|
||||
|
||||
# ── TF4. del after context exit is safe ───────────────────────────────
|
||||
|
||||
|
||||
def test_del_after_context_exit_no_crash():
|
||||
"""del tensor after context manager exit does not crash."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="safe_del",
|
||||
spec=graph.spec,
|
||||
)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
|
||||
|
||||
# Simulate context going away
|
||||
ctx.cleanup()
|
||||
|
||||
# del should not crash even though context already cleaned up
|
||||
del t
|
||||
gc.collect()
|
||||
# No exception = pass
|
||||
|
||||
|
||||
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_alloc_del_cycle():
|
||||
"""Multiple alloc-del cycles work correctly."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
|
||||
for i in range(3):
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
|
||||
assert t._handle is not None
|
||||
assert t._handle.va_base > 0
|
||||
del t
|
||||
gc.collect()
|
||||
|
||||
|
||||
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
|
||||
|
||||
|
||||
def test_del_already_cleaned_tensor_no_crash():
|
||||
"""del on a tensor whose handle is already None does not crash."""
|
||||
ctx, engine = _make_ctx()
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
|
||||
|
||||
ctx.cleanup() # clears all tensors
|
||||
# t._handle is now None
|
||||
del t # should not crash
|
||||
gc.collect()
|
||||
@@ -17,31 +17,32 @@ def test_full_graph_node_count():
|
||||
g = _graph()
|
||||
# 1 switch
|
||||
# + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn)
|
||||
# + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
|
||||
# + 16 cubes × (cube_comps + 8 PEs × 7 pe_comps))
|
||||
# IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23
|
||||
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
|
||||
# + 16 ucie_conn (4 ports × 4 connections)
|
||||
# + 2 xbar_top/bot
|
||||
# + 8 hbm_slices = 35
|
||||
# = 1 + 2*(23 + 16*(35+48)) = 1 + 2*(23+1328) = 1 + 2702 = 2703
|
||||
assert len(g.nodes) == 2703
|
||||
# pe_comps: 7 (pe_cpu, pe_scheduler, pe_dma, pe_gemm, pe_math, pe_mmu, pe_tcm)
|
||||
# = 1 + 2*(23 + 16*(35+56)) = 1 + 2*(23+1456) = 1 + 2958 = 2959
|
||||
assert len(g.nodes) == 2959
|
||||
|
||||
|
||||
def test_full_graph_edge_count():
|
||||
g = _graph()
|
||||
# Per cube: 184
|
||||
# Per cube: 192
|
||||
# PE-internal: 56
|
||||
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8
|
||||
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8, noc→pe_mmu: 8
|
||||
# xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8
|
||||
# noc↔xbar_top: 2, noc↔xbar_bot: 2
|
||||
# xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2
|
||||
# xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2
|
||||
# ucie: 64, m_cpu↔noc: 2, noc↔sram: 2
|
||||
# Total: 56+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 184
|
||||
# Total: 56+8+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 192
|
||||
# IO edges per SIP: 77
|
||||
# Per SIP: 16*184 + 48 inter-cube + 77 IO = 3069
|
||||
# Total: 2 * 3069 = 6138
|
||||
assert len(g.edges) == 6138
|
||||
# Per SIP: 16*192 + 48 inter-cube + 77 IO = 3197
|
||||
# Total: 2 * 3197 = 6394
|
||||
assert len(g.edges) == 6394
|
||||
|
||||
|
||||
# ── Full graph: specific nodes exist ─────────────────────────────────
|
||||
@@ -267,7 +268,7 @@ def test_cube_view_pe_to_noc():
|
||||
def test_pe_view_has_all_components():
|
||||
v = _graph().pe_view
|
||||
assert set(v.nodes.keys()) == {
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def test_pe_template_components():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
comps = spec["cube"]["pe_template"]["components"]
|
||||
assert set(comps.keys()) == {
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def test_tl_load_generates_dma_read():
|
||||
cmds = tl.commands
|
||||
assert len(cmds) == 1
|
||||
assert isinstance(cmds[0], DmaReadCmd)
|
||||
assert cmds[0].src_pa == 0x1000
|
||||
assert cmds[0].src_addr == 0x1000
|
||||
assert cmds[0].nbytes == 32 * 64 * 2
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def test_tl_store_generates_dma_write():
|
||||
tl.store(0x2000, h)
|
||||
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
assert len(cmds) == 1
|
||||
assert cmds[0].dst_pa == 0x2000
|
||||
assert cmds[0].dst_addr == 0x2000
|
||||
assert cmds[0].nbytes == 16 * 16 * 4
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ def test_tl_composite_nonblocking():
|
||||
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
||||
assert len(comp_cmds) == 1
|
||||
assert comp_cmds[0].op == "gemm"
|
||||
assert comp_cmds[0].out_pa == 0x3000
|
||||
assert comp_cmds[0].out_addr == 0x3000
|
||||
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Tests for VirtualAllocator: device-wide VA space management.
|
||||
|
||||
Validates:
|
||||
T7. Basic VA allocation (contiguous, non-overlapping)
|
||||
T8. VA free + reallocation (free-list reuse)
|
||||
T9. VA space exhaustion raises AllocationError
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
|
||||
_KB = 1024
|
||||
_MB = 1024 * 1024
|
||||
_GB = 1024 * 1024 * 1024
|
||||
|
||||
|
||||
# ── T7. Basic VA allocation ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_alloc_returns_aligned_va():
|
||||
"""First allocation returns va_base."""
|
||||
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||
addr = va.alloc(4096)
|
||||
assert addr == 0x1_0000_0000
|
||||
|
||||
|
||||
def test_alloc_sequential_non_overlapping():
|
||||
"""Two allocations return contiguous, non-overlapping VA ranges."""
|
||||
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||
a1 = va.alloc(4096)
|
||||
a2 = va.alloc(8192)
|
||||
assert a1 == 0x1_0000_0000
|
||||
assert a2 == 0x1_0000_1000 # a1 + 4096
|
||||
# No overlap
|
||||
assert a2 >= a1 + 4096
|
||||
|
||||
|
||||
def test_alloc_page_aligned():
|
||||
"""Allocations are page-aligned even if requested size is not page-multiple."""
|
||||
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||
a1 = va.alloc(100) # < 1 page, but occupies 1 page
|
||||
a2 = va.alloc(100)
|
||||
assert a2 == 0x1_0000_1000 # aligned to next page
|
||||
|
||||
|
||||
def test_alloc_large_contiguous():
|
||||
"""Large allocation (multiple pages) is contiguous."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=2 * _MB)
|
||||
addr = va.alloc(8 * _MB) # 4 pages
|
||||
assert addr == 0x0
|
||||
# Next alloc starts after 8 MB
|
||||
addr2 = va.alloc(2 * _MB)
|
||||
assert addr2 == 8 * _MB
|
||||
|
||||
|
||||
# ── T8. VA free + reallocation ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_free_and_realloc():
|
||||
"""Freed VA range can be reused by subsequent allocation."""
|
||||
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||
a1 = va.alloc(4096)
|
||||
a2 = va.alloc(4096)
|
||||
va.free(a1, 4096)
|
||||
|
||||
# New alloc should reuse a1's range
|
||||
a3 = va.alloc(4096)
|
||||
assert a3 == a1
|
||||
|
||||
|
||||
def test_free_coalesce():
|
||||
"""Freeing adjacent blocks allows larger reallocation."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
|
||||
a1 = va.alloc(4096)
|
||||
a2 = va.alloc(4096)
|
||||
a3 = va.alloc(4096)
|
||||
|
||||
# Free first two (adjacent)
|
||||
va.free(a1, 4096)
|
||||
va.free(a2, 4096)
|
||||
|
||||
# Should be able to allocate 8192 contiguous from the freed region
|
||||
a4 = va.alloc(8192)
|
||||
assert a4 == a1 # reuses coalesced region
|
||||
|
||||
|
||||
def test_free_out_of_order():
|
||||
"""Free in non-sequential order still works."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
|
||||
a1 = va.alloc(4096)
|
||||
a2 = va.alloc(4096)
|
||||
a3 = va.alloc(4096)
|
||||
|
||||
va.free(a2, 4096) # free middle
|
||||
a4 = va.alloc(4096)
|
||||
assert a4 == a2 # reuses middle slot
|
||||
|
||||
|
||||
# ── T9. VA space exhaustion ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_alloc_exhaustion():
|
||||
"""Allocation beyond VA space raises AllocationError."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
|
||||
va.alloc(4096)
|
||||
va.alloc(4096)
|
||||
with pytest.raises(Exception, match="[Aa]lloc|[Ee]xhaust|[Oo]ut of"):
|
||||
va.alloc(4096)
|
||||
|
||||
|
||||
def test_alloc_after_partial_free():
|
||||
"""After freeing some, can allocate again within freed space."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
|
||||
a1 = va.alloc(4096)
|
||||
a2 = va.alloc(4096)
|
||||
|
||||
# Space is full
|
||||
with pytest.raises(Exception):
|
||||
va.alloc(4096)
|
||||
|
||||
# Free one, now can allocate again
|
||||
va.free(a1, 4096)
|
||||
a3 = va.alloc(4096)
|
||||
assert a3 == a1
|
||||
|
||||
|
||||
# ── Stats / inspection ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_used_and_total():
|
||||
"""used and total properties reflect allocation state."""
|
||||
va = VirtualAllocator(va_base=0x0, va_size=1 * _MB, page_size=4096)
|
||||
assert va.used == 0
|
||||
assert va.total == 1 * _MB
|
||||
|
||||
va.alloc(4096)
|
||||
assert va.used == 4096
|
||||
|
||||
va.alloc(8192)
|
||||
assert va.used == 4096 + 8192 # page-aligned: 4096 + 8192 = 12288
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
|
||||
|
||||
Validates:
|
||||
T10. TensorHandle has va_base; TensorShard does NOT have va field
|
||||
T11. deploy_tensor allocates VA + creates mapping entries
|
||||
T12. Tensor.va returns the tensor's VA base
|
||||
T13. tl.load/tl.store generate DMA commands with VA (not PA)
|
||||
T14. Kernel VA-based offset calculation flows through DMA commands
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import column_wise, ShardSpec
|
||||
from kernbench.runtime_api.tensor import (
|
||||
TensorHandle,
|
||||
TensorShard,
|
||||
deploy_tensor,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import TensorArgShard
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
|
||||
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
|
||||
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
|
||||
|
||||
|
||||
def _make_va_allocator() -> VirtualAllocator:
|
||||
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
|
||||
|
||||
|
||||
# ── T10. TensorHandle has va_base ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_handle_has_va_base():
|
||||
"""TensorHandle must have a 'va_base' field."""
|
||||
th = TensorHandle(
|
||||
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||
shards=(), va_base=0x1_0000_0000,
|
||||
)
|
||||
assert th.va_base == 0x1_0000_0000
|
||||
|
||||
|
||||
def test_tensor_handle_va_base_immutable():
|
||||
"""TensorHandle.va_base is immutable (frozen dataclass)."""
|
||||
th = TensorHandle(
|
||||
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||
shards=(), va_base=0x1_0000_0000,
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
th.va_base = 0x2_0000_0000 # type: ignore[misc]
|
||||
|
||||
|
||||
def test_tensor_shard_no_va_field():
|
||||
"""TensorShard should NOT have a va field — va is derived from
|
||||
TensorHandle.va_base + shard.offset_bytes."""
|
||||
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||||
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
|
||||
|
||||
|
||||
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
|
||||
|
||||
|
||||
def test_deploy_tensor_assigns_va_base():
|
||||
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
assert th.va_base is not None
|
||||
assert th.va_base > 0
|
||||
|
||||
|
||||
def test_deploy_tensor_va_covers_all_shards():
|
||||
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Each shard's VA is derivable: va_base + offset_bytes
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert shard_va > 0
|
||||
|
||||
|
||||
def test_deploy_tensor_registers_mmu_mappings():
|
||||
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
|
||||
# Every MMU should have entries (broadcast)
|
||||
for mmu in mmus.values():
|
||||
assert mmu.num_entries > 0
|
||||
|
||||
# Each shard's derived VA should translate to its PA in every MMU
|
||||
for mmu in mmus.values():
|
||||
for s in th.shards:
|
||||
shard_va = th.va_base + s.offset_bytes
|
||||
assert mmu.translate(shard_va) == s.pa
|
||||
|
||||
|
||||
# ── T12. Tensor.va property ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_va_property():
|
||||
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
allocs = _make_allocators(1)
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus(1)
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
|
||||
|
||||
t = Tensor(shape=(2048,), dtype="f16", name="test")
|
||||
t._handle = deploy_tensor(
|
||||
name="test",
|
||||
shape=(2048,),
|
||||
dtype="f16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
va_allocator=va_alloc,
|
||||
mmus=mmus,
|
||||
)
|
||||
assert t.va > 0
|
||||
assert t.va == t._handle.va_base
|
||||
|
||||
|
||||
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_load_uses_va_in_dma_cmd():
|
||||
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
|
||||
tl = TLContext(dispatch_cycles=0)
|
||||
va_ptr = 0x1_0000_0000
|
||||
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
# The DMA command should carry the VA
|
||||
assert dma_cmds[0].src_addr == va_ptr
|
||||
|
||||
|
||||
def test_tl_store_uses_va_in_dma_cmd():
|
||||
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
|
||||
tl = TLContext(dispatch_cycles=0)
|
||||
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
|
||||
va_out = 0x2_0000_0000
|
||||
tl.store(va_out, h)
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
assert dma_cmds[0].dst_addr == va_out
|
||||
|
||||
|
||||
# ── T14. Kernel VA offset calculation ─────────────────────────────────
|
||||
|
||||
|
||||
def test_kernel_va_offset_in_dma():
|
||||
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
|
||||
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
|
||||
pid = tl.program_id(0)
|
||||
elem_bytes = 2 # f16
|
||||
offset = pid * BLOCK_SIZE * elem_bytes
|
||||
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
|
||||
|
||||
va_base = 0x1_0000_0000
|
||||
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
|
||||
run_kernel(tiled_kernel, tl, a_ptr=va_base)
|
||||
|
||||
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||
assert len(dma_cmds) == 1
|
||||
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
|
||||
assert dma_cmds[0].src_addr == expected_va
|
||||
Reference in New Issue
Block a user