Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+16 -16
View File
@@ -86,11 +86,11 @@ class TLContext:
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, pa: int, shape: tuple[int, ...], dtype: str,
self, addr: int, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
pa=pa, shape=shape, dtype=dtype,
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
)
@@ -104,7 +104,7 @@ class TLContext:
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
@@ -113,9 +113,9 @@ class TLContext:
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle."""
self._emit_dispatch_overhead()
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
self._commands.append(DmaReadCmd(
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
))
return handle
@@ -123,7 +123,7 @@ class TLContext:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
self._commands.append(DmaWriteCmd(
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
))
# ── GEMM Engine (blocking) ────────────────────────────────────
@@ -141,7 +141,7 @@ class TLContext:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
@@ -149,7 +149,7 @@ class TLContext:
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
return out
@@ -182,7 +182,7 @@ class TLContext:
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
@@ -201,7 +201,7 @@ class TLContext:
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
return out
@@ -209,7 +209,7 @@ class TLContext:
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
@@ -227,17 +227,17 @@ class TLContext:
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
@@ -247,7 +247,7 @@ class TLContext:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, pa=x.pa, shape=new_shape,
id=x.id, addr=x.addr, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
@@ -278,7 +278,7 @@ class TLContext:
self._emit_dispatch_overhead()
self._commands.append(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion