Add virtual memory support: PE_MMU, VA allocator, fabric MmuMapMsg
Implement VA/MMU layer (ADR-0011 Phase 1) enabling Triton kernels to use contiguous virtual addresses on sharded tensors. Key changes: - PE_MMU component: hybrid inbox (MmuMapMsg) + sync translate() for PE_DMA - VirtualAllocator + PEMemAllocator: free-list with coalescing - MmuMapMsg/MmuUnmapMsg fabric path with SIP-level routing - DPPolicy-based mapping: replicate=local, sharded=broadcast - Tensor lifecycle: del + weakref cleanup, context manager - Rename: TensorHandle.pa→addr, DmaReadCmd.src_pa→src_addr, ctx→torch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
def run(ctx):
|
||||
def run(torch):
|
||||
print("IPCQ all reduce kernel bench")
|
||||
|
||||
+2
-2
@@ -15,7 +15,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
|
||||
|
||||
Expected layout (repo root):
|
||||
benches/<bench_id>.py
|
||||
def run(ctx: RuntimeContext) -> Any
|
||||
def run(torch: RuntimeContext) -> Any
|
||||
"""
|
||||
bench_id = bench_id.strip()
|
||||
if not bench_id:
|
||||
@@ -30,7 +30,7 @@ def resolve_bench(bench_id: str) -> BenchFn:
|
||||
|
||||
run_fn = getattr(mod, "run", None)
|
||||
if run_fn is None:
|
||||
raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.")
|
||||
raise ValueError(f"Bench module {module_path} must define a 'run(torch)' function.")
|
||||
if not callable(run_fn):
|
||||
raise ValueError(f"'run' in {module_path} is not callable.")
|
||||
|
||||
|
||||
+5
-5
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
tl.wait(handle)
|
||||
|
||||
|
||||
def run(ctx):
|
||||
def run(torch):
|
||||
"""Run the QKV GEMM benchmark."""
|
||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
|
||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = ctx.empty(
|
||||
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = torch.empty(
|
||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||
)
|
||||
|
||||
# Launch GEMM kernel
|
||||
ctx.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|
||||
torch.launch("qkv_gemm", _gemm_kernel, a, b, out, M, K, N)
|
||||
|
||||
@@ -26,14 +26,14 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
tl.wait(handle)
|
||||
|
||||
|
||||
def run(ctx):
|
||||
def run(torch):
|
||||
"""Run the multi-PE QKV GEMM benchmark."""
|
||||
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
|
||||
a = ctx.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = ctx.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = ctx.empty(
|
||||
a = torch.zeros((M, K), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="replicate"), name="a")
|
||||
b = torch.zeros((K, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="b")
|
||||
out = torch.empty(
|
||||
(M, N), dtype=DTYPE, dp=DPPolicy(cube="replicate", pe="column_wise"), name="out",
|
||||
)
|
||||
|
||||
# Launch GEMM kernel on all PEs
|
||||
ctx.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
|
||||
torch.launch("qkv_gemm_multi", _gemm_kernel, a, b, out, M, K, N)
|
||||
|
||||
Reference in New Issue
Block a user