Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -207,12 +207,15 @@ benchmark instances by default.
|
|||||||
|
|
||||||
## R10. Memory Addressing (Phase 0)
|
## R10. Memory Addressing (Phase 0)
|
||||||
|
|
||||||
In Phase 0, the simulator uses a **PA-first memory model**:
|
The simulator uses a **VA/PA memory model** (ADR-0011):
|
||||||
|
|
||||||
- All memory operations use device physical addresses (PA) only.
|
- Tensors are assigned a contiguous virtual address (VA) range at deployment.
|
||||||
- Virtual addressing, MMU/IOMMU, and address translation latency are out of scope.
|
- PE_MMU translates VA→PA per access; TLB overhead is configurable.
|
||||||
|
- Mapping installation (MmuMapMsg) traverses the fabric with measured latency.
|
||||||
|
- Replicate tensors use per-cube local PA mapping; sharded tensors broadcast.
|
||||||
|
- PA-only fallback is retained for backward compatibility.
|
||||||
- Tensor placement is represented as a list of PA shards, each explicitly tagged
|
- Tensor placement is represented as a list of PA shards, each explicitly tagged
|
||||||
with `(sip, cube, pe)`.
|
with `(sip, cube, pe)`, plus a tensor-wide `va_base`.
|
||||||
|
|
||||||
All memory access latency MUST be modeled explicitly via graph traversal.
|
All memory access latency MUST be modeled explicitly via graph traversal.
|
||||||
No implicit translation or hidden latency is allowed.
|
No implicit translation or hidden latency is allowed.
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
def run(ctx):
|
def run(torch):
|
||||||
print("IPCQ all reduce kernel bench")
|
print("IPCQ all reduce kernel bench")
|
||||||
|
|||||||
+2
-2
@@ -15,7 +15,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
|
|||||||
|
|
||||||
Expected layout (repo root):
|
Expected layout (repo root):
|
||||||
benches/<bench_id>.py
|
benches/<bench_id>.py
|
||||||
def run(ctx: RuntimeContext) -> Any
|
def run(torch: RuntimeContext) -> Any
|
||||||
"""
|
"""
|
||||||
bench_id = bench_id.strip()
|
bench_id = bench_id.strip()
|
||||||
if not bench_id:
|
if not bench_id:
|
||||||
@@ -30,7 +30,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
|
|||||||
|
|
||||||
run_fn = getattr(mod, "run", None)
|
run_fn = getattr(mod, "run", None)
|
||||||
if run_fn is None:
|
if run_fn is None:
|
||||||
raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.")
|
raise ValueError(f"Bench module {module_path} must define a 'run(torch)' function.")
|
||||||
if not callable(run_fn):
|
if not callable(run_fn):
|
||||||
raise ValueError(f"'run' in {module_path} is not callable.")
|
raise ValueError(f"'run' in {module_path} is not callable.")
|
||||||
|
|
||||||
|
|||||||
+5
-5
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|||||||
tl.wait(handle)
|
tl.wait(handle)
|
||||||
|
|
||||||
|
|
||||||
def run(ctx):
|
def run(torch):
|
||||||
"""Run the QKV GEMM benchmark."""
|
"""Run the QKV GEMM benchmark."""
|
||||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
|
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
|
||||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||||
out = ctx.empty(
|
out = torch.empty(
|
||||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch GEMM kernel
|
# Launch GEMM kernel
|
||||||
ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|
torch.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|
||||||
|
|||||||
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
|||||||
tl.wait(handle)
|
tl.wait(handle)
|
||||||
|
|
||||||
|
|
||||||
def run(ctx):
|
def run(torch):
|
||||||
"""Run the multi-PE QKV GEMM benchmark."""
|
"""Run the multi-PE QKV GEMM benchmark."""
|
||||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
|
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
|
||||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||||
out = ctx.empty(
|
out = torch.empty(
|
||||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch GEMM kernel on all PEs
|
# Launch GEMM kernel on all PEs
|
||||||
ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
|
torch.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# ADR-0011: Memory Addressing Simplification (PA-first)
|
# ADR-0011: Memory Addressing — PA-first with VA/MMU Extension
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
|
|
||||||
Accepted
|
Accepted (Phase 1 VA/MMU implemented)
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
|
|
||||||
@@ -11,49 +11,82 @@ translation path for DMA: host allocates physical memory at PE level, maps it
|
|||||||
into a virtual address space, installs mappings, and DMA requests use virtual
|
into a virtual address space, installs mappings, and DMA requests use virtual
|
||||||
addresses that are translated to physical addresses.
|
addresses that are translated to physical addresses.
|
||||||
|
|
||||||
For early development, we want a minimal, deterministic model that enables:
|
The PA-only model (Phase 0) was insufficient for running standard Triton kernels
|
||||||
|
that use `base_addr + offset` patterns on sharded tensors — each PE's shard has
|
||||||
- correct routing and latency accounting through the graph,
|
a different PA, but the kernel needs a single contiguous address space.
|
||||||
- stable tensor deployment and kernel execution semantics,
|
|
||||||
- future extension toward VA/MMU without rewriting workflows.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Decision
|
## Decision
|
||||||
|
|
||||||
### D1. Phase 0 model is PA-only
|
### D1. Phase 0 model is PA-only (original, retained as fallback)
|
||||||
|
|
||||||
The simulator uses a PA-first model:
|
|
||||||
|
|
||||||
- All device memory accesses (MemoryRead/MemoryWrite) operate on device physical
|
- All device memory accesses (MemoryRead/MemoryWrite) operate on device physical
|
||||||
addresses (PA) plus size.
|
addresses (PA) plus size.
|
||||||
- Tensor handles store PA-based shard mappings after deployment.
|
- PA-only mode remains functional via PageFault fallback in PE_DMA.
|
||||||
- KernelLaunch passes tensor arguments as PA-based mappings (or references to them).
|
|
||||||
- MMU/IOMMU concepts (virtual address spaces, page tables, translation latency)
|
|
||||||
are NOT modeled in Phase 0.
|
|
||||||
|
|
||||||
### D2. Allocation produces PA mappings
|
### D2. Allocation produces PA mappings
|
||||||
|
|
||||||
Device allocation selects PE-local memory regions and returns PA mappings
|
Device allocation selects PE-local memory regions and returns PA mappings
|
||||||
sufficient to execute kernels and issue DMA requests.
|
sufficient to execute kernels and issue DMA requests.
|
||||||
|
|
||||||
### D3. Extension path (non-breaking)
|
### D3. Phase 1: VA/MMU layer (implemented)
|
||||||
|
|
||||||
A future ADR MAY introduce an optional VA/MMU layer by:
|
#### D3.1 Virtual Address Model
|
||||||
|
|
||||||
- introducing virtual addresses in tensor handles,
|
- Each tensor gets a single contiguous VA range (`TensorHandle.va_base`).
|
||||||
- adding a mapping-install step,
|
- `TensorShard` does NOT carry a `va` field — shard VA is derived as
|
||||||
- modeling translation latency and page granularity.
|
`va_base + offset_bytes`.
|
||||||
|
- Kernels receive `va_base` as their pointer argument (via `TensorArg.va_base`).
|
||||||
|
- `DmaReadCmd.src_addr` and `DmaWriteCmd.dst_addr` carry VA (not PA).
|
||||||
|
|
||||||
The Phase 0 PA model remains a valid fast-path configuration.
|
#### D3.2 PE_MMU Component
|
||||||
|
|
||||||
|
- Hybrid design: SimPy component (inbox for MmuMapMsg) + utility (synchronous
|
||||||
|
`translate()` called by PE_DMA).
|
||||||
|
- Page-aligned dict lookup for O(1) VA→PA translation.
|
||||||
|
- `tlb_overhead_ns` configurable per-access latency.
|
||||||
|
- PageFault fallback: if VA has no mapping, PE_DMA treats it as PA directly
|
||||||
|
(backward compatibility with PA-only tests).
|
||||||
|
|
||||||
|
#### D3.3 Mapping Installation
|
||||||
|
|
||||||
|
- `MmuMapMsg` traverses the fabric: Host → PCIE_EP → IO_CPU (cube fan-out) →
|
||||||
|
M_CPU (PE fan-out) → NOC → PE_MMU. Latency is measured end-to-end.
|
||||||
|
- `MmuMapMsg.target_sips` controls SIP-level routing to prevent cross-SIP
|
||||||
|
mapping contamination for replicated tensors.
|
||||||
|
- Mapping strategy based on `DPPolicy.cube`:
|
||||||
|
- **Replicate** (`cube="replicate"`): per-(sip, cube) local mapping only.
|
||||||
|
Each cube's PEs see only their local PA. No cross-cube mapping installed.
|
||||||
|
- **Sharded** (`cube="shard_m"`, etc.): broadcast all shard mappings to all
|
||||||
|
target cubes. Enables cross-PE and cross-cube DMA.
|
||||||
|
|
||||||
|
#### D3.4 Tensor Lifecycle
|
||||||
|
|
||||||
|
- `del tensor` triggers automatic cleanup via `Tensor.__del__` + `weakref` to
|
||||||
|
RuntimeContext. Sends `MmuUnmapMsg` through fabric, returns VA and PA space.
|
||||||
|
- `with RuntimeContext(...) as ctx:` provides scope-based bulk cleanup.
|
||||||
|
- `RuntimeContext._tensors` uses `weakref.ref` to avoid preventing GC.
|
||||||
|
- `PEMemAllocator` uses free-list with coalescing (not bump allocator).
|
||||||
|
- `VirtualAllocator` uses free-list with coalescing for VA space.
|
||||||
|
|
||||||
|
#### D3.5 Allocators
|
||||||
|
|
||||||
|
- `VirtualAllocator`: device-wide VA space, page-aligned alloc/free with
|
||||||
|
coalescing.
|
||||||
|
- `PEMemAllocator`: per-PE HBM/TCM, free-list based alloc/free with coalescing.
|
||||||
|
- Page size configurable via `topology.yaml` pe_mmu attrs (default 4096).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Consequences
|
## Consequences
|
||||||
|
|
||||||
- Early implementation stays simple and testable.
|
- Triton kernels use `base_addr + offset` patterns naturally on sharded tensors.
|
||||||
- All latency remains explicit via graph traversal, not hidden translation.
|
- All latency remains explicit via graph traversal, including MMU mapping
|
||||||
- Future VA/MMU modeling can be added without breaking existing benchmarks.
|
installation and per-access TLB overhead.
|
||||||
|
- PA-only mode retained as fallback (PageFault → treat as PA).
|
||||||
|
- Benchmark parameter renamed `ctx` → `torch` for PyTorch code compatibility.
|
||||||
|
- IPCQ and other fixed-address resources bypass MMU (use PA directly).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -62,4 +95,6 @@ The Phase 0 PA model remains a valid fast-path configuration.
|
|||||||
- ADR-0007 (runtime_api vs sim_engine boundaries)
|
- ADR-0007 (runtime_api vs sim_engine boundaries)
|
||||||
- ADR-0008 (tensor deployment)
|
- ADR-0008 (tensor deployment)
|
||||||
- ADR-0009 (kernel execution)
|
- ADR-0009 (kernel execution)
|
||||||
|
- ADR-0014 (PE-internal execution model)
|
||||||
|
- ADR-0015 (component port/wire model)
|
||||||
- SPEC R2 (latency by traversal)
|
- SPEC R2 (latency by traversal)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TensorHandle:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
pa: int # physical address in HBM/TCM
|
addr: int # address (VA when MMU enabled, PA otherwise)
|
||||||
shape: tuple[int, ...]
|
shape: tuple[int, ...]
|
||||||
dtype: str
|
dtype: str
|
||||||
nbytes: int # total byte size
|
nbytes: int # total byte size
|
||||||
@@ -50,19 +50,19 @@ class CompletionHandle:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DmaReadCmd:
|
class DmaReadCmd:
|
||||||
"""DMA READ: HBM → PE_TCM."""
|
"""DMA READ: HBM → PE_TCM. src_addr is VA (translated to PA by PE_DMA)."""
|
||||||
|
|
||||||
handle: TensorHandle
|
handle: TensorHandle
|
||||||
src_pa: int
|
src_addr: int
|
||||||
nbytes: int
|
nbytes: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DmaWriteCmd:
|
class DmaWriteCmd:
|
||||||
"""DMA WRITE: PE_TCM → HBM."""
|
"""DMA WRITE: PE_TCM → HBM. dst_addr is VA (translated to PA by PE_DMA)."""
|
||||||
|
|
||||||
handle: TensorHandle
|
handle: TensorHandle
|
||||||
dst_pa: int
|
dst_addr: int
|
||||||
nbytes: int
|
nbytes: int
|
||||||
|
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ class CompositeCmd:
|
|||||||
op: Literal["gemm", "math"]
|
op: Literal["gemm", "math"]
|
||||||
a: TensorHandle
|
a: TensorHandle
|
||||||
b: TensorHandle | None
|
b: TensorHandle | None
|
||||||
out_pa: int
|
out_addr: int
|
||||||
out_nbytes: int
|
out_nbytes: int
|
||||||
math_op: str | None = None # for op="math": which math operation
|
math_op: str | None = None # for op="math": which math operation
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from kernbench.components.impls.pe_dma import PeDmaComponent
|
|||||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||||
from kernbench.components.impls.pe_math import PeMathComponent
|
from kernbench.components.impls.pe_math import PeMathComponent
|
||||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||||
|
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||||
from kernbench.components.impls.sram import SramComponent
|
from kernbench.components.impls.sram import SramComponent
|
||||||
from kernbench.components.impls.xbar import PositionAwareXbarComponent
|
from kernbench.components.impls.xbar import PositionAwareXbarComponent
|
||||||
@@ -36,6 +37,7 @@ ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
|||||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||||
|
ComponentRegistry.register("pe_mmu_v1", PeMmuComponent)
|
||||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -47,6 +49,7 @@ __all__ = [
|
|||||||
"PeDmaComponent",
|
"PeDmaComponent",
|
||||||
"PeGemmComponent",
|
"PeGemmComponent",
|
||||||
"PeMathComponent",
|
"PeMathComponent",
|
||||||
|
"PeMmuComponent",
|
||||||
"PeSchedulerComponent",
|
"PeSchedulerComponent",
|
||||||
"PeTcmComponent",
|
"PeTcmComponent",
|
||||||
"TransitComponent",
|
"TransitComponent",
|
||||||
|
|||||||
@@ -93,7 +93,9 @@ class IoCpuComponent(ComponentBase):
|
|||||||
|
|
||||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||||
"""Return list of (sip, cube) pairs to fan out to."""
|
"""Return list of (sip, cube) pairs to fan out to."""
|
||||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
from kernbench.runtime_api.kernel import (
|
||||||
|
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
|
||||||
|
)
|
||||||
|
|
||||||
target_cubes = getattr(request, "target_cubes", "all")
|
target_cubes = getattr(request, "target_cubes", "all")
|
||||||
|
|
||||||
@@ -130,6 +132,16 @@ class IoCpuComponent(ComponentBase):
|
|||||||
targets.append(key)
|
targets.append(key)
|
||||||
return targets
|
return targets
|
||||||
|
|
||||||
|
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||||
|
my_sip = self._my_sip()
|
||||||
|
if target_cubes == "all":
|
||||||
|
n_cubes = 16
|
||||||
|
if self.ctx and self.ctx.spec:
|
||||||
|
sips = self.ctx.spec.get("system", {}).get("sips", {})
|
||||||
|
n_cubes = sips.get("cubes_per_sip", 16)
|
||||||
|
return [(my_sip, c) for c in range(n_cubes)]
|
||||||
|
return [(my_sip, c) for c in target_cubes]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class MCpuComponent(ComponentBase):
|
|||||||
|
|
||||||
def _worker(self, env: simpy.Environment) -> Generator:
|
def _worker(self, env: simpy.Environment) -> Generator:
|
||||||
"""Dispatch forward txns, collect response txns."""
|
"""Dispatch forward txns, collect response txns."""
|
||||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
txn: Any = yield self._inbox.get()
|
txn: Any = yield self._inbox.get()
|
||||||
@@ -66,6 +66,8 @@ class MCpuComponent(ComponentBase):
|
|||||||
elif self.ctx is not None and txn.request is not None:
|
elif self.ctx is not None and txn.request is not None:
|
||||||
if isinstance(txn.request, KernelLaunchMsg):
|
if isinstance(txn.request, KernelLaunchMsg):
|
||||||
env.process(self._kernel_launch_fanout(env, txn))
|
env.process(self._kernel_launch_fanout(env, txn))
|
||||||
|
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
|
||||||
|
env.process(self._mmu_msg_fanout(env, txn))
|
||||||
else:
|
else:
|
||||||
env.process(self._dma_fanout(env, txn))
|
env.process(self._dma_fanout(env, txn))
|
||||||
else:
|
else:
|
||||||
@@ -261,6 +263,63 @@ class MCpuComponent(ComponentBase):
|
|||||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||||
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
||||||
|
|
||||||
|
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||||
|
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
|
||||||
|
|
||||||
|
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
|
||||||
|
PE_MMU is a terminal node — completes the transaction directly.
|
||||||
|
"""
|
||||||
|
request = txn.request
|
||||||
|
target_pe = getattr(request, "target_pe", "all")
|
||||||
|
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||||
|
pe_ids = self._resolve_pe_ids(target_pe)
|
||||||
|
|
||||||
|
if not pe_ids:
|
||||||
|
txn.done.succeed()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fan out to each PE_MMU
|
||||||
|
sub_dones: list[simpy.Event] = []
|
||||||
|
for pe_id in pe_ids:
|
||||||
|
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
|
||||||
|
try:
|
||||||
|
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if len(path) < 2:
|
||||||
|
continue
|
||||||
|
sub_done = env.event()
|
||||||
|
sub_txn = Transaction(
|
||||||
|
request=request, path=path, step=0,
|
||||||
|
nbytes=0, done=sub_done,
|
||||||
|
)
|
||||||
|
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||||
|
sub_dones.append(sub_done)
|
||||||
|
|
||||||
|
# Wait for all PE_MMUs to complete
|
||||||
|
for sd in sub_dones:
|
||||||
|
yield sd
|
||||||
|
|
||||||
|
# Send aggregate response on reverse path
|
||||||
|
reverse_path = list(reversed(txn.path))
|
||||||
|
if len(reverse_path) >= 2:
|
||||||
|
from kernbench.runtime_api.kernel import ResponseMsg
|
||||||
|
|
||||||
|
parts = self.node.id.split(".")
|
||||||
|
cube_id = int(parts[1].replace("cube", ""))
|
||||||
|
resp_msg = ResponseMsg(
|
||||||
|
correlation_id=request.correlation_id,
|
||||||
|
request_id=request.request_id,
|
||||||
|
src_cube=cube_id, src_pe=-1, success=True,
|
||||||
|
)
|
||||||
|
resp_txn = Transaction(
|
||||||
|
request=resp_msg, path=reverse_path, step=0,
|
||||||
|
nbytes=0, done=env.event(), is_response=True,
|
||||||
|
)
|
||||||
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||||
|
else:
|
||||||
|
txn.done.succeed()
|
||||||
|
|
||||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||||
if isinstance(target_pe, int):
|
if isinstance(target_pe, int):
|
||||||
|
|||||||
@@ -84,10 +84,13 @@ class PeCpuComponent(ComponentBase):
|
|||||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||||
|
|
||||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||||
# TensorArg → PA (pointer), ScalarArg → value
|
# TensorArg → VA base (or PA fallback), ScalarArg → value
|
||||||
kernel_args: list = []
|
kernel_args: list = []
|
||||||
for arg in request.args:
|
for arg in request.args:
|
||||||
if arg.arg_kind == "tensor":
|
if arg.arg_kind == "tensor":
|
||||||
|
if arg.va_base:
|
||||||
|
kernel_args.append(arg.va_base)
|
||||||
|
else:
|
||||||
shard = self._find_shard(arg.shards)
|
shard = self._find_shard(arg.shards)
|
||||||
kernel_args.append(shard.pa)
|
kernel_args.append(shard.pa)
|
||||||
elif arg.arg_kind == "scalar":
|
elif arg.arg_kind == "scalar":
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class PeDmaComponent(PeEngineBase):
|
|||||||
super().__init__(node, ctx)
|
super().__init__(node, ctx)
|
||||||
self._dma_read: simpy.Resource | None = None
|
self._dma_read: simpy.Resource | None = None
|
||||||
self._dma_write: simpy.Resource | None = None
|
self._dma_write: simpy.Resource | None = None
|
||||||
|
self._mmu = None # PeMMU instance, set by engine wiring
|
||||||
|
|
||||||
def init_resources(self, env: simpy.Environment) -> None:
|
def init_resources(self, env: simpy.Environment) -> None:
|
||||||
self._dma_read = simpy.Resource(env, capacity=1)
|
self._dma_read = simpy.Resource(env, capacity=1)
|
||||||
@@ -48,20 +49,32 @@ class PeDmaComponent(PeEngineBase):
|
|||||||
cmd = pe_txn.command
|
cmd = pe_txn.command
|
||||||
assert self._dma_read is not None and self._dma_write is not None
|
assert self._dma_read is not None and self._dma_write is not None
|
||||||
|
|
||||||
# Determine direction and target PA
|
# Determine direction and target address (VA → PA via MMU)
|
||||||
if isinstance(cmd, DmaReadCmd):
|
if isinstance(cmd, DmaReadCmd):
|
||||||
dma_res = self._dma_read
|
dma_res = self._dma_read
|
||||||
target_pa = cmd.src_pa
|
raw_addr = cmd.src_addr
|
||||||
is_write = False
|
is_write = False
|
||||||
elif isinstance(cmd, DmaWriteCmd):
|
elif isinstance(cmd, DmaWriteCmd):
|
||||||
dma_res = self._dma_write
|
dma_res = self._dma_write
|
||||||
target_pa = cmd.dst_pa
|
raw_addr = cmd.dst_addr
|
||||||
is_write = True
|
is_write = True
|
||||||
else:
|
else:
|
||||||
pe_txn.done.succeed()
|
pe_txn.done.succeed()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Resolve PA → HBM node and compute path
|
# Translate VA → PA via MMU (if available), then resolve HBM node
|
||||||
|
# If MMU has no mapping for this address (PageFault), treat as PA directly
|
||||||
|
# (backward-compatible with PA-only mode)
|
||||||
|
if self._mmu is not None:
|
||||||
|
from kernbench.policy.address.pe_mmu import PageFault
|
||||||
|
try:
|
||||||
|
target_pa = self._mmu.translate(raw_addr)
|
||||||
|
if self._mmu.overhead_ns > 0:
|
||||||
|
yield env.timeout(self._mmu.overhead_ns)
|
||||||
|
except PageFault:
|
||||||
|
target_pa = raw_addr
|
||||||
|
else:
|
||||||
|
target_pa = raw_addr # fallback: treat as PA directly
|
||||||
pa = PhysAddr.decode(target_pa)
|
pa = PhysAddr.decode(target_pa)
|
||||||
dst_node = self.ctx.resolver.resolve(pa)
|
dst_node = self.ctx.resolver.resolve(pa)
|
||||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""PE_MMU component: address translation unit.
|
||||||
|
|
||||||
|
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
|
||||||
|
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import simpy
|
||||||
|
|
||||||
|
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kernbench.components.context import ComponentContext
|
||||||
|
from kernbench.topology.types import Node
|
||||||
|
|
||||||
|
|
||||||
|
class PeMmuComponent(ComponentBase):
|
||||||
|
"""PE_MMU: per-PE virtual-to-physical address translation.
|
||||||
|
|
||||||
|
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
|
||||||
|
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
|
||||||
|
via the ``mmu`` property for synchronous VA→PA translation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||||
|
super().__init__(node, ctx)
|
||||||
|
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
|
||||||
|
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
|
||||||
|
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mmu(self) -> PeMMU:
|
||||||
|
"""The underlying PeMMU utility object for direct translate() calls."""
|
||||||
|
return self._mmu
|
||||||
|
|
||||||
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||||
|
yield env.timeout(0)
|
||||||
|
|
||||||
|
def _worker(self, env: simpy.Environment) -> Generator:
|
||||||
|
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||||
|
|
||||||
|
while True:
|
||||||
|
txn: Any = yield self._inbox.get()
|
||||||
|
|
||||||
|
if hasattr(txn, "request"):
|
||||||
|
request = txn.request
|
||||||
|
if isinstance(request, MmuMapMsg):
|
||||||
|
for entry in request.entries:
|
||||||
|
self._mmu.map(
|
||||||
|
va=entry["va"], pa=entry["pa"], size=entry["size"],
|
||||||
|
)
|
||||||
|
txn.done.succeed()
|
||||||
|
elif isinstance(request, MmuUnmapMsg):
|
||||||
|
for entry in request.entries:
|
||||||
|
self._mmu.unmap(va=entry["va"], size=entry["size"])
|
||||||
|
txn.done.succeed()
|
||||||
|
else:
|
||||||
|
# Forward non-MMU transactions normally
|
||||||
|
yield from self._forward_txn(env, txn)
|
||||||
|
else:
|
||||||
|
yield from self._forward_txn(env, txn)
|
||||||
@@ -155,12 +155,12 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
|
|
||||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||||
read_done = env.event()
|
read_done = env.event()
|
||||||
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
|
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
||||||
b_tile_handle = TensorHandle(
|
b_tile_handle = TensorHandle(
|
||||||
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
|
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
||||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||||
)
|
)
|
||||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
|
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
||||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||||
t0 = env.now
|
t0 = env.now
|
||||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||||
@@ -176,7 +176,7 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
# --- Stage 2: COMPUTE (GEMM) ---
|
# --- Stage 2: COMPUTE (GEMM) ---
|
||||||
compute_done = env.event()
|
compute_done = env.event()
|
||||||
out_handle = TensorHandle(
|
out_handle = TensorHandle(
|
||||||
id=f"out_tile_{tile_idx}", pa=0,
|
id=f"out_tile_{tile_idx}", addr=0,
|
||||||
shape=(M, tile_n), dtype=dtype,
|
shape=(M, tile_n), dtype=dtype,
|
||||||
nbytes=M * tile_n * dtype_bytes,
|
nbytes=M * tile_n * dtype_bytes,
|
||||||
)
|
)
|
||||||
@@ -197,9 +197,9 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
|
|
||||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||||
write_done = env.event()
|
write_done = env.event()
|
||||||
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
|
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
||||||
write_nbytes = M * tile_n * dtype_bytes
|
write_nbytes = M * tile_n * dtype_bytes
|
||||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
|
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
||||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||||
t0 = env.now
|
t0 = env.now
|
||||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||||
@@ -237,7 +237,7 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
|
|
||||||
# Step 2: DMA_WRITE result to HBM
|
# Step 2: DMA_WRITE result to HBM
|
||||||
write_done = env.event()
|
write_done = env.event()
|
||||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
|
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
||||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||||
yield write_done
|
yield write_done
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from kernbench.policy.address.phyaddr import PhysAddr
|
from kernbench.policy.address.phyaddr import PhysAddr
|
||||||
@@ -29,6 +30,63 @@ class AddressConfig:
|
|||||||
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
|
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class _FreeList:
|
||||||
|
"""Offset-based free-list allocator with coalescing."""
|
||||||
|
|
||||||
|
def __init__(self, capacity: int) -> None:
|
||||||
|
self._capacity = capacity
|
||||||
|
self._used = 0
|
||||||
|
self._free: list[tuple[int, int]] = [(0, capacity)] # (offset, size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def used(self) -> int:
|
||||||
|
return self._used
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return self._capacity
|
||||||
|
|
||||||
|
def alloc(self, nbytes: int) -> int:
|
||||||
|
"""Allocate nbytes, return offset. Raises AllocationError if full."""
|
||||||
|
for i, (start, size) in enumerate(self._free):
|
||||||
|
if size >= nbytes:
|
||||||
|
if size == nbytes:
|
||||||
|
self._free.pop(i)
|
||||||
|
else:
|
||||||
|
self._free[i] = (start + nbytes, size - nbytes)
|
||||||
|
self._used += nbytes
|
||||||
|
return start
|
||||||
|
raise AllocationError(
|
||||||
|
f"overflow: need {nbytes}, "
|
||||||
|
f"largest free block {max((s for _, s in self._free), default=0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def free(self, offset: int, nbytes: int) -> None:
|
||||||
|
"""Return a range to the free-list with coalescing."""
|
||||||
|
self._used -= nbytes
|
||||||
|
new_start = offset
|
||||||
|
new_end = offset + nbytes
|
||||||
|
|
||||||
|
idx = bisect.bisect_left(self._free, (offset,))
|
||||||
|
|
||||||
|
# Coalesce with previous block
|
||||||
|
if idx > 0:
|
||||||
|
prev_start, prev_size = self._free[idx - 1]
|
||||||
|
if prev_start + prev_size == new_start:
|
||||||
|
new_start = prev_start
|
||||||
|
idx -= 1
|
||||||
|
self._free.pop(idx)
|
||||||
|
|
||||||
|
# Coalesce with next block
|
||||||
|
if idx < len(self._free):
|
||||||
|
next_start, next_size = self._free[idx]
|
||||||
|
if new_end == next_start:
|
||||||
|
new_end = next_start + next_size
|
||||||
|
self._free.pop(idx)
|
||||||
|
|
||||||
|
self._free.insert(idx, (new_start, new_end - new_start))
|
||||||
|
|
||||||
|
|
||||||
class PEMemAllocator:
|
class PEMemAllocator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
|
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
|
||||||
@@ -38,39 +96,48 @@ class PEMemAllocator:
|
|||||||
self._cube_id = cube_id
|
self._cube_id = cube_id
|
||||||
self._pe_id = pe_id
|
self._pe_id = pe_id
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
self._hbm_cursor = 0
|
self._hbm = _FreeList(cfg.hbm_slice_bytes)
|
||||||
self._tcm_cursor = 0
|
self._tcm = _FreeList(cfg.tcm_allocatable_bytes)
|
||||||
|
|
||||||
def alloc_hbm(self, nbytes: int) -> PhysAddr:
|
def alloc_hbm(self, nbytes: int) -> PhysAddr:
|
||||||
if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes:
|
try:
|
||||||
|
offset = self._hbm.alloc(nbytes)
|
||||||
|
except AllocationError:
|
||||||
raise AllocationError(
|
raise AllocationError(
|
||||||
f"HBM overflow: need {nbytes}, "
|
f"HBM overflow: need {nbytes}, "
|
||||||
f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}"
|
f"available {self._cfg.hbm_slice_bytes - self._hbm.used}"
|
||||||
)
|
)
|
||||||
pa = PhysAddr.pe_hbm_addr(
|
return PhysAddr.pe_hbm_addr(
|
||||||
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
||||||
pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor,
|
pe_id=self._pe_id, pe_local_hbm_offset=offset,
|
||||||
slice_size_bytes=self._cfg.hbm_slice_bytes,
|
slice_size_bytes=self._cfg.hbm_slice_bytes,
|
||||||
)
|
)
|
||||||
self._hbm_cursor += nbytes
|
|
||||||
return pa
|
def free_hbm(self, pa: PhysAddr, nbytes: int) -> None:
|
||||||
|
# Extract PE-local offset from the PA's hbm_offset
|
||||||
|
pe_slice_start = self._pe_id * self._cfg.hbm_slice_bytes
|
||||||
|
offset = pa.hbm_offset - pe_slice_start
|
||||||
|
self._hbm.free(offset, nbytes)
|
||||||
|
|
||||||
def alloc_tcm(self, nbytes: int) -> PhysAddr:
|
def alloc_tcm(self, nbytes: int) -> PhysAddr:
|
||||||
if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes:
|
try:
|
||||||
|
offset = self._tcm.alloc(nbytes)
|
||||||
|
except AllocationError:
|
||||||
raise AllocationError(
|
raise AllocationError(
|
||||||
f"TCM overflow: need {nbytes}, "
|
f"TCM overflow: need {nbytes}, "
|
||||||
f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}"
|
f"available {self._cfg.tcm_allocatable_bytes - self._tcm.used}"
|
||||||
)
|
)
|
||||||
pa = PhysAddr.pe_tcm_addr(
|
return PhysAddr.pe_tcm_addr(
|
||||||
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
||||||
pe_id=self._pe_id, tcm_offset=self._tcm_cursor,
|
pe_id=self._pe_id, tcm_offset=offset,
|
||||||
)
|
)
|
||||||
self._tcm_cursor += nbytes
|
|
||||||
return pa
|
def free_tcm(self, pa: PhysAddr, nbytes: int) -> None:
|
||||||
|
self._tcm.free(pa.sub_offset, nbytes)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hbm_used(self) -> int:
|
def hbm_used(self) -> int:
|
||||||
return self._hbm_cursor
|
return self._hbm.used
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hbm_total(self) -> int:
|
def hbm_total(self) -> int:
|
||||||
@@ -78,7 +145,7 @@ class PEMemAllocator:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def tcm_used(self) -> int:
|
def tcm_used(self) -> int:
|
||||||
return self._tcm_cursor
|
return self._tcm.used
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tcm_total(self) -> int:
|
def tcm_total(self) -> int:
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""PeMMU: per-PE virtual-to-physical address translation.
|
||||||
|
|
||||||
|
Page-aligned dict lookup for O(1) VA→PA translation.
|
||||||
|
Used as a utility class by PE_DMA, PE_GEMM (direct call),
|
||||||
|
and as a component inbox target for MmuMapMsg/MmuUnmapMsg.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class PageFault(Exception):
|
||||||
|
"""Raised when VA has no mapping in the page table."""
|
||||||
|
|
||||||
|
def __init__(self, va: int | str) -> None:
|
||||||
|
if isinstance(va, str):
|
||||||
|
super().__init__(va)
|
||||||
|
else:
|
||||||
|
self.va = va
|
||||||
|
super().__init__(f"PageFault at VA 0x{va:x}")
|
||||||
|
|
||||||
|
|
||||||
|
class PeMMU:
|
||||||
|
"""Per-PE MMU with page-aligned VA→PA translation table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page_size: Page size in bytes (default 2 MB).
|
||||||
|
overhead_ns: Per-access TLB lookup latency in nanoseconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
page_size: int = 2 * 1024 * 1024,
|
||||||
|
overhead_ns: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
self._page_size = page_size
|
||||||
|
self._page_shift = (page_size - 1).bit_length()
|
||||||
|
self._page_mask = page_size - 1
|
||||||
|
self._table: dict[int, int] = {} # va_page_number → pa_page_base
|
||||||
|
self._overhead_ns = overhead_ns
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overhead_ns(self) -> float:
|
||||||
|
return self._overhead_ns
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_entries(self) -> int:
|
||||||
|
return len(self._table)
|
||||||
|
|
||||||
|
def map(self, va: int, pa: int, size: int) -> None:
|
||||||
|
"""Register VA→PA mapping for a contiguous range."""
|
||||||
|
for off in range(0, size, self._page_size):
|
||||||
|
vpn = (va + off) >> self._page_shift
|
||||||
|
self._table[vpn] = pa + off
|
||||||
|
|
||||||
|
def unmap(self, va: int, size: int) -> None:
|
||||||
|
"""Remove VA mapping for a contiguous range."""
|
||||||
|
for off in range(0, size, self._page_size):
|
||||||
|
vpn = (va + off) >> self._page_shift
|
||||||
|
self._table.pop(vpn, None)
|
||||||
|
|
||||||
|
def translate(self, va: int) -> int:
|
||||||
|
"""Translate VA to PA. Raises PageFault if unmapped."""
|
||||||
|
vpn = va >> self._page_shift
|
||||||
|
pa_page_base = self._table.get(vpn)
|
||||||
|
if pa_page_base is None:
|
||||||
|
raise PageFault(va)
|
||||||
|
return pa_page_base + (va & self._page_mask)
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
"""VirtualAllocator: device-wide VA space management with free-list.
|
||||||
|
|
||||||
|
Allocations are page-aligned. Freed ranges are coalesced with adjacent
|
||||||
|
free blocks to allow larger re-allocations.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
|
||||||
|
|
||||||
|
class VaAllocationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualAllocator:
|
||||||
|
"""Manages a contiguous VA address space with page-aligned alloc/free.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
va_base: Start of the VA range.
|
||||||
|
va_size: Total size of the VA range in bytes.
|
||||||
|
page_size: Page granularity in bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
va_base: int,
|
||||||
|
va_size: int,
|
||||||
|
page_size: int = 2 * 1024 * 1024,
|
||||||
|
) -> None:
|
||||||
|
self._va_base = va_base
|
||||||
|
self._va_size = va_size
|
||||||
|
self._page_size = page_size
|
||||||
|
self._used = 0
|
||||||
|
# Free list: sorted list of (start, size) tuples
|
||||||
|
self._free: list[tuple[int, int]] = [(va_base, va_size)]
|
||||||
|
|
||||||
|
def _align_up(self, nbytes: int) -> int:
|
||||||
|
"""Round up to page boundary."""
|
||||||
|
return ((nbytes + self._page_size - 1) // self._page_size) * self._page_size
|
||||||
|
|
||||||
|
def alloc(self, nbytes: int) -> int:
|
||||||
|
"""Allocate a contiguous VA range. Returns the start VA."""
|
||||||
|
aligned = self._align_up(nbytes)
|
||||||
|
for i, (start, size) in enumerate(self._free):
|
||||||
|
if size >= aligned:
|
||||||
|
# Take from the beginning of this free block
|
||||||
|
if size == aligned:
|
||||||
|
self._free.pop(i)
|
||||||
|
else:
|
||||||
|
self._free[i] = (start + aligned, size - aligned)
|
||||||
|
self._used += aligned
|
||||||
|
return start
|
||||||
|
raise VaAllocationError(
|
||||||
|
f"Out of VA space: need {aligned}, largest free block "
|
||||||
|
f"{max((s for _, s in self._free), default=0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def free(self, va: int, nbytes: int) -> None:
|
||||||
|
"""Free a VA range and coalesce with adjacent free blocks."""
|
||||||
|
aligned = self._align_up(nbytes)
|
||||||
|
self._used -= aligned
|
||||||
|
|
||||||
|
# Insert into sorted free list and coalesce
|
||||||
|
new_start = va
|
||||||
|
new_end = va + aligned
|
||||||
|
|
||||||
|
# Find insertion point
|
||||||
|
idx = bisect.bisect_left(self._free, (va,))
|
||||||
|
|
||||||
|
# Try coalesce with previous block
|
||||||
|
if idx > 0:
|
||||||
|
prev_start, prev_size = self._free[idx - 1]
|
||||||
|
if prev_start + prev_size == new_start:
|
||||||
|
new_start = prev_start
|
||||||
|
idx -= 1
|
||||||
|
self._free.pop(idx)
|
||||||
|
|
||||||
|
# Try coalesce with next block
|
||||||
|
if idx < len(self._free):
|
||||||
|
next_start, next_size = self._free[idx]
|
||||||
|
if new_end == next_start:
|
||||||
|
new_end = next_start + next_size
|
||||||
|
self._free.pop(idx)
|
||||||
|
|
||||||
|
self._free.insert(idx, (new_start, new_end - new_start))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def used(self) -> int:
|
||||||
|
return self._used
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return self._va_size
|
||||||
@@ -19,8 +19,18 @@ class RuntimeContext:
|
|||||||
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
||||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||||
|
_va_allocator: Any = field(default=None, init=False)
|
||||||
|
_mmus: dict[int, Any] = field(default_factory=dict, init=False)
|
||||||
_tensor_counter: int = field(default=0, init=False)
|
_tensor_counter: int = field(default=0, init=False)
|
||||||
_traces: list[dict] = field(default_factory=list, init=False)
|
_traces: list[dict] = field(default_factory=list, init=False)
|
||||||
|
_tensors: list[Any] = field(default_factory=list, init=False)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self.cleanup()
|
||||||
|
return False
|
||||||
|
|
||||||
def submit(self, request: Any) -> RequestHandle:
|
def submit(self, request: Any) -> RequestHandle:
|
||||||
submit_fn = getattr(self.engine, "submit", None)
|
submit_fn = getattr(self.engine, "submit", None)
|
||||||
@@ -58,6 +68,92 @@ class RuntimeContext:
|
|||||||
def handles(self) -> list[RequestHandle]:
|
def handles(self) -> list[RequestHandle]:
|
||||||
return list(self._handles)
|
return list(self._handles)
|
||||||
|
|
||||||
|
# ── Tensor lifecycle ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _free_tensor(self, tensor: Any) -> None:
|
||||||
|
"""Free a single tensor: unmap MMU, return VA and PA."""
|
||||||
|
handle = tensor._handle
|
||||||
|
if handle is None:
|
||||||
|
return
|
||||||
|
tensor._handle = None
|
||||||
|
|
||||||
|
if not handle.va_base:
|
||||||
|
return
|
||||||
|
|
||||||
|
from kernbench.runtime_api.kernel import MmuUnmapMsg
|
||||||
|
|
||||||
|
dp_policy = None
|
||||||
|
if tensor._dp_metadata is not None:
|
||||||
|
dp_policy = tensor._dp_metadata.dp_policy
|
||||||
|
|
||||||
|
is_cube_replicate = (
|
||||||
|
dp_policy is not None and dp_policy.cube == "replicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send MmuUnmapMsg through fabric
|
||||||
|
from collections import defaultdict
|
||||||
|
if is_cube_replicate:
|
||||||
|
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||||
|
for shard in handle.shards:
|
||||||
|
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||||
|
for (sip, cube), group_shards in cube_groups.items():
|
||||||
|
entries = tuple(
|
||||||
|
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||||||
|
for s in group_shards
|
||||||
|
)
|
||||||
|
msg = MmuUnmapMsg(
|
||||||
|
correlation_id=self.correlation_id,
|
||||||
|
request_id=f"unmap_{tensor.name}_s{sip}c{cube}",
|
||||||
|
entries=entries,
|
||||||
|
target_sips=(sip,),
|
||||||
|
target_cubes=(cube,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = self.submit(msg)
|
||||||
|
self.wait(h)
|
||||||
|
else:
|
||||||
|
entries = tuple(
|
||||||
|
{"va": handle.va_base + s.offset_bytes, "size": s.nbytes}
|
||||||
|
for s in handle.shards
|
||||||
|
)
|
||||||
|
sip_set = sorted({s.sip for s in handle.shards})
|
||||||
|
cube_set = sorted({s.cube for s in handle.shards})
|
||||||
|
msg = MmuUnmapMsg(
|
||||||
|
correlation_id=self.correlation_id,
|
||||||
|
request_id=f"unmap_{tensor.name}",
|
||||||
|
entries=entries,
|
||||||
|
target_sips=tuple(sip_set),
|
||||||
|
target_cubes=tuple(cube_set),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = self.submit(msg)
|
||||||
|
self.wait(h)
|
||||||
|
|
||||||
|
# Return VA space
|
||||||
|
if self._va_allocator is not None:
|
||||||
|
self._va_allocator.free(handle.va_base, handle.nbytes)
|
||||||
|
|
||||||
|
# Return PA space
|
||||||
|
if self._allocators:
|
||||||
|
for shard in handle.shards:
|
||||||
|
flat_idx = (
|
||||||
|
shard.sip * self._num_cubes * self._pes_per_cube
|
||||||
|
+ shard.cube * self._pes_per_cube
|
||||||
|
+ shard.pe
|
||||||
|
)
|
||||||
|
alloc = self._allocators.get(flat_idx)
|
||||||
|
if alloc is not None:
|
||||||
|
from kernbench.policy.address.phyaddr import PhysAddr
|
||||||
|
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Free all tensors created by this context."""
|
||||||
|
for ref in self._tensors:
|
||||||
|
t = ref()
|
||||||
|
if t is not None and t._handle is not None:
|
||||||
|
self._free_tensor(t)
|
||||||
|
self._tensors.clear()
|
||||||
|
|
||||||
# ── PyTorch-like tensor API ──────────────────────────────────────
|
# ── PyTorch-like tensor API ──────────────────────────────────────
|
||||||
|
|
||||||
def _ensure_allocators(self) -> dict:
|
def _ensure_allocators(self) -> dict:
|
||||||
@@ -111,6 +207,26 @@ class RuntimeContext:
|
|||||||
self._allocators[flat_idx] = PEMemAllocator(
|
self._allocators[flat_idx] = PEMemAllocator(
|
||||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize VA allocator and per-PE MMUs
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||||
|
|
||||||
|
pe_mmu_attrs = pe_comps.get("pe_mmu", {}).get("attrs", {})
|
||||||
|
page_size = int(pe_mmu_attrs.get("page_size", 4096))
|
||||||
|
tlb_overhead_ns = float(pe_mmu_attrs.get("tlb_overhead_ns", 0.0))
|
||||||
|
|
||||||
|
self._va_allocator = VirtualAllocator(
|
||||||
|
va_base=0x1_0000_0000,
|
||||||
|
va_size=64 * (1 << 30), # 64 GB VA space
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
total_pes = sip_count * cubes_per_sip * pes_per_cube
|
||||||
|
for flat_idx in range(total_pes):
|
||||||
|
self._mmus[flat_idx] = PeMMU(
|
||||||
|
page_size=page_size, overhead_ns=tlb_overhead_ns,
|
||||||
|
)
|
||||||
|
|
||||||
return self._allocators
|
return self._allocators
|
||||||
|
|
||||||
def _next_tensor_name(self) -> str:
|
def _next_tensor_name(self) -> str:
|
||||||
@@ -122,45 +238,41 @@ class RuntimeContext:
|
|||||||
shape: tuple[int, ...],
|
shape: tuple[int, ...],
|
||||||
dtype: str = "f16",
|
dtype: str = "f16",
|
||||||
*,
|
*,
|
||||||
placement: list | None = None,
|
|
||||||
dp: Any = None,
|
dp: Any = None,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
):
|
):
|
||||||
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
|
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
|
||||||
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
|
return self._create_tensor(shape, dtype, name, pattern="zero", dp=dp)
|
||||||
|
|
||||||
def empty(
|
def empty(
|
||||||
self,
|
self,
|
||||||
shape: tuple[int, ...],
|
shape: tuple[int, ...],
|
||||||
dtype: str = "f16",
|
dtype: str = "f16",
|
||||||
*,
|
*,
|
||||||
placement: list | None = None,
|
|
||||||
dp: Any = None,
|
dp: Any = None,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
):
|
):
|
||||||
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
||||||
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
|
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
|
||||||
|
|
||||||
def _create_tensor(
|
def _create_tensor(
|
||||||
self,
|
self,
|
||||||
shape: tuple[int, ...],
|
shape: tuple[int, ...],
|
||||||
dtype: str,
|
dtype: str,
|
||||||
placement: list | None,
|
|
||||||
name: str | None,
|
name: str | None,
|
||||||
pattern: str | None,
|
pattern: str | None,
|
||||||
dp: Any = None,
|
dp: Any = None,
|
||||||
):
|
):
|
||||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||||
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
|
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
|
||||||
|
|
||||||
|
if not isinstance(dp, DPPolicy):
|
||||||
|
raise ValueError("dp=DPPolicy(...) is required for tensor creation")
|
||||||
|
|
||||||
tensor_name = name or self._next_tensor_name()
|
tensor_name = name or self._next_tensor_name()
|
||||||
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
|
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
|
||||||
|
|
||||||
dp_policy: DPPolicy | None = None
|
|
||||||
|
|
||||||
# Resolve placement: dp= takes priority over placement=
|
|
||||||
if dp is not None and isinstance(dp, DPPolicy):
|
|
||||||
dp_policy = dp
|
dp_policy = dp
|
||||||
allocators = self._ensure_allocators()
|
allocators = self._ensure_allocators()
|
||||||
itemsize = dtype_itemsize(dtype)
|
itemsize = dtype_itemsize(dtype)
|
||||||
@@ -170,15 +282,13 @@ class RuntimeContext:
|
|||||||
dp, shape=shape_2d, itemsize=itemsize,
|
dp, shape=shape_2d, itemsize=itemsize,
|
||||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||||
)
|
)
|
||||||
elif placement is None:
|
|
||||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
|
|
||||||
|
|
||||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||||
pe_indices = {s.pe_index for s in placement}
|
pe_indices = {s.pe_index for s in placement}
|
||||||
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
|
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
|
||||||
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
||||||
|
|
||||||
# Allocate PAs via PEMemAllocator
|
# Allocate PAs via PEMemAllocator + VA via VirtualAllocator
|
||||||
allocators = self._ensure_allocators()
|
allocators = self._ensure_allocators()
|
||||||
handle = deploy_tensor(
|
handle = deploy_tensor(
|
||||||
name=tensor_name,
|
name=tensor_name,
|
||||||
@@ -186,8 +296,64 @@ class RuntimeContext:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
placement=placement,
|
placement=placement,
|
||||||
allocators=allocators,
|
allocators=allocators,
|
||||||
|
va_allocator=self._va_allocator,
|
||||||
|
mmus=self._mmus,
|
||||||
)
|
)
|
||||||
t._handle = handle
|
t._handle = handle
|
||||||
|
import weakref
|
||||||
|
t._ctx_ref = weakref.ref(self)
|
||||||
|
self._tensors.append(weakref.ref(t))
|
||||||
|
|
||||||
|
# Install VA→PA mappings via fabric MmuMapMsg
|
||||||
|
if handle.va_base:
|
||||||
|
from collections import defaultdict
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
is_cube_replicate = (
|
||||||
|
dp_policy is not None and dp_policy.cube == "replicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_cube_replicate:
|
||||||
|
# Replicate: each (sip, cube) gets only its own local PA mappings
|
||||||
|
cube_groups: dict[tuple[int, int], list] = defaultdict(list)
|
||||||
|
for shard in handle.shards:
|
||||||
|
cube_groups[(shard.sip, shard.cube)].append(shard)
|
||||||
|
|
||||||
|
for (sip, cube), group_shards in cube_groups.items():
|
||||||
|
entries = tuple(
|
||||||
|
{"va": handle.va_base + s.offset_bytes,
|
||||||
|
"pa": s.pa, "size": s.nbytes}
|
||||||
|
for s in group_shards
|
||||||
|
)
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id=self.correlation_id,
|
||||||
|
request_id=f"mmu_{tensor_name}_s{sip}c{cube}",
|
||||||
|
entries=entries,
|
||||||
|
target_sips=(sip,),
|
||||||
|
target_cubes=(cube,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = self.submit(msg)
|
||||||
|
self.wait(h)
|
||||||
|
else:
|
||||||
|
# Sharded: broadcast all mappings to all target (sip, cube)s
|
||||||
|
entries = tuple(
|
||||||
|
{"va": handle.va_base + s.offset_bytes,
|
||||||
|
"pa": s.pa, "size": s.nbytes}
|
||||||
|
for s in handle.shards
|
||||||
|
)
|
||||||
|
sip_set = sorted({s.sip for s in handle.shards})
|
||||||
|
cube_set = sorted({s.cube for s in handle.shards})
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id=self.correlation_id,
|
||||||
|
request_id=f"mmu_{tensor_name}",
|
||||||
|
entries=entries,
|
||||||
|
target_sips=tuple(sip_set),
|
||||||
|
target_cubes=tuple(cube_set),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = self.submit(msg)
|
||||||
|
self.wait(h)
|
||||||
|
|
||||||
# Submit MemoryWriteMsg per shard (deploy data to device)
|
# Submit MemoryWriteMsg per shard (deploy data to device)
|
||||||
if pattern is not None:
|
if pattern is not None:
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class TensorArgShard:
|
|||||||
class TensorArg:
|
class TensorArg:
|
||||||
shards: tuple[TensorArgShard, ...]
|
shards: tuple[TensorArgShard, ...]
|
||||||
arg_kind: Literal["tensor"] = "tensor"
|
arg_kind: Literal["tensor"] = "tensor"
|
||||||
|
va_base: int = 0 # VA base address for the entire tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -121,3 +122,33 @@ class PeDmaMsg:
|
|||||||
nbytes: int
|
nbytes: int
|
||||||
is_write: bool = False
|
is_write: bool = False
|
||||||
msg_type: Literal["pe_dma"] = "pe_dma"
|
msg_type: Literal["pe_dma"] = "pe_dma"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MmuMapMsg:
|
||||||
|
"""MMU mapping install: broadcast VA→PA entries to target PEs.
|
||||||
|
|
||||||
|
Sent via fabric: Host → PCIE_EP → IO_CPU → M_CPU → NOC → PE_MMU.
|
||||||
|
target_sips controls which SIPs receive the message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
correlation_id: str
|
||||||
|
request_id: str
|
||||||
|
entries: tuple[dict, ...] # ({"va": int, "pa": int, "size": int}, ...)
|
||||||
|
target_sips: tuple[int, ...] | Literal["all"] = "all"
|
||||||
|
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||||
|
target_pe: int | Literal["all"] = "all"
|
||||||
|
msg_type: Literal["mmu_map"] = "mmu_map"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MmuUnmapMsg:
|
||||||
|
"""MMU mapping removal: broadcast VA ranges to unmap from all PEs."""
|
||||||
|
|
||||||
|
correlation_id: str
|
||||||
|
request_id: str
|
||||||
|
entries: tuple[dict, ...] # ({"va": int, "size": int}, ...)
|
||||||
|
target_sips: tuple[int, ...] | Literal["all"] = "all"
|
||||||
|
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||||
|
target_pe: int | Literal["all"] = "all"
|
||||||
|
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class TensorHandle:
|
|||||||
dtype: str
|
dtype: str
|
||||||
itemsize: int
|
itemsize: int
|
||||||
shards: tuple[TensorShard, ...]
|
shards: tuple[TensorShard, ...]
|
||||||
|
va_base: int = 0 # VA base address for the entire tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nbytes(self) -> int:
|
def nbytes(self) -> int:
|
||||||
@@ -56,8 +58,19 @@ def deploy_tensor(
|
|||||||
placement: list[ShardSpec],
|
placement: list[ShardSpec],
|
||||||
allocators: dict[int, PEMemAllocator],
|
allocators: dict[int, PEMemAllocator],
|
||||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||||
|
va_allocator=None,
|
||||||
|
mmus: dict | None = None,
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
|
||||||
isize = dtype_itemsize(dtype)
|
isize = dtype_itemsize(dtype)
|
||||||
|
total_nbytes = math.prod(shape) * isize
|
||||||
|
|
||||||
|
# Allocate VA range for the entire tensor (if VA allocator provided)
|
||||||
|
va_base = 0
|
||||||
|
if va_allocator is not None:
|
||||||
|
va_base = va_allocator.alloc(total_nbytes)
|
||||||
|
|
||||||
shards: list[TensorShard] = []
|
shards: list[TensorShard] = []
|
||||||
for spec in placement:
|
for spec in placement:
|
||||||
alloc = allocators[spec.pe_index]
|
alloc = allocators[spec.pe_index]
|
||||||
@@ -65,20 +78,29 @@ def deploy_tensor(
|
|||||||
pa = alloc.alloc_hbm(spec.nbytes)
|
pa = alloc.alloc_hbm(spec.nbytes)
|
||||||
else:
|
else:
|
||||||
pa = alloc.alloc_tcm(spec.nbytes)
|
pa = alloc.alloc_tcm(spec.nbytes)
|
||||||
|
encoded_pa = pa.encode()
|
||||||
shards.append(TensorShard(
|
shards.append(TensorShard(
|
||||||
sip=alloc._sip_id,
|
sip=alloc._sip_id,
|
||||||
cube=alloc._cube_id,
|
cube=alloc._cube_id,
|
||||||
pe=alloc._pe_id,
|
pe=alloc._pe_id,
|
||||||
pa=pa.encode(),
|
pa=encoded_pa,
|
||||||
nbytes=spec.nbytes,
|
nbytes=spec.nbytes,
|
||||||
offset_bytes=spec.offset_bytes,
|
offset_bytes=spec.offset_bytes,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Register VA→PA mapping in all MMUs (broadcast)
|
||||||
|
if va_base and mmus is not None:
|
||||||
|
shard_va = va_base + spec.offset_bytes
|
||||||
|
for mmu in mmus.values():
|
||||||
|
mmu.map(va=shard_va, pa=encoded_pa, size=spec.nbytes)
|
||||||
|
|
||||||
return TensorHandle(
|
return TensorHandle(
|
||||||
name=name,
|
name=name,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
itemsize=isize,
|
itemsize=isize,
|
||||||
shards=tuple(shards),
|
shards=tuple(shards),
|
||||||
|
va_base=va_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -101,8 +123,7 @@ class Tensor:
|
|||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
a = ctx.zeros((M, K), dtype="f16")
|
a = ctx.zeros((M, K), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"))
|
||||||
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
|
|
||||||
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -117,6 +138,14 @@ class Tensor:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self._dp_metadata: DPMetadata | None = None
|
self._dp_metadata: DPMetadata | None = None
|
||||||
self._handle: TensorHandle | None = None
|
self._handle: TensorHandle | None = None
|
||||||
|
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self._ctx_ref is None or self._handle is None:
|
||||||
|
return
|
||||||
|
ctx = self._ctx_ref()
|
||||||
|
if ctx is not None:
|
||||||
|
ctx._free_tensor(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def itemsize(self) -> int:
|
def itemsize(self) -> int:
|
||||||
@@ -133,6 +162,13 @@ class Tensor:
|
|||||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||||
return self._handle.shards[0].pa
|
return self._handle.shards[0].pa
|
||||||
|
|
||||||
|
@property
|
||||||
|
def va(self) -> int:
|
||||||
|
"""VA base address for the entire tensor."""
|
||||||
|
if self._handle is None:
|
||||||
|
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||||
|
return self._handle.va_base
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self,
|
self,
|
||||||
placement: list[ShardSpec] | None = None,
|
placement: list[ShardSpec] | None = None,
|
||||||
@@ -163,4 +199,5 @@ class Tensor:
|
|||||||
)
|
)
|
||||||
for s in self._handle.shards
|
for s in self._handle.shards
|
||||||
),
|
),
|
||||||
|
va_base=self._handle.va_base,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -98,6 +98,16 @@ class GraphEngine:
|
|||||||
self._components[node_id].in_ports["host"] = host_q
|
self._components[node_id].in_ports["host"] = host_q
|
||||||
self._pe_dma_queues[node_id] = host_q
|
self._pe_dma_queues[node_id] = host_q
|
||||||
|
|
||||||
|
# Wire PE_DMA._mmu to PE_MMU's underlying PeMMU utility object
|
||||||
|
for node_id, node in graph.nodes.items():
|
||||||
|
if node.kind == "pe_dma":
|
||||||
|
# Derive PE_MMU node ID from PE_DMA node ID
|
||||||
|
pe_prefix = node_id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
||||||
|
mmu_id = f"{pe_prefix}.pe_mmu"
|
||||||
|
mmu_comp = self._components.get(mmu_id)
|
||||||
|
if mmu_comp is not None and hasattr(mmu_comp, "mmu"):
|
||||||
|
self._components[node_id]._mmu = mmu_comp.mmu
|
||||||
|
|
||||||
# Start components after all ports are wired (ADR-0015 D3)
|
# Start components after all ports are wired (ADR-0015 D3)
|
||||||
for comp in self._components.values():
|
for comp in self._components.values():
|
||||||
comp.start(self._env)
|
comp.start(self._env)
|
||||||
@@ -119,6 +129,27 @@ class GraphEngine:
|
|||||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
||||||
return self._results[str(handle)]
|
return self._results[str(handle)]
|
||||||
|
|
||||||
|
def mmu_map(self, va: int, pa: int, size: int) -> None:
|
||||||
|
"""Sideband: install VA→PA mapping in all PE_MMU components."""
|
||||||
|
for node_id, comp in self._components.items():
|
||||||
|
if hasattr(comp, "mmu"):
|
||||||
|
comp.mmu.map(va=va, pa=pa, size=size)
|
||||||
|
|
||||||
|
def mmu_map_pe(
|
||||||
|
self, sip: int, cube: int, pe: int, va: int, pa: int, size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Sideband: install VA→PA mapping in a specific PE's MMU only."""
|
||||||
|
mmu_id = f"sip{sip}.cube{cube}.pe{pe}.pe_mmu"
|
||||||
|
comp = self._components.get(mmu_id)
|
||||||
|
if comp is not None and hasattr(comp, "mmu"):
|
||||||
|
comp.mmu.map(va=va, pa=pa, size=size)
|
||||||
|
|
||||||
|
def mmu_unmap(self, va: int, size: int) -> None:
|
||||||
|
"""Sideband: remove VA mapping from all PE_MMU components."""
|
||||||
|
for node_id, comp in self._components.items():
|
||||||
|
if hasattr(comp, "mmu"):
|
||||||
|
comp.mmu.unmap(va=va, size=size)
|
||||||
|
|
||||||
# ── internal ────────────────────────────────────────────────────
|
# ── internal ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _wire(
|
def _wire(
|
||||||
@@ -166,6 +197,11 @@ class GraphEngine:
|
|||||||
yield from self._process_memory_direct(key, request, done)
|
yield from self._process_memory_direct(key, request, done)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||||
|
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||||
|
yield from self._process_mmu_msg(key, request, done)
|
||||||
|
return
|
||||||
|
|
||||||
entries = self._entry_points(request)
|
entries = self._entry_points(request)
|
||||||
if not entries:
|
if not entries:
|
||||||
self._results[key] = (
|
self._results[key] = (
|
||||||
@@ -341,3 +377,59 @@ class GraphEngine:
|
|||||||
return entries
|
return entries
|
||||||
|
|
||||||
raise ValueError(f"unsupported request type: {type(request)}")
|
raise ValueError(f"unsupported request type: {type(request)}")
|
||||||
|
|
||||||
|
def _process_mmu_msg(self, key: str, request: Any, done: simpy.Event):
|
||||||
|
"""Route MmuMapMsg/MmuUnmapMsg through fabric like KernelLaunchMsg.
|
||||||
|
|
||||||
|
Path: Host → PCIE_EP → IO_NOC → IO_CPU → (fan-out) → M_CPU → (fan-out) → PE_MMU
|
||||||
|
"""
|
||||||
|
start_ns = self._env.now
|
||||||
|
target_sips = getattr(request, "target_sips", "all")
|
||||||
|
|
||||||
|
# Determine target SIPs
|
||||||
|
sip_set: set[int] = set()
|
||||||
|
if target_sips == "all":
|
||||||
|
for ep_id in self._resolver.find_all_pcie_eps():
|
||||||
|
sip_id = int(ep_id.split(".")[0].replace("sip", ""))
|
||||||
|
sip_set.add(sip_id)
|
||||||
|
else:
|
||||||
|
sip_set = set(target_sips)
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for sip_id in sorted(sip_set):
|
||||||
|
entries.append((
|
||||||
|
self._resolver.find_pcie_ep(sip_id),
|
||||||
|
self._resolver.find_io_cpu(sip_id),
|
||||||
|
0, # MmuMapMsg has no data payload
|
||||||
|
))
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
self._results[key] = (Completion(ok=True), {"total_ns": 0.0})
|
||||||
|
done.succeed()
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(entries) == 1:
|
||||||
|
pcie_ep_id, io_cpu_id, _ = entries[0]
|
||||||
|
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||||
|
txn_done = self._env.event()
|
||||||
|
txn = Transaction(request=request, path=path, step=0, nbytes=0, done=txn_done)
|
||||||
|
yield self._host_queues[pcie_ep_id].put(txn)
|
||||||
|
yield txn_done
|
||||||
|
else:
|
||||||
|
# Multi-SIP fan-out
|
||||||
|
sub_dones = []
|
||||||
|
for pcie_ep_id, io_cpu_id, _ in entries:
|
||||||
|
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||||
|
sub_done = self._env.event()
|
||||||
|
sub_txn = Transaction(request=request, path=path, step=0, nbytes=0, done=sub_done)
|
||||||
|
yield self._host_queues[pcie_ep_id].put(sub_txn)
|
||||||
|
sub_dones.append(sub_done)
|
||||||
|
for sd in sub_dones:
|
||||||
|
yield sd
|
||||||
|
|
||||||
|
elapsed = self._env.now - start_ns
|
||||||
|
self._results[key] = (
|
||||||
|
Completion(ok=True),
|
||||||
|
{"total_ns": elapsed, "msg_type": request.msg_type},
|
||||||
|
)
|
||||||
|
done.succeed()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ _PE_COMP_OFFSETS = {
|
|||||||
"pe_dma": (0.0, -0.15),
|
"pe_dma": (0.0, -0.15),
|
||||||
"pe_gemm": (0.0, 0.0),
|
"pe_gemm": (0.0, 0.0),
|
||||||
"pe_math": (0.0, 0.15),
|
"pe_math": (0.0, 0.15),
|
||||||
|
"pe_mmu": (0.15, -0.15),
|
||||||
"pe_tcm": (0.3, 0.0),
|
"pe_tcm": (0.3, 0.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,6 +496,15 @@ def _instantiate_cube(
|
|||||||
kind="pe_response",
|
kind="pe_response",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# noc → PE_MMU (MMU mapping install)
|
||||||
|
pe_mmu_id = f"{pp}.pe_mmu"
|
||||||
|
if pe_mmu_id in nodes:
|
||||||
|
edges.append(Edge(
|
||||||
|
src=f"{cp}.noc", dst=pe_mmu_id,
|
||||||
|
distance_mm=clinks.get("noc_to_pe_mmu_mm", 0.0),
|
||||||
|
kind="command",
|
||||||
|
))
|
||||||
|
|
||||||
pe_idx += 1
|
pe_idx += 1
|
||||||
|
|
||||||
# ── xbar_top/bot → HBM slices ──
|
# ── xbar_top/bot → HBM slices ──
|
||||||
@@ -1073,6 +1083,7 @@ def _build_pe_view(spec: dict) -> ViewGraph:
|
|||||||
"pe_dma": (7.0, 1.5),
|
"pe_dma": (7.0, 1.5),
|
||||||
"pe_gemm": (7.0, 4.0),
|
"pe_gemm": (7.0, 4.0),
|
||||||
"pe_math": (7.0, 6.5),
|
"pe_math": (7.0, 6.5),
|
||||||
|
"pe_mmu": (4.0, 1.5),
|
||||||
"pe_tcm": (10.0, 4.0),
|
"pe_tcm": (10.0, 4.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,11 +86,11 @@ class TLContext:
|
|||||||
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||||
|
|
||||||
def _make_handle(
|
def _make_handle(
|
||||||
self, pa: int, shape: tuple[int, ...], dtype: str,
|
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
return TensorHandle(
|
return TensorHandle(
|
||||||
id=self._next_handle_id(),
|
id=self._next_handle_id(),
|
||||||
pa=pa, shape=shape, dtype=dtype,
|
addr=addr, shape=shape, dtype=dtype,
|
||||||
nbytes=self._nbytes(shape, dtype),
|
nbytes=self._nbytes(shape, dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ class TLContext:
|
|||||||
Used when the scheduler will stream data per-tile (e.g., tensor b
|
Used when the scheduler will stream data per-tile (e.g., tensor b
|
||||||
in a composite GEMM). No command is generated.
|
in a composite GEMM). No command is generated.
|
||||||
"""
|
"""
|
||||||
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
return self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
||||||
|
|
||||||
@@ -113,9 +113,9 @@ class TLContext:
|
|||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
||||||
self._commands.append(DmaReadCmd(
|
self._commands.append(DmaReadCmd(
|
||||||
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
|
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
|
||||||
))
|
))
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class TLContext:
|
|||||||
"""Store tensor from TCM to HBM."""
|
"""Store tensor from TCM to HBM."""
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(DmaWriteCmd(
|
self._commands.append(DmaWriteCmd(
|
||||||
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
|
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
|
||||||
))
|
))
|
||||||
|
|
||||||
# ── GEMM Engine (blocking) ────────────────────────────────────
|
# ── GEMM Engine (blocking) ────────────────────────────────────
|
||||||
@@ -141,7 +141,7 @@ class TLContext:
|
|||||||
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
||||||
out_shape = (*a.shape[:-2], m, n)
|
out_shape = (*a.shape[:-2], m, n)
|
||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
|
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||||
return out
|
return out
|
||||||
@@ -149,7 +149,7 @@ class TLContext:
|
|||||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||||
|
|
||||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||||
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
|
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
||||||
return out
|
return out
|
||||||
@@ -182,7 +182,7 @@ class TLContext:
|
|||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
out_shape = list(x.shape)
|
out_shape = list(x.shape)
|
||||||
out_shape[axis] = 1
|
out_shape[axis] = 1
|
||||||
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
|
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||||
return out
|
return out
|
||||||
@@ -201,7 +201,7 @@ class TLContext:
|
|||||||
def _binary_math(
|
def _binary_math(
|
||||||
self, op: str, a: TensorHandle, b: TensorHandle,
|
self, op: str, a: TensorHandle, b: TensorHandle,
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
||||||
return out
|
return out
|
||||||
@@ -209,7 +209,7 @@ class TLContext:
|
|||||||
def where(
|
def where(
|
||||||
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||||
return out
|
return out
|
||||||
@@ -227,17 +227,17 @@ class TLContext:
|
|||||||
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
||||||
"""Create index range tensor in TCM."""
|
"""Create index range tensor in TCM."""
|
||||||
n = end - start
|
n = end - start
|
||||||
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
|
return self._make_handle(addr=0, shape=(n,), dtype=dtype)
|
||||||
|
|
||||||
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||||
"""Create zero-filled tensor in TCM."""
|
"""Create zero-filled tensor in TCM."""
|
||||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
def full(
|
def full(
|
||||||
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
|
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
"""Create constant-filled tensor in TCM."""
|
"""Create constant-filled tensor in TCM."""
|
||||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
# ── Metadata (no compute, no DMA) ─────────────────────────────
|
# ── Metadata (no compute, no DMA) ─────────────────────────────
|
||||||
|
|
||||||
@@ -247,7 +247,7 @@ class TLContext:
|
|||||||
raise ValueError("trans requires at least 2D tensor")
|
raise ValueError("trans requires at least 2D tensor")
|
||||||
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
|
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
|
||||||
return TensorHandle(
|
return TensorHandle(
|
||||||
id=x.id, pa=x.pa, shape=new_shape,
|
id=x.id, addr=x.addr, shape=new_shape,
|
||||||
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ class TLContext:
|
|||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(CompositeCmd(
|
self._commands.append(CompositeCmd(
|
||||||
completion=completion, op=op,
|
completion=completion, op=op,
|
||||||
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
|
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
||||||
math_op=math_op,
|
math_op=math_op,
|
||||||
))
|
))
|
||||||
return completion
|
return completion
|
||||||
|
|||||||
@@ -0,0 +1,226 @@
|
|||||||
|
"""Tests for PE_MMU component integration and MmuMapMsg fabric path.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
T15-a. PE_MMU component registered in ComponentRegistry
|
||||||
|
T15-b. PE_MMU component receives MmuMapMsg via inbox, updates page table
|
||||||
|
T15-c. PE_DMA translates VA→PA via mmu before routing
|
||||||
|
T16. MmuMapMsg/MmuUnmapMsg message types defined with correct fields
|
||||||
|
T17. PE_CPU passes VA (not PA) to kernel when VA is available
|
||||||
|
T18. End-to-end: deploy (MmuMapMsg broadcast) → kernel launch → DMA with VA
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("MmuMapMsg/MmuUnmapMsg not yet defined (Phase 2)", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── T16. MmuMapMsg / MmuUnmapMsg message types ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_msg_fields():
|
||||||
|
"""MmuMapMsg carries VA→PA mapping entries for broadcast."""
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=(
|
||||||
|
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
|
||||||
|
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
|
||||||
|
),
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
assert msg.msg_type == "mmu_map"
|
||||||
|
assert len(msg.entries) == 2
|
||||||
|
assert msg.entries[0]["va"] == 0x1_0000_0000
|
||||||
|
assert msg.entries[0]["pa"] == 0xA000_0000
|
||||||
|
assert msg.entries[0]["size"] == 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_msg_immutable():
|
||||||
|
"""MmuMapMsg is frozen."""
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=(),
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
msg.entries = () # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_unmap_msg_fields():
|
||||||
|
"""MmuUnmapMsg carries VA ranges to unmap."""
|
||||||
|
msg = MmuUnmapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=(
|
||||||
|
{"va": 0x1_0000_0000, "size": 4096},
|
||||||
|
),
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
assert msg.msg_type == "mmu_unmap"
|
||||||
|
assert len(msg.entries) == 1
|
||||||
|
assert msg.entries[0]["va"] == 0x1_0000_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── T15-a. PE_MMU component registry ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pe_mmu_registry():
|
||||||
|
"""pe_mmu_v1 impl resolves in ComponentRegistry."""
|
||||||
|
from kernbench.components.base import ComponentRegistry
|
||||||
|
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||||
|
from kernbench.topology.types import Node
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="sip0.cube0.pe0.pe_mmu",
|
||||||
|
kind="pe_mmu",
|
||||||
|
impl="pe_mmu_v1",
|
||||||
|
pos_mm=None,
|
||||||
|
attrs={"tlb_overhead_ns": 0.5},
|
||||||
|
)
|
||||||
|
comp = ComponentRegistry.create(node)
|
||||||
|
assert isinstance(comp, PeMmuComponent)
|
||||||
|
|
||||||
|
|
||||||
|
# ── T15-b. PE_MMU receives MmuMapMsg and updates page table ─────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pe_mmu_processes_map_msg():
|
||||||
|
"""PE_MMU component receives MmuMapMsg → translate works."""
|
||||||
|
import simpy
|
||||||
|
from kernbench.components.impls.pe_mmu import PeMmuComponent
|
||||||
|
from kernbench.sim_engine.transaction import Transaction
|
||||||
|
from kernbench.topology.types import Node
|
||||||
|
|
||||||
|
env = simpy.Environment()
|
||||||
|
node = Node(
|
||||||
|
id="sip0.cube0.pe0.pe_mmu",
|
||||||
|
kind="pe_mmu",
|
||||||
|
impl="pe_mmu_v1",
|
||||||
|
pos_mm=None,
|
||||||
|
attrs={"tlb_overhead_ns": 0.5, "page_size": 4096},
|
||||||
|
)
|
||||||
|
comp = PeMmuComponent(node)
|
||||||
|
comp.in_ports["src"] = simpy.Store(env)
|
||||||
|
comp.start(env)
|
||||||
|
|
||||||
|
# Submit MmuMapMsg via inbox
|
||||||
|
map_msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=(
|
||||||
|
{"va": 0x1_0000_0000, "pa": 0xABCD_0000, "size": 4096},
|
||||||
|
),
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
done = env.event()
|
||||||
|
txn = Transaction(
|
||||||
|
request=map_msg,
|
||||||
|
path=["sip0.cube0.pe0.pe_mmu"],
|
||||||
|
step=0, nbytes=0, done=done,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject():
|
||||||
|
yield comp._inbox.put(txn)
|
||||||
|
|
||||||
|
env.process(inject())
|
||||||
|
env.run(until=100)
|
||||||
|
|
||||||
|
# After processing, the MMU's translate should work
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
mmu = comp.mmu # the underlying PeMMU utility object
|
||||||
|
assert isinstance(mmu, PeMMU)
|
||||||
|
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── T15-c. PE_DMA uses MMU translate ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pe_dma_translates_va():
|
||||||
|
"""PE_DMA.handle_command calls mmu.translate(va) → PA before routing.
|
||||||
|
|
||||||
|
This test validates the contract: after Phase 2, DmaReadCmd carries VA,
|
||||||
|
and PE_DMA must translate it to PA via the MMU before resolving the
|
||||||
|
HBM node path.
|
||||||
|
"""
|
||||||
|
# This test validates the interface contract. Full integration test
|
||||||
|
# requires the engine wiring which is validated in test_engine.
|
||||||
|
# Here we check that PE_DMA has an mmu attribute it can call.
|
||||||
|
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||||
|
from kernbench.topology.types import Node
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="sip0.cube0.pe0.pe_dma",
|
||||||
|
kind="pe_dma",
|
||||||
|
impl="pe_dma_v1",
|
||||||
|
pos_mm=None,
|
||||||
|
attrs={"rd_engines": 1, "wr_engines": 1},
|
||||||
|
)
|
||||||
|
comp = PeDmaComponent(node)
|
||||||
|
|
||||||
|
# PE_DMA must have a way to access the MMU (via ctx or direct reference)
|
||||||
|
# The exact wiring mechanism is flexible, but the attribute must exist
|
||||||
|
assert hasattr(comp, '_mmu') or hasattr(comp, 'mmu') or (
|
||||||
|
hasattr(comp, 'ctx') and comp.ctx is not None
|
||||||
|
), "PE_DMA must have access to PE_MMU for VA translation"
|
||||||
|
|
||||||
|
|
||||||
|
# ── T17. PE_CPU passes VA to kernel ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pe_cpu_uses_va_base_from_tensor_arg():
|
||||||
|
"""PE_CPU should use TensorArg.va_base for kernel pointer args.
|
||||||
|
|
||||||
|
After Phase 2, TensorArg carries va_base alongside shards.
|
||||||
|
PE_CPU extracts va_base and passes it to the kernel function
|
||||||
|
so kernels operate on VA (not PA).
|
||||||
|
"""
|
||||||
|
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
|
||||||
|
|
||||||
|
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x1000,
|
||||||
|
nbytes=4096, offset_bytes=0)
|
||||||
|
targ = TensorArg(shards=(shard,), va_base=0x1_0000_0000)
|
||||||
|
|
||||||
|
# PE_CPU should use targ.va_base for kernel pointer arg
|
||||||
|
assert targ.va_base == 0x1_0000_0000
|
||||||
|
# PA still accessible via shard for direct-PA operations (IPCQ etc.)
|
||||||
|
assert shard.pa == 0x1000
|
||||||
|
|
||||||
|
|
||||||
|
# ── T18. MmuMapMsg broadcast pattern ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_msg_broadcast_target():
|
||||||
|
"""MmuMapMsg with target_pe='all' is a broadcast to all PEs."""
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=({"va": 0x1000, "pa": 0x2000, "size": 4096},),
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
assert msg.target_pe == "all"
|
||||||
|
assert msg.target_cubes == "all"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_msg_same_entries_all_pes():
|
||||||
|
"""All PEs in a broadcast receive identical entries (not per-PE splits)."""
|
||||||
|
entries = (
|
||||||
|
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 8192},
|
||||||
|
{"va": 0x1_0000_2000, "pa": 0xB000_0000, "size": 8192},
|
||||||
|
)
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="r0",
|
||||||
|
entries=entries,
|
||||||
|
target_cubes="all",
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
# The message carries the full mapping — every PE receives exactly this
|
||||||
|
assert msg.entries == entries
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
"""Tests for MmuMapMsg fabric path and cross-cube mapping.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
F1. MmuMapMsg traverses fabric: latency > 0 (not sideband)
|
||||||
|
F2. MmuMapMsg fan-out: IO_CPU → cubes, M_CPU → PEs
|
||||||
|
F3. After MmuMapMsg, PE_MMU has correct mappings
|
||||||
|
F4. Cross-cube sharded tensor: all PEs get global mappings
|
||||||
|
F5. Replicate tensor: each PE gets own cube's PA (local override)
|
||||||
|
F6. Cross-cube DMA after sharded mapping: PE can access remote cube's HBM
|
||||||
|
F7. Overlap detection: replicate vs sharded identified correctly
|
||||||
|
F8. Existing regression: PA-only benchmarks still pass
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||||
|
from kernbench.policy.placement.dp import column_wise, replicate, ShardSpec
|
||||||
|
from kernbench.runtime_api.tensor import deploy_tensor, TensorHandle
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.topology.builder import load_topology
|
||||||
|
|
||||||
|
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||||
|
|
||||||
|
_MB = 1 << 20
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
_CFG = AddressConfig(
|
||||||
|
sip_count=2,
|
||||||
|
cubes_per_sip=16,
|
||||||
|
pes_per_cube=8,
|
||||||
|
hbm_bytes_per_cube=48 * _GB,
|
||||||
|
hbm_slices_per_cube=8,
|
||||||
|
tcm_bytes_per_pe=16 * _MB,
|
||||||
|
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||||
|
sram_bytes_per_cube=32 * _MB,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _engine():
|
||||||
|
return GraphEngine(load_topology(TOPOLOGY_PATH))
|
||||||
|
|
||||||
|
|
||||||
|
# ── F1. MmuMapMsg fabric latency ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_via_fabric_has_latency():
|
||||||
|
"""MmuMapMsg submitted through engine.submit() completes with latency > 0."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
engine = _engine()
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_map_0",
|
||||||
|
entries=({"va": 0x1_0000_0000, "pa": 0x2000_0000, "size": 4096},),
|
||||||
|
target_cubes=(0,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = engine.submit(msg)
|
||||||
|
engine.wait(h)
|
||||||
|
comp, trace = engine.get_completion(h)
|
||||||
|
assert comp.ok is True
|
||||||
|
# Fabric traversal must have non-zero latency
|
||||||
|
assert trace is not None
|
||||||
|
assert trace.get("total_ns", 0) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── F2. MmuMapMsg fan-out ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_reaches_all_pes_in_cube():
|
||||||
|
"""MmuMapMsg with target_pe='all' installs mapping in all 8 PE_MMUs of target cube."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
engine = _engine()
|
||||||
|
va, pa, size = 0x1_0000_0000, 0xABCD_0000, 4096
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_map_1",
|
||||||
|
entries=({"va": va, "pa": pa, "size": size},),
|
||||||
|
target_cubes=(0,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = engine.submit(msg)
|
||||||
|
engine.wait(h)
|
||||||
|
|
||||||
|
# Verify all 8 PE_MMUs in cube 0 have the mapping
|
||||||
|
for pe_id in range(8):
|
||||||
|
mmu_id = f"sip0.cube0.pe{pe_id}.pe_mmu"
|
||||||
|
mmu_comp = engine._components[mmu_id]
|
||||||
|
assert mmu_comp.mmu.translate(va) == pa
|
||||||
|
|
||||||
|
|
||||||
|
# ── F3. Multiple MmuMapMsg entries ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmu_map_multiple_entries():
|
||||||
|
"""MmuMapMsg with multiple entries installs all of them."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
engine = _engine()
|
||||||
|
entries = (
|
||||||
|
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096},
|
||||||
|
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096},
|
||||||
|
)
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_map_2",
|
||||||
|
entries=entries,
|
||||||
|
target_cubes=(0,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = engine.submit(msg)
|
||||||
|
engine.wait(h)
|
||||||
|
|
||||||
|
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||||
|
assert mmu_comp.mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||||
|
assert mmu_comp.mmu.translate(0x1_0000_1000) == 0xB000_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── F4. Cross-cube sharded: global mapping ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_cube_sharded_all_pes_get_global_mapping():
|
||||||
|
"""For sharded tensor across cubes (unique offsets), all PEs get all mappings."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
engine = _engine()
|
||||||
|
# Simulate 2-cube shard: cube0 has offset=0, cube1 has offset=4096
|
||||||
|
entries = (
|
||||||
|
{"va": 0x1_0000_0000, "pa": 0xA000_0000, "size": 4096}, # cube0
|
||||||
|
{"va": 0x1_0000_1000, "pa": 0xB000_0000, "size": 4096}, # cube1
|
||||||
|
)
|
||||||
|
# Broadcast to both cubes
|
||||||
|
msg = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_map_xc",
|
||||||
|
entries=entries,
|
||||||
|
target_cubes=(0, 1),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h = engine.submit(msg)
|
||||||
|
engine.wait(h)
|
||||||
|
|
||||||
|
# PE in cube0 can translate both cube0 and cube1 addresses
|
||||||
|
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||||
|
assert mmu_c0.mmu.translate(0x1_0000_0000) == 0xA000_0000 # local
|
||||||
|
assert mmu_c0.mmu.translate(0x1_0000_1000) == 0xB000_0000 # remote
|
||||||
|
|
||||||
|
# PE in cube1 can also translate both
|
||||||
|
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||||
|
assert mmu_c1.mmu.translate(0x1_0000_0000) == 0xA000_0000 # remote
|
||||||
|
assert mmu_c1.mmu.translate(0x1_0000_1000) == 0xB000_0000 # local
|
||||||
|
|
||||||
|
|
||||||
|
# ── F5. Replicate: local PA override ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_replicate_local_pa_override():
|
||||||
|
"""For replicated tensor (same VA range), each cube's PEs see local PA."""
|
||||||
|
from kernbench.runtime_api.kernel import MmuMapMsg
|
||||||
|
|
||||||
|
engine = _engine()
|
||||||
|
va, size = 0x1_0000_0000, 4096
|
||||||
|
|
||||||
|
# Cube 0 gets its own PA
|
||||||
|
msg0 = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_rep_c0",
|
||||||
|
entries=({"va": va, "pa": 0xA000_0000, "size": size},),
|
||||||
|
target_cubes=(0,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h0 = engine.submit(msg0)
|
||||||
|
engine.wait(h0)
|
||||||
|
|
||||||
|
# Cube 1 gets a different PA for the same VA
|
||||||
|
msg1 = MmuMapMsg(
|
||||||
|
correlation_id="c0",
|
||||||
|
request_id="mmu_rep_c1",
|
||||||
|
entries=({"va": va, "pa": 0xB000_0000, "size": size},),
|
||||||
|
target_cubes=(1,),
|
||||||
|
target_pe="all",
|
||||||
|
)
|
||||||
|
h1 = engine.submit(msg1)
|
||||||
|
engine.wait(h1)
|
||||||
|
|
||||||
|
# Cube 0 PEs translate to cube 0's PA
|
||||||
|
mmu_c0 = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||||
|
assert mmu_c0.mmu.translate(va) == 0xA000_0000
|
||||||
|
|
||||||
|
# Cube 1 PEs translate to cube 1's PA
|
||||||
|
mmu_c1 = engine._components["sip0.cube1.pe0.pe_mmu"]
|
||||||
|
assert mmu_c1.mmu.translate(va) == 0xB000_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── F7. Overlap detection ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_overlapping_shards():
|
||||||
|
"""Utility: detect if shards have overlapping VA ranges (replicate indicator)."""
|
||||||
|
from kernbench.runtime_api.tensor import TensorShard
|
||||||
|
|
||||||
|
# Sharded: unique offsets
|
||||||
|
sharded = [
|
||||||
|
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||||
|
TensorShard(sip=0, cube=0, pe=1, pa=0x200, nbytes=4096, offset_bytes=4096),
|
||||||
|
]
|
||||||
|
offsets = [(s.offset_bytes, s.nbytes) for s in sharded]
|
||||||
|
assert len(set(offsets)) == len(offsets), "Sharded should have unique offsets"
|
||||||
|
|
||||||
|
# Replicated: same offset
|
||||||
|
replicated = [
|
||||||
|
TensorShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=4096, offset_bytes=0),
|
||||||
|
TensorShard(sip=0, cube=1, pe=0, pa=0x200, nbytes=4096, offset_bytes=0),
|
||||||
|
]
|
||||||
|
offsets_r = [(s.offset_bytes, s.nbytes) for s in replicated]
|
||||||
|
assert len(set(offsets_r)) < len(offsets_r), "Replicate should have duplicate offsets"
|
||||||
|
|
||||||
|
|
||||||
|
# ── F8. Regression: existing benchmarks still pass ───────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_qkv_gemm_still_passes():
|
||||||
|
"""QKV GEMM benchmark completes successfully with VA/MMU enabled."""
|
||||||
|
from kernbench.runtime_api.context import RuntimeContext
|
||||||
|
from kernbench.runtime_api.types import BenchResult, DeviceSelector
|
||||||
|
|
||||||
|
graph = load_topology(TOPOLOGY_PATH)
|
||||||
|
engine = GraphEngine(graph)
|
||||||
|
ctx = RuntimeContext(
|
||||||
|
engine=engine,
|
||||||
|
target_device=DeviceSelector("sip:0"),
|
||||||
|
correlation_id="test_regression",
|
||||||
|
spec=graph.spec,
|
||||||
|
)
|
||||||
|
from benches.qkv_gemm import run as bench_run
|
||||||
|
bench_run(ctx)
|
||||||
|
ctx.wait_all()
|
||||||
|
# If we get here without exception, the benchmark succeeded
|
||||||
@@ -308,9 +308,9 @@ def test_pe_gemm_handles_pe_internal_txn():
|
|||||||
gemm.in_ports["src"] = simpy.Store(env)
|
gemm.in_ports["src"] = simpy.Store(env)
|
||||||
gemm.start(env)
|
gemm.start(env)
|
||||||
|
|
||||||
a = TensorHandle(id="t1", pa=0, shape=(4, 8), dtype="f16", nbytes=64)
|
a = TensorHandle(id="t1", addr=0, shape=(4, 8), dtype="f16", nbytes=64)
|
||||||
b = TensorHandle(id="t2", pa=0, shape=(8, 4), dtype="f16", nbytes=64)
|
b = TensorHandle(id="t2", addr=0, shape=(8, 4), dtype="f16", nbytes=64)
|
||||||
out = TensorHandle(id="t3", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
out = TensorHandle(id="t3", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||||
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
|
cmd = GemmCmd(a=a, b=b, out=out, m=4, k=8, n=4)
|
||||||
done = env.event()
|
done = env.event()
|
||||||
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
||||||
@@ -349,8 +349,8 @@ def test_pe_math_handles_pe_internal_txn():
|
|||||||
math_comp.in_ports["src"] = simpy.Store(env)
|
math_comp.in_ports["src"] = simpy.Store(env)
|
||||||
math_comp.start(env)
|
math_comp.start(env)
|
||||||
|
|
||||||
x = TensorHandle(id="t1", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
x = TensorHandle(id="t1", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||||
out = TensorHandle(id="t2", pa=0, shape=(4, 4), dtype="f16", nbytes=32)
|
out = TensorHandle(id="t2", addr=0, shape=(4, 4), dtype="f16", nbytes=32)
|
||||||
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
||||||
done = env.event()
|
done = env.event()
|
||||||
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
pe_txn = PeInternalTxn(command=cmd, done=done, pe_prefix=pe_prefix)
|
||||||
@@ -777,7 +777,7 @@ def test_tl_ref_no_dma():
|
|||||||
|
|
||||||
tl = TLContext(pe_id=0, dispatch_cycles=0)
|
tl = TLContext(pe_id=0, dispatch_cycles=0)
|
||||||
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
|
handle = tl.ref(0x1000, shape=(4, 4), dtype="f16")
|
||||||
assert handle.pa == 0x1000
|
assert handle.addr == 0x1000
|
||||||
assert handle.shape == (4, 4)
|
assert handle.shape == (4, 4)
|
||||||
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"
|
assert len(tl.commands) == 0, f"tl.ref should emit 0 commands, got {len(tl.commands)}"
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
"""Tests for PeMMU: per-PE virtual-to-physical address translation.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
T1. Basic map + translate
|
||||||
|
T2. Page-aligned dict lookup (O(1), multi-page range)
|
||||||
|
T3. Multiple tensor mappings accumulate
|
||||||
|
T4. unmap removes entries, translate raises PageFault
|
||||||
|
T5. PageFault on unmapped VA
|
||||||
|
T6. Identical mapping broadcast across multiple PEs
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernbench.policy.address.pe_mmu import PageFault, PeMMU
|
||||||
|
|
||||||
|
_2MB = 2 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ── T1. Basic map + translate ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_and_translate_basic():
|
||||||
|
"""map(va, pa, size) → translate(va) returns pa; offset preserved."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||||
|
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_preserves_offset():
|
||||||
|
"""translate(va + offset) returns pa + offset within a page."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||||
|
assert mmu.translate(0x1_0000_0100) == 0xABCD_0100
|
||||||
|
assert mmu.translate(0x1_0000_0FFF) == 0xABCD_0FFF
|
||||||
|
|
||||||
|
|
||||||
|
# ── T2. Page-aligned dict lookup (multi-page) ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_page_mapping():
|
||||||
|
"""8 MB mapping with 2 MB pages → 4 page entries, all translate correctly."""
|
||||||
|
mmu = PeMMU(page_size=_2MB)
|
||||||
|
va_base = 0x1_0000_0000
|
||||||
|
pa_base = 0x2_0000_0000
|
||||||
|
size = 8 * 1024 * 1024 # 8 MB = 4 pages
|
||||||
|
|
||||||
|
mmu.map(va=va_base, pa=pa_base, size=size)
|
||||||
|
|
||||||
|
# First page
|
||||||
|
assert mmu.translate(va_base) == pa_base
|
||||||
|
# Second page start
|
||||||
|
assert mmu.translate(va_base + _2MB) == pa_base + _2MB
|
||||||
|
# Third page with offset
|
||||||
|
assert mmu.translate(va_base + 2 * _2MB + 0x100) == pa_base + 2 * _2MB + 0x100
|
||||||
|
# Last byte of last page
|
||||||
|
assert mmu.translate(va_base + size - 1) == pa_base + size - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_table_entry_count():
|
||||||
|
"""Mapping N bytes with page_size P creates ceil(N/P) entries."""
|
||||||
|
mmu = PeMMU(page_size=_2MB)
|
||||||
|
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8 * 1024 * 1024)
|
||||||
|
assert mmu.num_entries == 4
|
||||||
|
|
||||||
|
mmu.map(va=0x2000_0000, pa=0x3000_0000, size=_2MB)
|
||||||
|
assert mmu.num_entries == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ── T3. Multiple tensor mappings accumulate ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_mappings_accumulate():
|
||||||
|
"""Two non-overlapping tensors → both translate correctly."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
# Tensor A
|
||||||
|
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
||||||
|
# Tensor B (different VA range)
|
||||||
|
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
||||||
|
|
||||||
|
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||||
|
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
||||||
|
|
||||||
|
|
||||||
|
def test_mappings_do_not_interfere():
|
||||||
|
"""Adjacent VA ranges map to completely independent PA ranges."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x0000_0000, pa=0xFFFF_0000, size=4096)
|
||||||
|
mmu.map(va=0x0000_1000, pa=0x0000_0000, size=4096)
|
||||||
|
|
||||||
|
assert mmu.translate(0x0000_0000) == 0xFFFF_0000
|
||||||
|
assert mmu.translate(0x0000_1000) == 0x0000_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── T4. unmap removes entries ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_unmap_removes_mapping():
|
||||||
|
"""After unmap, translate raises PageFault."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x1_0000_0000, pa=0xABCD_0000, size=4096)
|
||||||
|
assert mmu.translate(0x1_0000_0000) == 0xABCD_0000
|
||||||
|
|
||||||
|
mmu.unmap(va=0x1_0000_0000, size=4096)
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu.translate(0x1_0000_0000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unmap_partial_range():
|
||||||
|
"""Unmap only part of a multi-page mapping; rest stays valid."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x1000_0000, pa=0x2000_0000, size=8192) # 2 pages
|
||||||
|
assert mmu.num_entries == 2
|
||||||
|
|
||||||
|
# Unmap first page only
|
||||||
|
mmu.unmap(va=0x1000_0000, size=4096)
|
||||||
|
assert mmu.num_entries == 1
|
||||||
|
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu.translate(0x1000_0000)
|
||||||
|
# Second page still valid
|
||||||
|
assert mmu.translate(0x1000_1000) == 0x2000_1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_unmap_does_not_affect_other_mappings():
|
||||||
|
"""Unmapping tensor A does not affect tensor B."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
mmu.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096)
|
||||||
|
mmu.map(va=0x1_0001_0000, pa=0xB000_0000, size=4096)
|
||||||
|
|
||||||
|
mmu.unmap(va=0x1_0000_0000, size=4096)
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu.translate(0x1_0000_0000)
|
||||||
|
assert mmu.translate(0x1_0001_0000) == 0xB000_0000
|
||||||
|
|
||||||
|
|
||||||
|
# ── T5. PageFault on unmapped VA ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_pagefault_on_unmapped_va():
|
||||||
|
"""translate() on never-mapped VA raises PageFault."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu.translate(0xDEAD_BEEF)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pagefault_contains_va():
|
||||||
|
"""PageFault exception carries the faulting VA."""
|
||||||
|
mmu = PeMMU(page_size=4096)
|
||||||
|
with pytest.raises(PageFault, match="0xdeadbeef"):
|
||||||
|
mmu.translate(0xDEAD_BEEF)
|
||||||
|
|
||||||
|
|
||||||
|
# ── T6. Identical mapping broadcast across PEs ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_broadcast_same_mapping_to_all_pes():
|
||||||
|
"""All PEs receive the same full mapping → identical translate results."""
|
||||||
|
entries = [
|
||||||
|
(0x1_0000_0000, 0xA000_0000, 4096), # shard 0
|
||||||
|
(0x1_0000_1000, 0xB000_0000, 4096), # shard 1
|
||||||
|
(0x1_0000_2000, 0xC000_0000, 4096), # shard 2
|
||||||
|
(0x1_0000_3000, 0xD000_0000, 4096), # shard 3
|
||||||
|
]
|
||||||
|
num_pes = 8
|
||||||
|
mmus = [PeMMU(page_size=4096) for _ in range(num_pes)]
|
||||||
|
|
||||||
|
# Broadcast: every PE gets the same entries
|
||||||
|
for mmu in mmus:
|
||||||
|
for va, pa, size in entries:
|
||||||
|
mmu.map(va=va, pa=pa, size=size)
|
||||||
|
|
||||||
|
# All PEs translate identically
|
||||||
|
for mmu in mmus:
|
||||||
|
assert mmu.translate(0x1_0000_0000) == 0xA000_0000
|
||||||
|
assert mmu.translate(0x1_0000_1000) == 0xB000_0000
|
||||||
|
assert mmu.translate(0x1_0000_2000) == 0xC000_0000
|
||||||
|
assert mmu.translate(0x1_0000_3000) == 0xD000_0000
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_pe_access_via_broadcast():
|
||||||
|
"""PE0 can translate a VA that maps to PE3's HBM PA (cross-PE DMA scenario)."""
|
||||||
|
mmu_pe0 = PeMMU(page_size=4096)
|
||||||
|
# Full mapping includes PE3's shard
|
||||||
|
mmu_pe0.map(va=0x1_0000_0000, pa=0xA000_0000, size=4096) # PE0 shard
|
||||||
|
mmu_pe0.map(va=0x1_0000_1000, pa=0xD000_0000, size=4096) # PE3 shard
|
||||||
|
|
||||||
|
# PE0 accesses PE3's region → valid translation
|
||||||
|
pa = mmu_pe0.translate(0x1_0000_1000 + 0x100)
|
||||||
|
assert pa == 0xD000_0100
|
||||||
|
|
||||||
|
|
||||||
|
# ── TLB overhead attribute ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tlb_overhead_default():
|
||||||
|
"""Default tlb_overhead_ns is 0."""
|
||||||
|
mmu = PeMMU()
|
||||||
|
assert mmu.overhead_ns == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tlb_overhead_configurable():
|
||||||
|
"""tlb_overhead_ns is configurable."""
|
||||||
|
mmu = PeMMU(overhead_ns=0.5)
|
||||||
|
assert mmu.overhead_ns == 0.5
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
"""Tests for tensor free: del-based + context manager cleanup.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
TF1. PEMemAllocator.free_hbm/free_tcm reclaims space
|
||||||
|
TF2. del tensor triggers cleanup (VA/PA returned, MMU unmapped)
|
||||||
|
TF3. Context manager cleans up all tensors on exit
|
||||||
|
TF4. del after context exit is safe (no crash)
|
||||||
|
TF5. Alloc-del-alloc cycle reuses VA and PA
|
||||||
|
TF6. del already-freed tensor is safe (no crash)
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||||
|
from kernbench.policy.address.pe_mmu import PageFault
|
||||||
|
from kernbench.policy.placement.dp import DPPolicy
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.runtime_api.context import RuntimeContext
|
||||||
|
from kernbench.runtime_api.types import DeviceSelector
|
||||||
|
from kernbench.topology.builder import load_topology
|
||||||
|
|
||||||
|
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||||
|
|
||||||
|
_MB = 1 << 20
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
_CFG = AddressConfig(
|
||||||
|
sip_count=2,
|
||||||
|
cubes_per_sip=16,
|
||||||
|
pes_per_cube=8,
|
||||||
|
hbm_bytes_per_cube=48 * _GB,
|
||||||
|
hbm_slices_per_cube=8,
|
||||||
|
tcm_bytes_per_pe=16 * _MB,
|
||||||
|
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||||
|
sram_bytes_per_cube=32 * _MB,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx():
|
||||||
|
graph = load_topology(TOPOLOGY_PATH)
|
||||||
|
engine = GraphEngine(graph)
|
||||||
|
ctx = RuntimeContext(
|
||||||
|
engine=engine,
|
||||||
|
target_device=DeviceSelector("sip:0"),
|
||||||
|
correlation_id="test_free",
|
||||||
|
spec=graph.spec,
|
||||||
|
)
|
||||||
|
return ctx, engine
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF1. PEMemAllocator free ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocator_free_hbm_reclaims_space():
|
||||||
|
"""free_hbm returns HBM space; subsequent alloc can reuse it."""
|
||||||
|
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||||
|
pa1 = a.alloc_hbm(4096)
|
||||||
|
used_after_alloc = a.hbm_used
|
||||||
|
a.free_hbm(pa1, 4096)
|
||||||
|
assert a.hbm_used == used_after_alloc - 4096
|
||||||
|
pa2 = a.alloc_hbm(4096)
|
||||||
|
assert pa2 is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocator_free_tcm_reclaims_space():
|
||||||
|
"""free_tcm returns TCM space."""
|
||||||
|
a = PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=0, cfg=_CFG)
|
||||||
|
pa1 = a.alloc_tcm(256)
|
||||||
|
used_after_alloc = a.tcm_used
|
||||||
|
a.free_tcm(pa1, 256)
|
||||||
|
assert a.tcm_used == used_after_alloc - 256
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF2. del tensor triggers cleanup ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_del_tensor_unmaps_mmu():
|
||||||
|
"""del tensor removes MMU mappings."""
|
||||||
|
ctx, engine = _make_ctx()
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="del_test")
|
||||||
|
va_base = t._handle.va_base
|
||||||
|
|
||||||
|
# Verify mapping exists
|
||||||
|
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||||
|
assert mmu_comp.mmu.translate(va_base) is not None
|
||||||
|
|
||||||
|
# Delete tensor
|
||||||
|
del t
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Mapping should be gone
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu_comp.mmu.translate(va_base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_del_tensor_reclaims_va():
|
||||||
|
"""del tensor returns VA space for reuse."""
|
||||||
|
ctx, engine = _make_ctx()
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
|
||||||
|
t1 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse1")
|
||||||
|
va1 = t1._handle.va_base
|
||||||
|
|
||||||
|
del t1
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
t2 = ctx.zeros((128, 128), dtype="f16", dp=dp, name="va_reuse2")
|
||||||
|
va2 = t2._handle.va_base
|
||||||
|
assert va2 == va1
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF3. Context manager cleanup ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_manager_cleans_all():
|
||||||
|
"""Exiting context manager cleans up all tensors."""
|
||||||
|
graph = load_topology(TOPOLOGY_PATH)
|
||||||
|
engine = GraphEngine(graph)
|
||||||
|
|
||||||
|
with RuntimeContext(
|
||||||
|
engine=engine,
|
||||||
|
target_device=DeviceSelector("sip:0"),
|
||||||
|
correlation_id="ctx_mgr",
|
||||||
|
spec=graph.spec,
|
||||||
|
) as ctx:
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="ctx_mgr_test")
|
||||||
|
va_base = t._handle.va_base
|
||||||
|
|
||||||
|
# After context exit, MMU mappings should be cleared
|
||||||
|
mmu_comp = engine._components["sip0.cube0.pe0.pe_mmu"]
|
||||||
|
with pytest.raises(PageFault):
|
||||||
|
mmu_comp.mmu.translate(va_base)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF4. del after context exit is safe ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_del_after_context_exit_no_crash():
|
||||||
|
"""del tensor after context manager exit does not crash."""
|
||||||
|
graph = load_topology(TOPOLOGY_PATH)
|
||||||
|
engine = GraphEngine(graph)
|
||||||
|
|
||||||
|
ctx = RuntimeContext(
|
||||||
|
engine=engine,
|
||||||
|
target_device=DeviceSelector("sip:0"),
|
||||||
|
correlation_id="safe_del",
|
||||||
|
spec=graph.spec,
|
||||||
|
)
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="safe_del_test")
|
||||||
|
|
||||||
|
# Simulate context going away
|
||||||
|
ctx.cleanup()
|
||||||
|
|
||||||
|
# del should not crash even though context already cleaned up
|
||||||
|
del t
|
||||||
|
gc.collect()
|
||||||
|
# No exception = pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF5. Alloc-del-alloc cycle ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_del_cycle():
|
||||||
|
"""Multiple alloc-del cycles work correctly."""
|
||||||
|
ctx, engine = _make_ctx()
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name=f"cycle_{i}")
|
||||||
|
assert t._handle is not None
|
||||||
|
assert t._handle.va_base > 0
|
||||||
|
del t
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── TF6. del already-cleaned tensor is safe ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_del_already_cleaned_tensor_no_crash():
|
||||||
|
"""del on a tensor whose handle is already None does not crash."""
|
||||||
|
ctx, engine = _make_ctx()
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||||
|
t = ctx.zeros((128, 128), dtype="f16", dp=dp, name="double_del")
|
||||||
|
|
||||||
|
ctx.cleanup() # clears all tensors
|
||||||
|
# t._handle is now None
|
||||||
|
del t # should not crash
|
||||||
|
gc.collect()
|
||||||
@@ -17,31 +17,32 @@ def test_full_graph_node_count():
|
|||||||
g = _graph()
|
g = _graph()
|
||||||
# 1 switch
|
# 1 switch
|
||||||
# + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn)
|
# + 2 SIPs × (1 IO × (3 comps + 4 io_ucie + 16 io_conn)
|
||||||
# + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
|
# + 16 cubes × (cube_comps + 8 PEs × 7 pe_comps))
|
||||||
# IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23
|
# IO: pcie_ep + io_cpu + io_noc + 4 io_ucie + 4*4 io_conn = 23
|
||||||
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
|
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
|
||||||
# + 16 ucie_conn (4 ports × 4 connections)
|
# + 16 ucie_conn (4 ports × 4 connections)
|
||||||
# + 2 xbar_top/bot
|
# + 2 xbar_top/bot
|
||||||
# + 8 hbm_slices = 35
|
# + 8 hbm_slices = 35
|
||||||
# = 1 + 2*(23 + 16*(35+48)) = 1 + 2*(23+1328) = 1 + 2702 = 2703
|
# pe_comps: 7 (pe_cpu, pe_scheduler, pe_dma, pe_gemm, pe_math, pe_mmu, pe_tcm)
|
||||||
assert len(g.nodes) == 2703
|
# = 1 + 2*(23 + 16*(35+56)) = 1 + 2*(23+1456) = 1 + 2958 = 2959
|
||||||
|
assert len(g.nodes) == 2959
|
||||||
|
|
||||||
|
|
||||||
def test_full_graph_edge_count():
|
def test_full_graph_edge_count():
|
||||||
g = _graph()
|
g = _graph()
|
||||||
# Per cube: 184
|
# Per cube: 192
|
||||||
# PE-internal: 56
|
# PE-internal: 56
|
||||||
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8
|
# PE_DMA→noc: 8, noc→pe_dma: 8, noc→pe_cpu: 8, pe_cpu→noc: 8, noc→pe_mmu: 8
|
||||||
# xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8
|
# xbar_top→hbm{0..3}: 4+4=8, xbar_bot→hbm{4..7}: 4+4=8
|
||||||
# noc↔xbar_top: 2, noc↔xbar_bot: 2
|
# noc↔xbar_top: 2, noc↔xbar_bot: 2
|
||||||
# xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2
|
# xbar_top↔bridge.left: 2, bridge.left↔xbar_bot: 2
|
||||||
# xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2
|
# xbar_top↔bridge.right: 2, bridge.right↔xbar_bot: 2
|
||||||
# ucie: 64, m_cpu↔noc: 2, noc↔sram: 2
|
# ucie: 64, m_cpu↔noc: 2, noc↔sram: 2
|
||||||
# Total: 56+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 184
|
# Total: 56+8+8+8+8+8+8+8+2+2+2+2+2+2+64+2+2 = 192
|
||||||
# IO edges per SIP: 77
|
# IO edges per SIP: 77
|
||||||
# Per SIP: 16*184 + 48 inter-cube + 77 IO = 3069
|
# Per SIP: 16*192 + 48 inter-cube + 77 IO = 3197
|
||||||
# Total: 2 * 3069 = 6138
|
# Total: 2 * 3197 = 6394
|
||||||
assert len(g.edges) == 6138
|
assert len(g.edges) == 6394
|
||||||
|
|
||||||
|
|
||||||
# ── Full graph: specific nodes exist ─────────────────────────────────
|
# ── Full graph: specific nodes exist ─────────────────────────────────
|
||||||
@@ -267,7 +268,7 @@ def test_cube_view_pe_to_noc():
|
|||||||
def test_pe_view_has_all_components():
|
def test_pe_view_has_all_components():
|
||||||
v = _graph().pe_view
|
v = _graph().pe_view
|
||||||
assert set(v.nodes.keys()) == {
|
assert set(v.nodes.keys()) == {
|
||||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def test_pe_template_components():
|
|||||||
spec = _read_spec(TOPOLOGY_PATH)
|
spec = _read_spec(TOPOLOGY_PATH)
|
||||||
comps = spec["cube"]["pe_template"]["components"]
|
comps = spec["cube"]["pe_template"]["components"]
|
||||||
assert set(comps.keys()) == {
|
assert set(comps.keys()) == {
|
||||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_mmu", "pe_tcm"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def test_tl_load_generates_dma_read():
|
|||||||
cmds = tl.commands
|
cmds = tl.commands
|
||||||
assert len(cmds) == 1
|
assert len(cmds) == 1
|
||||||
assert isinstance(cmds[0], DmaReadCmd)
|
assert isinstance(cmds[0], DmaReadCmd)
|
||||||
assert cmds[0].src_pa == 0x1000
|
assert cmds[0].src_addr == 0x1000
|
||||||
assert cmds[0].nbytes == 32 * 64 * 2
|
assert cmds[0].nbytes == 32 * 64 * 2
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def test_tl_store_generates_dma_write():
|
|||||||
tl.store(0x2000, h)
|
tl.store(0x2000, h)
|
||||||
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||||
assert len(cmds) == 1
|
assert len(cmds) == 1
|
||||||
assert cmds[0].dst_pa == 0x2000
|
assert cmds[0].dst_addr == 0x2000
|
||||||
assert cmds[0].nbytes == 16 * 16 * 4
|
assert cmds[0].nbytes == 16 * 16 * 4
|
||||||
|
|
||||||
|
|
||||||
@@ -148,7 +148,7 @@ def test_tl_composite_nonblocking():
|
|||||||
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
||||||
assert len(comp_cmds) == 1
|
assert len(comp_cmds) == 1
|
||||||
assert comp_cmds[0].op == "gemm"
|
assert comp_cmds[0].op == "gemm"
|
||||||
assert comp_cmds[0].out_pa == 0x3000
|
assert comp_cmds[0].out_addr == 0x3000
|
||||||
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
"""Tests for VirtualAllocator: device-wide VA space management.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
T7. Basic VA allocation (contiguous, non-overlapping)
|
||||||
|
T8. VA free + reallocation (free-list reuse)
|
||||||
|
T9. VA space exhaustion raises AllocationError
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||||
|
|
||||||
|
_KB = 1024
|
||||||
|
_MB = 1024 * 1024
|
||||||
|
_GB = 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ── T7. Basic VA allocation ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_returns_aligned_va():
|
||||||
|
"""First allocation returns va_base."""
|
||||||
|
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||||
|
addr = va.alloc(4096)
|
||||||
|
assert addr == 0x1_0000_0000
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_sequential_non_overlapping():
|
||||||
|
"""Two allocations return contiguous, non-overlapping VA ranges."""
|
||||||
|
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||||
|
a1 = va.alloc(4096)
|
||||||
|
a2 = va.alloc(8192)
|
||||||
|
assert a1 == 0x1_0000_0000
|
||||||
|
assert a2 == 0x1_0000_1000 # a1 + 4096
|
||||||
|
# No overlap
|
||||||
|
assert a2 >= a1 + 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_page_aligned():
|
||||||
|
"""Allocations are page-aligned even if requested size is not page-multiple."""
|
||||||
|
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||||
|
a1 = va.alloc(100) # < 1 page, but occupies 1 page
|
||||||
|
a2 = va.alloc(100)
|
||||||
|
assert a2 == 0x1_0000_1000 # aligned to next page
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_large_contiguous():
|
||||||
|
"""Large allocation (multiple pages) is contiguous."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=2 * _MB)
|
||||||
|
addr = va.alloc(8 * _MB) # 4 pages
|
||||||
|
assert addr == 0x0
|
||||||
|
# Next alloc starts after 8 MB
|
||||||
|
addr2 = va.alloc(2 * _MB)
|
||||||
|
assert addr2 == 8 * _MB
|
||||||
|
|
||||||
|
|
||||||
|
# ── T8. VA free + reallocation ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_and_realloc():
|
||||||
|
"""Freed VA range can be reused by subsequent allocation."""
|
||||||
|
va = VirtualAllocator(va_base=0x1_0000_0000, va_size=1 * _GB, page_size=4096)
|
||||||
|
a1 = va.alloc(4096)
|
||||||
|
a2 = va.alloc(4096)
|
||||||
|
va.free(a1, 4096)
|
||||||
|
|
||||||
|
# New alloc should reuse a1's range
|
||||||
|
a3 = va.alloc(4096)
|
||||||
|
assert a3 == a1
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_coalesce():
|
||||||
|
"""Freeing adjacent blocks allows larger reallocation."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
|
||||||
|
a1 = va.alloc(4096)
|
||||||
|
a2 = va.alloc(4096)
|
||||||
|
a3 = va.alloc(4096)
|
||||||
|
|
||||||
|
# Free first two (adjacent)
|
||||||
|
va.free(a1, 4096)
|
||||||
|
va.free(a2, 4096)
|
||||||
|
|
||||||
|
# Should be able to allocate 8192 contiguous from the freed region
|
||||||
|
a4 = va.alloc(8192)
|
||||||
|
assert a4 == a1 # reuses coalesced region
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_out_of_order():
|
||||||
|
"""Free in non-sequential order still works."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=1 * _GB, page_size=4096)
|
||||||
|
a1 = va.alloc(4096)
|
||||||
|
a2 = va.alloc(4096)
|
||||||
|
a3 = va.alloc(4096)
|
||||||
|
|
||||||
|
va.free(a2, 4096) # free middle
|
||||||
|
a4 = va.alloc(4096)
|
||||||
|
assert a4 == a2 # reuses middle slot
|
||||||
|
|
||||||
|
|
||||||
|
# ── T9. VA space exhaustion ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_exhaustion():
|
||||||
|
"""Allocation beyond VA space raises AllocationError."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
|
||||||
|
va.alloc(4096)
|
||||||
|
va.alloc(4096)
|
||||||
|
with pytest.raises(Exception, match="[Aa]lloc|[Ee]xhaust|[Oo]ut of"):
|
||||||
|
va.alloc(4096)
|
||||||
|
|
||||||
|
|
||||||
|
def test_alloc_after_partial_free():
|
||||||
|
"""After freeing some, can allocate again within freed space."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=8192, page_size=4096)
|
||||||
|
a1 = va.alloc(4096)
|
||||||
|
a2 = va.alloc(4096)
|
||||||
|
|
||||||
|
# Space is full
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
va.alloc(4096)
|
||||||
|
|
||||||
|
# Free one, now can allocate again
|
||||||
|
va.free(a1, 4096)
|
||||||
|
a3 = va.alloc(4096)
|
||||||
|
assert a3 == a1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stats / inspection ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_used_and_total():
|
||||||
|
"""used and total properties reflect allocation state."""
|
||||||
|
va = VirtualAllocator(va_base=0x0, va_size=1 * _MB, page_size=4096)
|
||||||
|
assert va.used == 0
|
||||||
|
assert va.total == 1 * _MB
|
||||||
|
|
||||||
|
va.alloc(4096)
|
||||||
|
assert va.used == 4096
|
||||||
|
|
||||||
|
va.alloc(8192)
|
||||||
|
assert va.used == 4096 + 8192 # page-aligned: 4096 + 8192 = 12288
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
"""Tests for VA integration: Tensor, TLContext, and DMA commands use VA.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
T10. TensorHandle has va_base; TensorShard does NOT have va field
|
||||||
|
T11. deploy_tensor allocates VA + creates mapping entries
|
||||||
|
T12. Tensor.va returns the tensor's VA base
|
||||||
|
T13. tl.load/tl.store generate DMA commands with VA (not PA)
|
||||||
|
T14. Kernel VA-based offset calculation flows through DMA commands
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||||
|
from kernbench.policy.address.pe_mmu import PeMMU
|
||||||
|
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||||
|
from kernbench.policy.placement.dp import column_wise, ShardSpec
|
||||||
|
from kernbench.runtime_api.tensor import (
|
||||||
|
TensorHandle,
|
||||||
|
TensorShard,
|
||||||
|
deploy_tensor,
|
||||||
|
)
|
||||||
|
from kernbench.runtime_api.kernel import TensorArgShard
|
||||||
|
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||||
|
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||||
|
|
||||||
|
_MB = 1 << 20
|
||||||
|
_GB = 1 << 30
|
||||||
|
|
||||||
|
_CFG = AddressConfig(
|
||||||
|
sip_count=2,
|
||||||
|
cubes_per_sip=16,
|
||||||
|
pes_per_cube=8,
|
||||||
|
hbm_bytes_per_cube=48 * _GB,
|
||||||
|
hbm_slices_per_cube=8,
|
||||||
|
tcm_bytes_per_pe=16 * _MB,
|
||||||
|
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||||
|
sram_bytes_per_cube=32 * _MB,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||||
|
return {
|
||||||
|
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||||
|
for i in range(num_pe)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mmus(num_pe: int = 8, page_size: int = 4096) -> dict[int, PeMMU]:
|
||||||
|
return {i: PeMMU(page_size=page_size) for i in range(num_pe)}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_va_allocator() -> VirtualAllocator:
|
||||||
|
return VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=2 * _MB)
|
||||||
|
|
||||||
|
|
||||||
|
# ── T10. TensorHandle has va_base ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_handle_has_va_base():
|
||||||
|
"""TensorHandle must have a 'va_base' field."""
|
||||||
|
th = TensorHandle(
|
||||||
|
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||||
|
shards=(), va_base=0x1_0000_0000,
|
||||||
|
)
|
||||||
|
assert th.va_base == 0x1_0000_0000
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_handle_va_base_immutable():
|
||||||
|
"""TensorHandle.va_base is immutable (frozen dataclass)."""
|
||||||
|
th = TensorHandle(
|
||||||
|
name="A", shape=(1024, 512), dtype="fp16", itemsize=2,
|
||||||
|
shards=(), va_base=0x1_0000_0000,
|
||||||
|
)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
th.va_base = 0x2_0000_0000 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_shard_no_va_field():
|
||||||
|
"""TensorShard should NOT have a va field — va is derived from
|
||||||
|
TensorHandle.va_base + shard.offset_bytes."""
|
||||||
|
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||||||
|
assert not hasattr(ts, "va"), "TensorShard should not have a 'va' field"
|
||||||
|
|
||||||
|
|
||||||
|
# ── T11. deploy_tensor allocates VA + creates mappings ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_deploy_tensor_assigns_va_base():
|
||||||
|
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
|
||||||
|
allocs = _make_allocators()
|
||||||
|
va_alloc = _make_va_allocator()
|
||||||
|
mmus = _make_mmus()
|
||||||
|
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||||
|
|
||||||
|
th = deploy_tensor(
|
||||||
|
name="W",
|
||||||
|
shape=(1024, 512),
|
||||||
|
dtype="fp16",
|
||||||
|
placement=placement,
|
||||||
|
allocators=allocs,
|
||||||
|
va_allocator=va_alloc,
|
||||||
|
mmus=mmus,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert th.va_base is not None
|
||||||
|
assert th.va_base > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_deploy_tensor_va_covers_all_shards():
|
||||||
|
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
|
||||||
|
allocs = _make_allocators()
|
||||||
|
va_alloc = _make_va_allocator()
|
||||||
|
mmus = _make_mmus()
|
||||||
|
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||||
|
|
||||||
|
th = deploy_tensor(
|
||||||
|
name="W",
|
||||||
|
shape=(1024, 512),
|
||||||
|
dtype="fp16",
|
||||||
|
placement=placement,
|
||||||
|
allocators=allocs,
|
||||||
|
va_allocator=va_alloc,
|
||||||
|
mmus=mmus,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each shard's VA is derivable: va_base + offset_bytes
|
||||||
|
for s in th.shards:
|
||||||
|
shard_va = th.va_base + s.offset_bytes
|
||||||
|
assert shard_va > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_deploy_tensor_registers_mmu_mappings():
|
||||||
|
"""deploy_tensor registers VA→PA mappings in all PE MMUs."""
|
||||||
|
allocs = _make_allocators()
|
||||||
|
va_alloc = _make_va_allocator()
|
||||||
|
mmus = _make_mmus()
|
||||||
|
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||||
|
|
||||||
|
th = deploy_tensor(
|
||||||
|
name="W",
|
||||||
|
shape=(1024, 512),
|
||||||
|
dtype="fp16",
|
||||||
|
placement=placement,
|
||||||
|
allocators=allocs,
|
||||||
|
va_allocator=va_alloc,
|
||||||
|
mmus=mmus,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Every MMU should have entries (broadcast)
|
||||||
|
for mmu in mmus.values():
|
||||||
|
assert mmu.num_entries > 0
|
||||||
|
|
||||||
|
# Each shard's derived VA should translate to its PA in every MMU
|
||||||
|
for mmu in mmus.values():
|
||||||
|
for s in th.shards:
|
||||||
|
shard_va = th.va_base + s.offset_bytes
|
||||||
|
assert mmu.translate(shard_va) == s.pa
|
||||||
|
|
||||||
|
|
||||||
|
# ── T12. Tensor.va property ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_va_property():
|
||||||
|
"""Tensor.va returns the VA base of the entire tensor (from TensorHandle.va_base)."""
|
||||||
|
from kernbench.runtime_api.tensor import Tensor
|
||||||
|
|
||||||
|
allocs = _make_allocators(1)
|
||||||
|
va_alloc = _make_va_allocator()
|
||||||
|
mmus = _make_mmus(1)
|
||||||
|
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
|
||||||
|
|
||||||
|
t = Tensor(shape=(2048,), dtype="f16", name="test")
|
||||||
|
t._handle = deploy_tensor(
|
||||||
|
name="test",
|
||||||
|
shape=(2048,),
|
||||||
|
dtype="f16",
|
||||||
|
placement=placement,
|
||||||
|
allocators=allocs,
|
||||||
|
va_allocator=va_alloc,
|
||||||
|
mmus=mmus,
|
||||||
|
)
|
||||||
|
assert t.va > 0
|
||||||
|
assert t.va == t._handle.va_base
|
||||||
|
|
||||||
|
|
||||||
|
# ── T13. tl.load/tl.store use VA ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tl_load_uses_va_in_dma_cmd():
|
||||||
|
"""tl.load(va_ptr) generates DmaReadCmd with src_va (not src_pa)."""
|
||||||
|
tl = TLContext(dispatch_cycles=0)
|
||||||
|
va_ptr = 0x1_0000_0000
|
||||||
|
h = tl.load(va_ptr, shape=(32, 64), dtype="f16")
|
||||||
|
|
||||||
|
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||||
|
assert len(dma_cmds) == 1
|
||||||
|
# The DMA command should carry the VA
|
||||||
|
assert dma_cmds[0].src_addr == va_ptr
|
||||||
|
|
||||||
|
|
||||||
|
def test_tl_store_uses_va_in_dma_cmd():
|
||||||
|
"""tl.store(va_ptr, handle) generates DmaWriteCmd with dst_va."""
|
||||||
|
tl = TLContext(dispatch_cycles=0)
|
||||||
|
h = tl.load(0x1_0000_0000, shape=(16, 16), dtype="f32")
|
||||||
|
va_out = 0x2_0000_0000
|
||||||
|
tl.store(va_out, h)
|
||||||
|
|
||||||
|
dma_cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||||
|
assert len(dma_cmds) == 1
|
||||||
|
assert dma_cmds[0].dst_addr == va_out
|
||||||
|
|
||||||
|
|
||||||
|
# ── T14. Kernel VA offset calculation ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_va_offset_in_dma():
|
||||||
|
"""Kernel using base_va + pid * stride generates correct VA in DmaReadCmd."""
|
||||||
|
def tiled_kernel(a_ptr, tl, BLOCK_SIZE=1024, DTYPE="f16"):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
elem_bytes = 2 # f16
|
||||||
|
offset = pid * BLOCK_SIZE * elem_bytes
|
||||||
|
a = tl.load(a_ptr + offset, shape=(BLOCK_SIZE,), dtype=DTYPE)
|
||||||
|
|
||||||
|
va_base = 0x1_0000_0000
|
||||||
|
tl = TLContext(pe_id=3, num_programs=8, dispatch_cycles=0)
|
||||||
|
run_kernel(tiled_kernel, tl, a_ptr=va_base)
|
||||||
|
|
||||||
|
dma_cmds = [c for c in tl.commands if isinstance(c, DmaReadCmd)]
|
||||||
|
assert len(dma_cmds) == 1
|
||||||
|
expected_va = va_base + 3 * 1024 * 2 # pid=3, BLOCK_SIZE=1024, 2 bytes
|
||||||
|
assert dma_cmds[0].src_addr == expected_va
|
||||||
@@ -65,6 +65,7 @@ cube:
|
|||||||
pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } }
|
pe_dma: { kind: pe_dma, impl: pe_dma_v1, attrs: { rd_engines: 1, wr_engines: 1 } }
|
||||||
pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } }
|
pe_gemm: { kind: pe_gemm, impl: pe_gemm_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot, peak_tflops_f16: 8.0 } }
|
||||||
pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }
|
pe_math: { kind: pe_math, impl: pe_math_v1, attrs: { overhead_ns: 0.0, shared_resource: accel_slot } }
|
||||||
|
pe_mmu: { kind: pe_mmu, impl: pe_mmu_v1, attrs: { tlb_overhead_ns: 0.5, page_size: 4096 } }
|
||||||
pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs:
|
pe_tcm: { kind: pe_tcm, impl: pe_tcm_v1, attrs:
|
||||||
{ size_mb: 16 } }
|
{ size_mb: 16 } }
|
||||||
links:
|
links:
|
||||||
|
|||||||
Reference in New Issue
Block a user