Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg

Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 00:01:47 -07:00
parent 62fb01ae18
commit 08812eda58
34 changed files with 2131 additions and 139 deletions
+6 -6
View File
@@ -28,7 +28,7 @@ class TensorHandle:
"""
id: str
pa: int # physical address in HBM/TCM
addr: int # address (VA when MMU enabled, PA otherwise)
shape: tuple[int, ...]
dtype: str
nbytes: int # total byte size
@@ -50,19 +50,19 @@ class CompletionHandle:
@dataclass(frozen=True)
class DmaReadCmd:
"""DMA READ: HBM → PE_TCM."""
"""DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
src_pa: int
src_addr: int
nbytes: int
@dataclass(frozen=True)
class DmaWriteCmd:
"""DMA WRITE: PE_TCM → HBM."""
"""DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA)."""
handle: TensorHandle
dst_pa: int
dst_addr: int
nbytes: int
@@ -108,7 +108,7 @@ class CompositeCmd:
op: Literal["gemm", "math"]
a: TensorHandle
b: TensorHandle | None
out_pa: int
out_addr: int
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
@@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_mmu import PeMmuComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
from kernbench.components.impls.xbar import PositionAwareXbarComponent
@@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
@@ -47,6 +49,7 @@ __all__ = [
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeMmuComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
+13 -1
View File
@@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase):
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
from kernbench.runtime_api.kernel import (
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
)
target_cubes = getattr(request, "target_cubes", "all")
@@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase):
targets.append(key)
return targets
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
my_sip = self._my_sip()
if target_cubes == "all":
n_cubes = 16
if self.ctx and self.ctx.spec:
sips = self.ctx.spec.get("system", {}).get("sips", {})
n_cubes = sips.get("cubes_per_sip", 16)
return [(my_sip, c) for c in range(n_cubes)]
return [(my_sip, c) for c in target_cubes]
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
+60 -1
View File
@@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase):
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
@@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase):
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
env.process(self._mmu_msg_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
@@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase):
n_slices = mm.get("hbm_slices_per_cube", 8)
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
PE_MMU is a terminal node — completes the transaction directly.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_MMU
sub_dones: list[simpy.Event] = []
for pe_id in pe_ids:
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
# Wait for all PE_MMUs to complete
for sd in sub_dones:
yield sd
# Send aggregate response on reverse path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
+6 -3
View File
@@ -84,12 +84,15 @@ class PeCpuComponent(ComponentBase):
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → PA (pointer), ScalarArg → value
# TensorArg → VA base (or PA fallback), ScalarArg → value
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
if arg.va_base:
kernel_args.append(arg.va_base)
else:
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
+17 -4
View File
@@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase):
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
self._mmu = None # PeMMU instance, set by engine wiring
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
@@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase):
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target PA
# Determine direction and target address (VA → PA via MMU)
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
target_pa = cmd.src_pa
raw_addr = cmd.src_addr
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
target_pa = cmd.dst_pa
raw_addr = cmd.dst_addr
is_write = True
else:
pe_txn.done.succeed()
return
# Resolve PA → HBM node and compute path
# Translate VA → PA via MMU (if available), then resolve HBM node
# If MMU has no mapping for this address (PageFault), treat as PA directly
# (backward-compatible with PA-only mode)
if self._mmu is not None:
from kernbench.policy.address.pe_mmu import PageFault
try:
target_pa = self._mmu.translate(raw_addr)
if self._mmu.overhead_ns > 0:
yield env.timeout(self._mmu.overhead_ns)
except PageFault:
target_pa = raw_addr
else:
target_pa = raw_addr # fallback: treat as PA directly
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
+66
View File
@@ -0,0 +1,66 @@
"""PE_MMU component: address translation unit.
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.policy.address.pe_mmu import PeMMU
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMmuComponent(ComponentBase):
"""PE_MMU: per-PE virtual-to-physical address translation.
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
via the ``mmu`` property for synchronous VA→PA translation.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
@property
def mmu(self) -> PeMMU:
"""The underlying PeMMU utility object for direct translate() calls."""
return self._mmu
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def _worker(self, env: simpy.Environment) -> Generator:
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
while True:
txn: Any = yield self._inbox.get()
if hasattr(txn, "request"):
request = txn.request
if isinstance(request, MmuMapMsg):
for entry in request.entries:
self._mmu.map(
va=entry["va"], pa=entry["pa"], size=entry["size"],
)
txn.done.succeed()
elif isinstance(request, MmuUnmapMsg):
for entry in request.entries:
self._mmu.unmap(va=entry["va"], size=entry["size"])
txn.done.succeed()
else:
# Forward non-MMU transactions normally
yield from self._forward_txn(env, txn)
else:
yield from self._forward_txn(env, txn)
@@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
@@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", pa=0,
id=f"out_tile_{tile_idx}", addr=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
@@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase):
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
@@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase):
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
+83 -16
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import bisect
from dataclasses import dataclass
from kernbench.policy.address.phyaddr import PhysAddr
@@ -29,6 +30,63 @@ class AddressConfig:
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
class _FreeList:
"""Offset-based free-list allocator with coalescing."""
def __init__(self, capacity: int) -> None:
self._capacity = capacity
self._used = 0
self._free: list[tuple[int, int]] = [(0, capacity)] # (offset, size)
@property
def used(self) -> int:
return self._used
@property
def total(self) -> int:
return self._capacity
def alloc(self, nbytes: int) -> int:
"""Allocate nbytes, return offset. Raises AllocationError if full."""
for i, (start, size) in enumerate(self._free):
if size >= nbytes:
if size == nbytes:
self._free.pop(i)
else:
self._free[i] = (start + nbytes, size - nbytes)
self._used += nbytes
return start
raise AllocationError(
f"overflow: need {nbytes}, "
f"largest free block {max((s for _, s in self._free), default=0)}"
)
def free(self, offset: int, nbytes: int) -> None:
"""Return a range to the free-list with coalescing."""
self._used -= nbytes
new_start = offset
new_end = offset + nbytes
idx = bisect.bisect_left(self._free, (offset,))
# Coalesce with previous block
if idx > 0:
prev_start, prev_size = self._free[idx - 1]
if prev_start + prev_size == new_start:
new_start = prev_start
idx -= 1
self._free.pop(idx)
# Coalesce with next block
if idx < len(self._free):
next_start, next_size = self._free[idx]
if new_end == next_start:
new_end = next_start + next_size
self._free.pop(idx)
self._free.insert(idx, (new_start, new_end - new_start))
class PEMemAllocator:
def __init__(
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
@@ -38,39 +96,48 @@ class PEMemAllocator:
self._cube_id = cube_id
self._pe_id = pe_id
self._cfg = cfg
self._hbm_cursor = 0
self._tcm_cursor = 0
self._hbm = _FreeList(cfg.hbm_slice_bytes)
self._tcm = _FreeList(cfg.tcm_allocatable_bytes)
def alloc_hbm(self, nbytes: int) -> PhysAddr:
if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes:
try:
offset = self._hbm.alloc(nbytes)
except AllocationError:
raise AllocationError(
f"HBM overflow: need {nbytes}, "
f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}"
f"available {self._cfg.hbm_slice_bytes - self._hbm.used}"
)
pa = PhysAddr.pe_hbm_addr(
return PhysAddr.pe_hbm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor,
pe_id=self._pe_id, pe_local_hbm_offset=offset,
slice_size_bytes=self._cfg.hbm_slice_bytes,
)
self._hbm_cursor += nbytes
return pa
def free_hbm(self, pa: PhysAddr, nbytes: int) -> None:
# Extract PE-local offset from the PA's hbm_offset
pe_slice_start = self._pe_id * self._cfg.hbm_slice_bytes
offset = pa.hbm_offset - pe_slice_start
self._hbm.free(offset, nbytes)
def alloc_tcm(self, nbytes: int) -> PhysAddr:
if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes:
try:
offset = self._tcm.alloc(nbytes)
except AllocationError:
raise AllocationError(
f"TCM overflow: need {nbytes}, "
f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}"
f"available {self._cfg.tcm_allocatable_bytes - self._tcm.used}"
)
pa = PhysAddr.pe_tcm_addr(
return PhysAddr.pe_tcm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, tcm_offset=self._tcm_cursor,
pe_id=self._pe_id, tcm_offset=offset,
)
self._tcm_cursor += nbytes
return pa
def free_tcm(self, pa: PhysAddr, nbytes: int) -> None:
self._tcm.free(pa.sub_offset, nbytes)
@property
def hbm_used(self) -> int:
return self._hbm_cursor
return self._hbm.used
@property
def hbm_total(self) -> int:
@@ -78,7 +145,7 @@ class PEMemAllocator:
@property
def tcm_used(self) -> int:
return self._tcm_cursor
return self._tcm.used
@property
def tcm_total(self) -> int:
+66
View File
@@ -0,0 +1,66 @@
"""PeMMU: per-PE virtual-to-physical address translation.
Page-aligned dict lookup for O(1) VA→PA translation.
Used as a utility class by PE_DMA, PE_GEMM (direct call),
and as a component inbox target for MmuMapMsg/MmuUnmapMsg.
"""
from __future__ import annotations
class PageFault(Exception):
"""Raised when VA has no mapping in the page table."""
def __init__(self, va: int | str) -> None:
if isinstance(va, str):
super().__init__(va)
else:
self.va = va
super().__init__(f"PageFault at VA 0x{va:x}")
class PeMMU:
"""Per-PE MMU with page-aligned VA→PA translation table.
Args:
page_size: Page size in bytes (default 2 MB).
overhead_ns: Per-access TLB lookup latency in nanoseconds.
"""
def __init__(
self,
page_size: int = 2 * 1024 * 1024,
overhead_ns: float = 0.0,
) -> None:
self._page_size = page_size
self._page_shift = (page_size - 1).bit_length()
self._page_mask = page_size - 1
self._table: dict[int, int] = {} # va_page_number → pa_page_base
self._overhead_ns = overhead_ns
@property
def overhead_ns(self) -> float:
return self._overhead_ns
@property
def num_entries(self) -> int:
return len(self._table)
def map(self, va: int, pa: int, size: int) -> None:
"""Register VA→PA mapping for a contiguous range."""
for off in range(0, size, self._page_size):
vpn = (va + off) >> self._page_shift
self._table[vpn] = pa + off
def unmap(self, va: int, size: int) -> None:
"""Remove VA mapping for a contiguous range."""
for off in range(0, size, self._page_size):
vpn = (va + off) >> self._page_shift
self._table.pop(vpn, None)
def translate(self, va: int) -> int:
"""Translate VA to PA. Raises PageFault if unmapped."""
vpn = va >> self._page_shift
pa_page_base = self._table.get(vpn)
if pa_page_base is None:
raise PageFault(va)
return pa_page_base + (va & self._page_mask)
@@ -0,0 +1,93 @@
"""VirtualAllocator: device-wide VA space management with free-list.
Allocations are page-aligned. Freed ranges are coalesced with adjacent
free blocks to allow larger re-allocations.
"""
from __future__ import annotations
import bisect
class VaAllocationError(Exception):
pass
class VirtualAllocator:
"""Manages a contiguous VA address space with page-aligned alloc/free.
Args:
va_base: Start of the VA range.
va_size: Total size of the VA range in bytes.
page_size: Page granularity in bytes.
"""
def __init__(
self,
va_base: int,
va_size: int,
page_size: int = 2 * 1024 * 1024,
) -> None:
self._va_base = va_base
self._va_size = va_size
self._page_size = page_size
self._used = 0
# Free list: sorted list of (start, size) tuples
self._free: list[tuple[int, int]] = [(va_base, va_size)]
def _align_up(self, nbytes: int) -> int:
"""Round up to page boundary."""
return ((nbytes + self._page_size - 1) // self._page_size) * self._page_size
def alloc(self, nbytes: int) -> int:
"""Allocate a contiguous VA range. Returns the start VA."""
aligned = self._align_up(nbytes)
for i, (start, size) in enumerate(self._free):
if size >= aligned:
# Take from the beginning of this free block
if size == aligned:
self._free.pop(i)
else:
self._free[i] = (start + aligned, size - aligned)
self._used += aligned
return start
raise VaAllocationError(
f"Out of VA space: need {aligned}, largest free block "
f"{max((s for _, s in self._free), default=0)}"
)
def free(self, va: int, nbytes: int) -> None:
"""Free a VA range and coalesce with adjacent free blocks."""
aligned = self._align_up(nbytes)
self._used -= aligned
# Insert into sorted free list and coalesce
new_start = va
new_end = va + aligned
# Find insertion point
idx = bisect.bisect_left(self._free, (va,))
# Try coalesce with previous block
if idx > 0:
prev_start, prev_size = self._free[idx - 1]
if prev_start + prev_size == new_start:
new_start = prev_start
idx -= 1
self._free.pop(idx)
# Try coalesce with next block
if idx < len(self._free):
next_start, next_size = self._free[idx]
if new_end == next_start:
new_end = next_start + next_size
self._free.pop(idx)
self._free.insert(idx, (new_start, new_end - new_start))
@property
def used(self) -> int:
return self._used
@property
def total(self) -> int:
return self._va_size
+188 -22
View File
@@ -19,8 +19,18 @@ class RuntimeContext:
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
def __enter__(self):
return self
def __exit__(self, *exc):
self.cleanup()
return False
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
@@ -58,6 +68,92 @@ class RuntimeContext:
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── Tensor lifecycle ─────────────────────────────────────────────
def _free_tensor(self, tensor: Any) -> None:
"""Free a single tensor: unmap MMU, return VA and PA."""
handle = tensor._handle
if handle is None:
return
tensor._handle = None
if not handle.va_base:
return
from kernbench.runtime_api.kernel import MmuUnmapMsg
dp_policy = None
if tensor._dp_metadata is not None:
dp_policy = tensor._dp_metadata.dp_policy
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
# Send MmuUnmapMsg through fabric
from collections import defaultdict
if is_cube_replicate:
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in group_shards
)
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
entries = tuple(
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuUnmapMsg(
correlation_id=self.correlation_id,
request_id=f"unmap_{tensor.name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Return VA space
if self._va_allocator is not None:
self._va_allocator.free(handle.va_base, handle.nbytes)
# Return PA space
if self._allocators:
for shard in handle.shards:
flat_idx = (
shard.sip * self._num_cubes * self._pes_per_cube
+ shard.cube * self._pes_per_cube
+ shard.pe
)
alloc = self._allocators.get(flat_idx)
if alloc is not None:
from kernbench.policy.address.phyaddr import PhysAddr
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
def cleanup(self) -> None:
"""Free all tensors created by this context."""
for ref in self._tensors:
t = ref()
if t is not None and t._handle is not None:
self._free_tensor(t)
self._tensors.clear()
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
@@ -111,6 +207,26 @@ class RuntimeContext:
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
# Initialize VA allocator and per-PE MMUs
from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
page_size = int(pe_mmu_attrs.get("page_size", 4096))
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
self._va_allocator = VirtualAllocator(
va_base=0x1_0000_0000,
va_size=64 * (1 << 30), # 64 GB VA space
page_size=page_size,
)
total_pes = sip_count * cubes_per_sip * pes_per_cube
for flat_idx in range(total_pes):
self._mmus[flat_idx] = PeMMU(
page_size=page_size, overhead_ns=tlb_overhead_ns,
)
return self._allocators
def _next_tensor_name(self) -> str:
@@ -122,63 +238,57 @@ class RuntimeContext:
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
placement: list | None,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
if not isinstance(dp, DPPolicy):
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy: DPPolicy | None = None
# Resolve placement: dp= takes priority over placement=
if dp is not None and isinstance(dp, DPPolicy):
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
elif placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
pe_indices = {s.pe_index for s in placement}
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
@@ -186,8 +296,64 @@ class RuntimeContext:
dtype=dtype,
placement=placement,
allocators=allocators,
va_allocator=self._va_allocator,
mmus=self._mmus,
)
t._handle = handle
import weakref
t._ctx_ref = weakref.ref(self)
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
if handle.va_base:
from collections import defaultdict
from kernbench.runtime_api.kernel import MmuMapMsg
is_cube_replicate = (
dp_policy is not None and dp_policy.cube == "replicate"
)
if is_cube_replicate:
# Replicate: each (sip, cube) gets only its own local PA mappings
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
for shard in handle.shards:
cube_groups[(shard.sip, shard.cube)].append(shard)
for (sip, cube), group_shards in cube_groups.items():
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in group_shards
)
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
entries=entries,
target_sips=(sip,),
target_cubes=(cube,),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
else:
# Sharded: broadcast all mappings to all target (sip, cube)s
entries = tuple(
{"va": handle.va_base + s.offset_bytes,
"pa": s.pa, "size": s.nbytes}
for s in handle.shards
)
sip_set = sorted({s.sip for s in handle.shards})
cube_set = sorted({s.cube for s in handle.shards})
msg = MmuMapMsg(
correlation_id=self.correlation_id,
request_id=f"mmu_{tensor_name}",
entries=entries,
target_sips=tuple(sip_set),
target_cubes=tuple(cube_set),
target_pe="all",
)
h = self.submit(msg)
self.wait(h)
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
+31
View File
@@ -69,6 +69,7 @@ class TensorArgShard:
class TensorArg:
shards: tuple[TensorArgShard, ...]
arg_kind: Literal["tensor"] = "tensor"
va_base: int = 0 # VA base address for the entire tensor
@dataclass(frozen=True)
@@ -121,3 +122,33 @@ class PeDmaMsg:
nbytes: int
is_write: bool = False
msg_type: Literal["pe_dma"] = "pe_dma"
@dataclass(frozen=True)
class MmuMapMsg:
"""MMU mapping install: broadcast VA→PA entries to target PEs.
Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU.
target_sips controls which SIPs receive the message.
"""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_map"] = "mmu_map"
@dataclass(frozen=True)
class MmuUnmapMsg:
"""MMU mapping removal: broadcast VA ranges to unmap from all PEs."""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
+40 -3
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import weakref
from dataclasses import dataclass
from typing import Literal
@@ -26,6 +27,7 @@ class TensorHandle:
dtype: str
itemsize: int
shards: tuple[TensorShard, ...]
va_base: int = 0 # VA base address for the entire tensor
@property
def nbytes(self) -> int:
@@ -56,8 +58,19 @@ def deploy_tensor(
placement: list[ShardSpec],
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None,
mmus: dict | None = None,
) -> TensorHandle:
from kernbench.policy.address.pe_mmu import PeMMU
isize = dtype_itemsize(dtype)
total_nbytes = math.prod(shape) * isize
# Allocate VA range for the entire tensor (if VA allocator provided)
va_base = 0
if va_allocator is not None:
va_base = va_allocator.alloc(total_nbytes)
shards: list[TensorShard] = []
for spec in placement:
alloc = allocators[spec.pe_index]
@@ -65,20 +78,29 @@ def deploy_tensor(
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
encoded_pa = pa.encode()
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=pa.encode(),
pa=encoded_pa,
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
# Register VA→PA mapping in all MMUs (broadcast)
if va_base and mmus is not None:
shard_va = va_base + spec.offset_bytes
for mmu in mmus.values():
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
return TensorHandle(
name=name,
shape=shape,
dtype=dtype,
itemsize=isize,
shards=tuple(shards),
va_base=va_base,
)
@@ -101,8 +123,7 @@ class Tensor:
Usage::
a = ctx.zeros((M, K), dtype="f16")
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
"""
@@ -117,6 +138,14 @@ class Tensor:
self.name = name
self._dp_metadata: DPMetadata | None = None
self._handle: TensorHandle | None = None
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
def __del__(self) -> None:
if self._ctx_ref is None or self._handle is None:
return
ctx = self._ctx_ref()
if ctx is not None:
ctx._free_tensor(self)
@property
def itemsize(self) -> int:
@@ -133,6 +162,13 @@ class Tensor:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return self._handle.shards[0].pa
@property
def va(self) -> int:
"""VA base address for the entire tensor."""
if self._handle is None:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return self._handle.va_base
def to(
self,
placement: list[ShardSpec] | None = None,
@@ -163,4 +199,5 @@ class Tensor:
)
for s in self._handle.shards
),
va_base=self._handle.va_base,
)
+92
View File
@@ -98,6 +98,16 @@ class GraphEngine:
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
# Derive PE_MMU node ID from PE_DMA node ID
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
mmu_id = f"{pe_prefix}.pe_mmu"
mmu_comp = self._components.get(mmu_id)
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
self._components[node_id]._mmu = mmu_comp.mmu
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
@@ -119,6 +129,27 @@ class GraphEngine:
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
def mmu_map(self, va: int, pa: int, size: int) -> None:
"""Sideband: install VA→PA mapping in all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_map_pe(
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
) -> None:
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
comp = self._components.get(mmu_id)
if comp is not None and hasattr(comp, "mmu"):
comp.mmu.map(va=va, pa=pa, size=size)
def mmu_unmap(self, va: int, size: int) -> None:
"""Sideband: remove VA mapping from all PE_MMU components."""
for node_id, comp in self._components.items():
if hasattr(comp, "mmu"):
comp.mmu.unmap(va=va, size=size)
# ── internal ────────────────────────────────────────────────────
def _wire(
@@ -166,6 +197,11 @@ class GraphEngine:
yield from self._process_memory_direct(key, request, done)
return
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
yield from self._process_mmu_msg(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
@@ -341,3 +377,59 @@ class GraphEngine:
return entries
raise ValueError(f"unsupported request type: {type(request)}")
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
"""
start_ns = self._env.now
target_sips = getattr(request, "target_sips", "all")
# Determine target SIPs
sip_set: set[int] = set()
if target_sips == "all":
for ep_id in self._resolver.find_all_pcie_eps():
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
sip_set.add(sip_id)
else:
sip_set = set(target_sips)
entries = []
for sip_id in sorted(sip_set):
entries.append((
self._resolver.find_pcie_ep(sip_id),
self._resolver.find_io_cpu(sip_id),
0, # MmuMapMsg has no data payload
))
if not entries:
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
done.succeed()
return
if len(entries) == 1:
pcie_ep_id, io_cpu_id, _ = entries[0]
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP fan-out
sub_dones = []
for pcie_ep_id, io_cpu_id, _ in entries:
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
sub_done = self._env.event()
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
yield self._host_queues[pcie_ep_id].put(sub_txn)
sub_dones.append(sub_done)
for sd in sub_dones:
yield sd
elapsed = self._env.now - start_ns
self._results[key] = (
Completion(ok=True),
{"total_ns": elapsed, "msg_type": request.msg_type},
)
done.succeed()
+11
View File
@@ -22,6 +22,7 @@ _PE_COMP_OFFSETS = {
"pe_dma": (0.0, -0.15),
"pe_gemm": (0.0, 0.0),
"pe_math": (0.0, 0.15),
"pe_mmu": (0.15, -0.15),
"pe_tcm": (0.3, 0.0),
}
@@ -495,6 +496,15 @@ def _instantiate_cube(
kind="pe_response",
))
# noc → PE_MMU (MMU mapping install)
pe_mmu_id = f"{pp}.pe_mmu"
if pe_mmu_id in nodes:
edges.append(Edge(
src=f"{cp}.noc", dst=pe_mmu_id,
distance_mm=clinks.get("noc_to_pe_mmu_mm", 0.0),
kind="command",
))
pe_idx += 1
# ── xbar_top/bot → HBM slices ──
@@ -1073,6 +1083,7 @@ def _build_pe_view(spec: dict) -> ViewGraph:
"pe_dma": (7.0, 1.5),
"pe_gemm": (7.0, 4.0),
"pe_math": (7.0, 6.5),
"pe_mmu": (4.0, 1.5),
"pe_tcm": (10.0, 4.0),
}
+16 -16
View File
@@ -86,11 +86,11 @@ class TLContext:
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, pa: int, shape: tuple[int, ...], dtype: str,
self, addr: int, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
pa=pa, shape=shape, dtype=dtype,
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
)
@@ -104,7 +104,7 @@ class TLContext:
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
@@ -113,9 +113,9 @@ class TLContext:
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle."""
self._emit_dispatch_overhead()
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
self._commands.append(DmaReadCmd(
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
))
return handle
@@ -123,7 +123,7 @@ class TLContext:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
self._commands.append(DmaWriteCmd(
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
))
# ── GEMM Engine (blocking) ────────────────────────────────────
@@ -141,7 +141,7 @@ class TLContext:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
@@ -149,7 +149,7 @@ class TLContext:
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
return out
@@ -182,7 +182,7 @@ class TLContext:
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
@@ -201,7 +201,7 @@ class TLContext:
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
return out
@@ -209,7 +209,7 @@ class TLContext:
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
@@ -227,17 +227,17 @@ class TLContext:
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
@@ -247,7 +247,7 @@ class TLContext:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, pa=x.pa, shape=new_shape,
id=x.id, addr=x.addr, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
@@ -278,7 +278,7 @@ class TLContext:
self._emit_dispatch_overhead()
self._commands.append(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion