Files
kernbench2/src/kernbench/runtime_api/kernel.py
T
ywkang 08812eda58 Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use
contiguous virtual addresses on sharded tensors.

Key changes:
- PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA
- VirtualAllocator + PEMemAllocator: free-list with coalescing
- MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing
- DPPolicy-based mapping: replicate=local, sharded=broadcast
- Tensor lifecycle: del + weakref cleanup, context manager
- Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 00:01:47 -07:00

155 lines
3.9 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, TypeAlias
@dataclass(frozen=True)
class MemoryWriteMsg:
correlation_id: str
request_id: str
dst_sip: int
dst_cube: int
dst_pe: int
dst_pa: int
nbytes: int
src_kind: Literal["pattern", "host_buffer_ref"] = "pattern"
pattern: str | None = None
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["memory_write"] = "memory_write"
@dataclass(frozen=True)
class MemoryReadMsg:
correlation_id: str
request_id: str
src_sip: int
src_cube: int
src_pe: int
src_pa: int
nbytes: int
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["memory_read"] = "memory_read"
@dataclass(frozen=True)
class KernelRef:
"""Reference to a kernel binary or builtin timing model.
Kernel binaries must be pre-deployed to device memory via MemoryWriteMsg.
KernelLaunchMsg references the deployed location by PA — source code or IR
MUST NOT be embedded in launch messages.
- "deployed": kernel binary pre-deployed to HBM/SRAM at deploy_pa.
- "builtin": simulator built-in timing model, identified by name.
"""
name: str
kind: Literal["deployed", "builtin"]
deploy_pa: int | None = None
deploy_sip: int = 0
deploy_cube: int = 0
deploy_pe: int = 0
nbytes_code: int = 0
@dataclass(frozen=True)
class TensorArgShard:
sip: int
cube: int
pe: int
pa: int
nbytes: int
offset_bytes: int
@dataclass(frozen=True)
class TensorArg:
shards: tuple[TensorArgShard, ...]
arg_kind: Literal["tensor"] = "tensor"
va_base: int = 0 # VA base address for the entire tensor
@dataclass(frozen=True)
class ScalarArg:
dtype: str
value: float | int
arg_kind: Literal["scalar"] = "scalar"
KernelArg: TypeAlias = TensorArg | ScalarArg
@dataclass(frozen=True)
class KernelLaunchMsg:
correlation_id: str
request_id: str
kernel_ref: KernelRef
args: tuple[KernelArg, ...]
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["kernel_launch"] = "kernel_launch"
@dataclass(frozen=True)
class ResponseMsg:
"""Device→Host response carrying PE execution result."""
correlation_id: str
request_id: str
src_cube: int
src_pe: int
success: bool
msg_type: Literal["response"] = "response"
@dataclass(frozen=True)
class PeDmaMsg:
"""Direct PE DMA request: host injects a transfer at PE_DMA level.
Used by the probe utility to measure PE→HBM latency without requiring
the full PE_CPU → scheduler → DMA pipeline.
"""
correlation_id: str
request_id: str
src_sip: int
src_cube: int
src_pe: int
dst_pa: int
nbytes: int
is_write: bool = False
msg_type: Literal["pe_dma"] = "pe_dma"
@dataclass(frozen=True)
class MmuMapMsg:
"""MMU mapping install: broadcast VA→PA entries to target PEs.
Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU.
target_sips controls which SIPs receive the message.
"""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_map"] = "mmu_map"
@dataclass(frozen=True)
class MmuUnmapMsg:
"""MMU mapping removal: broadcast VA ranges to unmap from all PEs."""
correlation_id: str
request_id: str
entries: tuple[dict, ...] # ({"va": int, "size": int}, ...)
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_unmap"] = "mmu_unmap"