Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+6 -6
View File
@@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn():
gemm.in_ports["src"] = simpy.Store(env)
gemm.start(env)
a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64)
b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64)
out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn():
math_comp.in_ports["src"] = simpy.Store(env)
math_comp.start(env)
x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
cmd = MathCmd(op="exp", inputs=(x,), out=out)
done = env.event()
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
@@ -777,7 +777,7 @@ def test_tl_ref_no_dma():
tl = TLContext(pe_id=0, dispatch_cycles=0)
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
assert handle.pa == 0x1000
assert handle.addr == 0x1000
assert handle.shape == (4, 4)
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"