Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d8b9401e5 | |||
| cfc2d74ec4 | |||
| 105f1dc09e | |||
| e7f376ebaa | |||
| 357cab525b | |||
| 787409ced1 | |||
| 79124daab1 | |||
| 4ba0a83e71 |
@@ -29,3 +29,4 @@ build/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
.claude/
|
||||
|
||||
+80
-106
@@ -1,129 +1,103 @@
|
||||
"""CCL all-reduce bench — single unified entry point.
|
||||
"""CCL all-reduce bench (ADR-0024 + ADR-0027).
|
||||
|
||||
Driven entirely by ``ccl.yaml`` + ``topology.yaml``:
|
||||
Pure TP launcher model: rank = SIP. Each rank owns a ``(N_CUBES, n_elem)``
|
||||
tensor sharded row-wise across the cube mesh (pe0 per cube). After
|
||||
``dist.all_reduce(op="sum")`` every cube on every rank must hold
|
||||
``N_CUBES * sum(1..world_size)``. Rank 0 prints the pass/fail line.
|
||||
|
||||
- ``defaults.algorithm`` in ``ccl.yaml`` picks which kernel to run
|
||||
(``ring_allreduce_{tcm,hbm,sram}`` / ``mesh_allreduce_4`` /
|
||||
``tree_allreduce_7``).
|
||||
- ``world_size`` is derived from the algorithm entry's override or from
|
||||
the topology spec (``sips × cubes_per_sip × pes_per_cube``).
|
||||
- The host code uses only real PyTorch ``torch.distributed`` names:
|
||||
``init_process_group``, ``get_world_size``, ``get_rank``, ``all_reduce``.
|
||||
|
||||
The bench is split into ``worker(rank, world_size, torch)`` — the
|
||||
per-rank business logic, designed to look like a real PyTorch DDP
|
||||
training worker so future model benches can reuse the same skeleton —
|
||||
and ``run(torch)`` — the kernbench-specific launcher that initializes
|
||||
the process group and invokes the worker.
|
||||
Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml``
|
||||
(SIP count → world_size, cube_mesh → N_CUBES).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
# Default per-rank tile size if ccl.yaml doesn't override it. Real
|
||||
# pytorch benches hardcode batch/feature dims similarly.
|
||||
DEFAULT_N_ELEM = 32
|
||||
DEFAULT_N_ELEM = 8
|
||||
|
||||
|
||||
def _derive_dp(spec: dict, world_size: int) -> DPPolicy:
|
||||
"""Pick a DPPolicy that fans the tensor across exactly ``world_size`` PEs.
|
||||
@dataclass(frozen=True)
|
||||
class _BenchCfg:
|
||||
algorithm: str
|
||||
n_elem: int
|
||||
n_cubes: int
|
||||
world_size: int
|
||||
|
||||
Mirrors what a real PyTorch DDP user does manually with
|
||||
``tensor.to(f"cuda:{rank}")``: the host code chooses the placement so
|
||||
that the collective sees the right number of participating ranks.
|
||||
"""
|
||||
sips = int(spec["system"]["sips"]["count"])
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
pl = spec["cube"]["pe_layout"]
|
||||
pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"])
|
||||
cubes_per_sip = int(cm["w"]) * int(cm["h"])
|
||||
total = sips * cubes_per_sip * pes_per_cube
|
||||
if world_size == total:
|
||||
return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
if world_size <= pes_per_cube:
|
||||
return DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=world_size,
|
||||
|
||||
def _resolve_cfg(torch) -> _BenchCfg:
|
||||
"""Read ccl.yaml + topology once at host side."""
|
||||
merged = resolve_algorithm_config(load_ccl_config())
|
||||
ws = torch.distributed.get_world_size()
|
||||
spec = torch.spec or {}
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
if ws != n_sips:
|
||||
raise RuntimeError(
|
||||
f"ccl_allreduce bench requires world_size == topology SIP count "
|
||||
f"(world_size={ws}, n_sips={n_sips})."
|
||||
)
|
||||
if world_size <= cubes_per_sip * pes_per_cube:
|
||||
return DPPolicy(
|
||||
sip="replicate", cube="column_wise", pe="column_wise",
|
||||
num_sips=1, num_cubes=world_size // pes_per_cube,
|
||||
)
|
||||
return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
|
||||
|
||||
def worker(rank: int, world_size: int, torch) -> None:
|
||||
"""Per-rank business logic. Mirrors a real PyTorch DDP worker.
|
||||
|
||||
In real PyTorch DDP, this function runs in N separate processes,
|
||||
each with its own ``rank``. In kernbench (single-process multi-device)
|
||||
it is invoked once with ``rank=0`` on the single host driver; the
|
||||
actual per-PE parallelism is handled by ``torch.launch`` fanning out
|
||||
the kernel across all participating PEs via the tensor's DPPolicy.
|
||||
The ``rank`` parameter is therefore always 0 today, and is kept as
|
||||
an explicit argument for parity with real DDP workers (``if rank ==
|
||||
0`` logging guards, future multi-host extensions).
|
||||
"""
|
||||
cfg = resolve_algorithm_config(load_ccl_config())
|
||||
algo_name = cfg["algorithm"]
|
||||
n_elem = int(cfg.get("n_elem", DEFAULT_N_ELEM))
|
||||
|
||||
# Pick a DP that produces exactly ``world_size`` shards on this topology.
|
||||
dp = _derive_dp(torch.spec, world_size)
|
||||
tensor = torch.zeros(
|
||||
(1, world_size * n_elem), dtype="f16", dp=dp, name="ccl_in",
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
n_cubes = int(cm.get("w", 4)) * int(cm.get("h", 4))
|
||||
return _BenchCfg(
|
||||
algorithm=merged["algorithm"],
|
||||
n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)),
|
||||
n_cubes=n_cubes,
|
||||
world_size=ws,
|
||||
)
|
||||
|
||||
# Initialize: CCL rank r's slice gets value (r + 1). Real PyTorch idiom:
|
||||
# target.copy_(torch.from_numpy(source))
|
||||
init = np.zeros((1, world_size * n_elem), dtype=np.float16)
|
||||
for r in range(world_size):
|
||||
init[0, r * n_elem : (r + 1) * n_elem] = float(r + 1)
|
||||
tensor.copy_(torch.from_numpy(init))
|
||||
|
||||
# The main act: one all_reduce call — the backend installs IPCQ at
|
||||
# init_process_group time and here only dispatches the kernel.
|
||||
def _rank_dp(n_cubes: int) -> DPPolicy:
|
||||
return DPPolicy(cube="row_wise", pe="replicate", num_cubes=n_cubes, num_pes=1)
|
||||
|
||||
|
||||
def _allocate_rank_tensor(torch, rank: int, cfg: _BenchCfg):
|
||||
"""Allocate this rank's ``(n_cubes, n_elem)`` tensor on its SIP."""
|
||||
return torch.zeros(
|
||||
(cfg.n_cubes, cfg.n_elem), dtype="f16",
|
||||
dp=_rank_dp(cfg.n_cubes), name=f"ccl_in_r{rank}",
|
||||
)
|
||||
|
||||
|
||||
def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None:
|
||||
"""Fill all cubes with the scalar ``rank + 1``."""
|
||||
arr = np.full((cfg.n_cubes, cfg.n_elem), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(arr))
|
||||
|
||||
|
||||
def _report(result: np.ndarray, cfg: _BenchCfg) -> None:
|
||||
"""Single-line pass/fail printer (rank 0 only)."""
|
||||
expected = float(cfg.n_cubes * sum(range(1, cfg.world_size + 1)))
|
||||
ok = True
|
||||
for cube_id in range(cfg.n_cubes):
|
||||
if not np.allclose(result[cube_id], expected, rtol=1e-1, atol=1e-1):
|
||||
ok = False
|
||||
break
|
||||
if ok:
|
||||
total = cfg.world_size * cfg.n_cubes
|
||||
print(f" {cfg.algorithm} (ws={cfg.world_size}): {total} OK")
|
||||
return
|
||||
got = float(result.reshape(-1).mean())
|
||||
print(
|
||||
f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): "
|
||||
f"got mean={got:.3f}, expected={expected:.3f}"
|
||||
)
|
||||
|
||||
|
||||
def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
|
||||
torch.ahbm.set_device(rank)
|
||||
tensor = _allocate_rank_tensor(torch, rank, cfg)
|
||||
_init_with_rank_value(torch, tensor, rank, cfg)
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
|
||||
# Verify: each shard should hold sum(1..world_size) after all-reduce.
|
||||
result = tensor.numpy()
|
||||
expected = float(sum(range(1, world_size + 1)))
|
||||
all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
|
||||
|
||||
# Print only on rank 0 — real PyTorch DDP idiom for single-source logs.
|
||||
if rank == 0:
|
||||
if all_ok:
|
||||
print(f" {algo_name} (ws={world_size}): {world_size} OK")
|
||||
else:
|
||||
flat = result.reshape(-1)
|
||||
n_fail = 0
|
||||
for r in range(world_size):
|
||||
slice_r = flat[r * n_elem : (r + 1) * n_elem]
|
||||
if not np.allclose(slice_r, expected, rtol=1e-1, atol=1e-1):
|
||||
n_fail += 1
|
||||
if n_fail <= 5:
|
||||
print(
|
||||
f" [FAIL] rank {r} "
|
||||
f"(ws={world_size}, algo={algo_name}): "
|
||||
f"got mean={float(slice_r.mean()):.3f}, "
|
||||
f"expected={expected:.3f}"
|
||||
)
|
||||
print(
|
||||
f" {algo_name} (ws={world_size}): "
|
||||
f"{world_size - n_fail} OK / {n_fail} FAIL"
|
||||
)
|
||||
_report(tensor.numpy(), cfg)
|
||||
|
||||
|
||||
def run(torch) -> None:
|
||||
"""CLI entry point: initialize the process group, invoke worker."""
|
||||
dist = torch.distributed
|
||||
dist.init_process_group(backend="ahbm")
|
||||
worker(
|
||||
rank=dist.get_rank(),
|
||||
world_size=dist.get_world_size(),
|
||||
torch=torch,
|
||||
torch.distributed.init_process_group(backend="ahbm")
|
||||
cfg = _resolve_cfg(torch)
|
||||
torch.multiprocessing.spawn(
|
||||
_worker, args=(cfg, torch), nprocs=cfg.world_size,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Full host-to-PE pipeline:
|
||||
Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM
|
||||
|
||||
Single PE: num_sips=1, num_cubes=1, num_pes=1 via DPPolicy override.
|
||||
Single PE: num_cubes=1, num_pes=1 via DPPolicy override.
|
||||
Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams
|
||||
per-tile DMA internally.
|
||||
|
||||
@@ -30,7 +30,7 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||
def run(torch):
|
||||
"""Run the single-PE GEMM benchmark."""
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
num_cubes=1, num_pes=1)
|
||||
|
||||
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
|
||||
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
|
||||
|
||||
+8
-4
@@ -72,12 +72,16 @@ def run(torch):
|
||||
K = GPT3_D_MODEL
|
||||
N = COLS_PER_PE
|
||||
|
||||
# X: replicated across all PEs
|
||||
# ADR-0026: DPPolicy is intra-device only. For multi-SIP execution the
|
||||
# ADR-0024 launcher calls this bench once per SIP (each worker via
|
||||
# torch.ahbm.set_device(rank)); here the policy describes only the
|
||||
# cube × PE layout within a single SIP.
|
||||
# X: replicated across all PEs within the SIP
|
||||
dp_replicate = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs
|
||||
num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs within the SIP
|
||||
dp_sharded = DPPolicy(cube="column_wise", pe="column_wise",
|
||||
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
|
||||
|
||||
x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x")
|
||||
wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""VA offset verification benchmark.
|
||||
|
||||
Verifies that Triton-style base_ptr + pid * stride addressing works correctly
|
||||
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own
|
||||
with intra-SIP TP sharding (cube/pe column_wise). Each PE loads its own
|
||||
block from a sharded tensor and stores it back.
|
||||
|
||||
The kernel uses standard Triton patterns:
|
||||
@@ -28,7 +28,7 @@ def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
|
||||
|
||||
def run(torch):
|
||||
"""Run the VA offset verification benchmark with full TP sharding."""
|
||||
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise")
|
||||
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
|
||||
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
|
||||
|
||||
|
||||
@@ -6,12 +6,7 @@
|
||||
|
||||
defaults:
|
||||
# Algorithm to run for this benchmark execution.
|
||||
algorithm: ring_allreduce_tcm
|
||||
|
||||
# NOTE: world_size is not set here by default. AhbmCCLBackend derives it
|
||||
# from the chosen algorithm's entry (if it sets ``world_size``) or from
|
||||
# topology.yaml (``sips × cubes_per_sip × pes_per_cube``). This mirrors
|
||||
# real PyTorch DDP where ranks/world_size come from env vars, not code.
|
||||
algorithm: intercube_allreduce
|
||||
|
||||
# IPCQ ring buffer location.
|
||||
# tcm — PE-local TCM (fast, small, conflicts with compute TCM access)
|
||||
@@ -30,59 +25,21 @@ defaults:
|
||||
# Slot size in bytes (must hold one tile worth of data).
|
||||
slot_size: 4096
|
||||
|
||||
# PE_DMA virtual channel chunk size (D8). First implementation does not
|
||||
# use chunk-level interleave; this is reserved for future precision.
|
||||
# PE_DMA virtual channel chunk size (D8).
|
||||
vc_chunk_size: 256
|
||||
|
||||
# Credit return fast path message size (D9). Used by bottleneck-BW
|
||||
# latency calculation. 16-64 bytes typical.
|
||||
# Credit return fast path message size (D9).
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
# ── ring all-reduce, buffer in PE_TCM ──
|
||||
# Defaults to topology-derived world_size (full system, 256 ranks).
|
||||
# Use a smaller tile size at high rank counts so f16 sums stay within
|
||||
# the verification tolerance and op_log replay scales.
|
||||
ring_allreduce_tcm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
|
||||
# ── ring all-reduce, buffer in PE-local HBM ──
|
||||
ring_allreduce_hbm:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: hbm
|
||||
n_elem: 8
|
||||
|
||||
# ── ring all-reduce, buffer in cube SRAM ──
|
||||
ring_allreduce_sram:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: sram
|
||||
n_elem: 8
|
||||
|
||||
# ── 2D mesh all-reduce: perfect square only (2×2 = 4 PEs) ──
|
||||
mesh_allreduce_4:
|
||||
module: kernbench.ccl.algorithms.mesh_allreduce
|
||||
topology: mesh_2d
|
||||
buffer_kind: tcm
|
||||
world_size: 4
|
||||
n_elem: 16
|
||||
|
||||
# ── tree all-reduce (binary, 7 PEs) ──
|
||||
tree_allreduce_7:
|
||||
module: kernbench.ccl.algorithms.tree_allreduce
|
||||
topology: tree_binary
|
||||
buffer_kind: tcm
|
||||
world_size: 7
|
||||
n_elem: 16
|
||||
|
||||
# ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ──
|
||||
# Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat.
|
||||
hierarchical_allreduce:
|
||||
module: kernbench.ccl.algorithms.hierarchical_allreduce
|
||||
# ── intercube all-reduce (pe0-only, cube mesh + inter-SIP) ──
|
||||
# Reduces across the 4×4 cube mesh within each SIP, then inter-SIP
|
||||
# exchange on root cube, then broadcast back. SIP topology is read
|
||||
# from topology.yaml → system.sips.topology. Kernel auto-selects
|
||||
# ring / torus / mesh inter-SIP exchange pattern.
|
||||
intercube_allreduce:
|
||||
module: kernbench.ccl.algorithms.intercube_allreduce
|
||||
topology: none
|
||||
buffer_kind: tcm
|
||||
n_elem: 16
|
||||
n_elem: 8
|
||||
root_cube: 15
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## Status
|
||||
|
||||
Proposed (Revision 4 — 문서 일관성 + grep audit 구체화)
|
||||
Accepted (Revision 5 — Phase 2 landed 2026-04-14, 523 passed + 1 strict xfail)
|
||||
|
||||
## Context
|
||||
|
||||
@@ -69,9 +69,9 @@ class DPPolicy:
|
||||
class DPPolicy:
|
||||
"""Intra-device (cube × PE) data-parallel policy.
|
||||
|
||||
SIP-level placement is controlled by ``torch.cuda.set_device(rank)``
|
||||
(ADR-0024) and, for model-level TP, by Megatron-style parallel layers
|
||||
(ADR-0027). DPPolicy does not cross SIP boundaries.
|
||||
SIP-level placement is controlled by ``torch.ahbm.set_device(rank)``
|
||||
(ADR-0024 D10) and, for model-level TP, by Megatron-style parallel
|
||||
layers (ADR-0027). DPPolicy does not cross SIP boundaries.
|
||||
"""
|
||||
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
|
||||
+899
-132
File diff suppressed because it is too large
Load Diff
@@ -129,8 +129,8 @@ N_ELEM = 8
|
||||
def worker(rank: int, world_size: int, torch) -> None:
|
||||
"""Per-rank business logic — mirrors a real PyTorch DDP worker."""
|
||||
dp = DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=world_size,
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1, num_pes=world_size,
|
||||
)
|
||||
tensor = torch.zeros(
|
||||
(1, world_size * N_ELEM), dtype="f16", dp=dp, name="hello_in",
|
||||
|
||||
@@ -114,8 +114,8 @@ def run(torch):
|
||||
a = torch.zeros(
|
||||
(1, WORLD_SIZE * N_ELEM), dtype="f16",
|
||||
dp=DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1,
|
||||
cube="replicate", pe="column_wise",
|
||||
num_cubes=1,
|
||||
),
|
||||
name="hello_in",
|
||||
)
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
|
||||
|
||||
Each PE sends its tile to the E neighbor and receives one tile from W,
|
||||
then stores the received tile back into its own HBM slice. The simplest
|
||||
possible demonstration of ``tl.send`` / ``tl.recv``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem,)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, tl):
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Send our local HBM tile to the E neighbor.
|
||||
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=src)
|
||||
|
||||
# Receive a tile from W and store it into our slice (overwrite).
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.store(pe_addr, recv)
|
||||
@@ -1,192 +0,0 @@
|
||||
"""Hierarchical all-reduce kernel (ADR-0023).
|
||||
|
||||
3-level reduce + broadcast exploiting the topology hierarchy:
|
||||
|
||||
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
|
||||
Bidirectional ring reduce to PE 0.
|
||||
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
|
||||
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
|
||||
Level 3 — Inter-SIP (2 SIPs, parent):
|
||||
Pair exchange between SIP representatives.
|
||||
Broadcast — Reverse chain through levels 2 and 1.
|
||||
|
||||
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
|
||||
right-half sends via dir_inc (wrapping). Representative receives from
|
||||
both sides. Rounds per level = ceil((group_size - 1) / 2).
|
||||
|
||||
Direction pairing (ring):
|
||||
Send via dir_dec at PE K → recv via dir_inc at PE K-1
|
||||
Send via dir_inc at PE K → recv via dir_dec at PE K+1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Positional kernel args for the ahbm backend."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
|
||||
|
||||
|
||||
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
|
||||
"""Build the 3-level neighbor map."""
|
||||
pes_per_cube = 8
|
||||
num_sips = max(1, world_size // 128) if world_size > 128 else 1
|
||||
cubes_per_sip = world_size // (pes_per_cube * num_sips)
|
||||
|
||||
pe_id = rank % pes_per_cube
|
||||
cube_global = rank // pes_per_cube
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
result = {}
|
||||
|
||||
# Level 1: intra-cube ring (E/W, all PEs)
|
||||
cube_base = cube_global * pes_per_cube
|
||||
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
|
||||
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
|
||||
|
||||
# Level 2: inter-cube ring (N/S, PE 0 only)
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
sip_base = sip_id * cubes_per_sip * pes_per_cube
|
||||
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
|
||||
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
|
||||
result["N"] = next_cube_pe0
|
||||
result["S"] = prev_cube_pe0
|
||||
|
||||
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
|
||||
result["parent"] = other_sip_pe0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
|
||||
"""Bidirectional ring reduce to node 0.
|
||||
|
||||
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
|
||||
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
|
||||
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
|
||||
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
|
||||
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
|
||||
|
||||
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
|
||||
send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
half = group_size // 2
|
||||
|
||||
if my_id == 0:
|
||||
# Representative: recv left-half sum via dir_inc (from PE 1)
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
|
||||
if group_size - half - 1 >= 1:
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
elif my_id <= half:
|
||||
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
|
||||
if my_id < half: # not the far-edge
|
||||
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_dec, src=acc)
|
||||
|
||||
else:
|
||||
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
|
||||
if my_id > half + 1: # not the near-edge
|
||||
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
|
||||
"""Linear chain broadcast from node 0 via dir_inc.
|
||||
|
||||
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
|
||||
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
|
||||
|
||||
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
|
||||
"""
|
||||
if group_size <= 1:
|
||||
return acc
|
||||
|
||||
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
|
||||
# dir_dec is the "other" direction. We infer it from the ring:
|
||||
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
|
||||
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
|
||||
|
||||
if my_id == 0:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
|
||||
if my_id < group_size - 1:
|
||||
tl.send(dir=dir_inc, src=acc)
|
||||
return acc
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
|
||||
"""Hierarchical all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA).
|
||||
n_elem: f16 elements per tile.
|
||||
pes_per_cube: PEs per cube (typically 8).
|
||||
cubes_per_sip: cubes per SIP (typically 16).
|
||||
num_sips: number of SIPs (typically 2).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
pe_id = tl.program_id(axis=0)
|
||||
cube_global = tl.program_id(axis=1)
|
||||
sip_id = cube_global // cubes_per_sip
|
||||
local_cube_id = cube_global % cubes_per_sip
|
||||
|
||||
rank = cube_global * pes_per_cube + pe_id
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
shape = (n_elem,)
|
||||
dtype = "f16"
|
||||
|
||||
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _bidir_reduce(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
|
||||
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
|
||||
tl.send(dir="parent", src=acc)
|
||||
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
|
||||
acc = acc + recv
|
||||
|
||||
# ── Broadcast back ──
|
||||
|
||||
# Level 2: cube 0 PE 0 → all PE 0s via chain
|
||||
if pe_id == 0 and cubes_per_sip > 1:
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
|
||||
dir_inc="N", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
# Level 1: PE 0 → all PEs in cube via chain
|
||||
acc = _chain_broadcast(
|
||||
tl, acc, my_id=pe_id, group_size=pes_per_cube,
|
||||
dir_inc="E", shape=shape, dtype=dtype,
|
||||
)
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
|
||||
|
||||
Reduces across the 4×4 cube mesh within each SIP, then exchanges
|
||||
between SIPs using the configured SIP topology, and broadcasts back.
|
||||
|
||||
Supported SIP topologies (selected via ``sip_topo_kind``):
|
||||
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
|
||||
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
|
||||
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
|
||||
|
||||
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
SIP_TOPO_RING = 0
|
||||
SIP_TOPO_TORUS = 1
|
||||
SIP_TOPO_MESH = 2
|
||||
|
||||
TOPO_NAME_TO_KIND = {
|
||||
"ring_1d": SIP_TOPO_RING,
|
||||
"torus_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d": SIP_TOPO_TORUS,
|
||||
"mesh_2d_no_wrap": SIP_TOPO_MESH,
|
||||
}
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
cube_w = 4
|
||||
cube_h = 4
|
||||
return (n_elem, cube_w, cube_h, world_size)
|
||||
|
||||
|
||||
def _inter_sip_ring(acc, n_sips, n_elem, tl):
|
||||
current = acc
|
||||
for _ in range(n_sips - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
# Row ring (global_E / global_W)
|
||||
current = acc
|
||||
for _ in range(sip_topo_w - 1):
|
||||
tl.send(dir="global_E", src=current)
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
# Col ring (global_S / global_N)
|
||||
current = acc
|
||||
for _ in range(sip_topo_h - 1):
|
||||
tl.send(dir="global_S", src=current)
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
return acc
|
||||
|
||||
|
||||
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
|
||||
sip_row = sip_rank // sip_topo_w
|
||||
sip_col = sip_rank % sip_topo_w
|
||||
|
||||
# Row reduce W → E
|
||||
if sip_col == 0:
|
||||
tl.send(dir="global_E", src=acc)
|
||||
elif sip_col < sip_topo_w - 1:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Row broadcast E → W
|
||||
if sip_col == sip_topo_w - 1:
|
||||
tl.send(dir="global_W", src=acc)
|
||||
elif sip_col > 0:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Col reduce N → S
|
||||
if sip_row == 0:
|
||||
tl.send(dir="global_S", src=acc)
|
||||
elif sip_row < sip_topo_h - 1:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="global_S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# Col broadcast S → N
|
||||
if sip_row == sip_topo_h - 1:
|
||||
tl.send(dir="global_N", src=acc)
|
||||
elif sip_row > 0:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="global_N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
def allreduce_intercube_multidevice(
|
||||
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
|
||||
):
|
||||
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
|
||||
|
||||
Args:
|
||||
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
|
||||
n_elem: f16 elements per cube tile.
|
||||
cube_w: cube mesh width (columns).
|
||||
cube_h: cube mesh height (rows).
|
||||
n_sips: number of SIPs.
|
||||
sip_rank: this SIP's rank (0-based).
|
||||
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
|
||||
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
|
||||
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
|
||||
tl: TLContext (auto-injected).
|
||||
"""
|
||||
cube_id = tl.program_id(axis=1)
|
||||
row = cube_id // cube_w
|
||||
col = cube_id % cube_w
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + cube_id * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 1: row reduce W → E ──
|
||||
if col == 0:
|
||||
tl.send(dir="E", src=acc)
|
||||
elif col < cube_w - 1:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="E", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 2: col reduce N → S on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == 0:
|
||||
tl.send(dir="S", src=acc)
|
||||
elif row < cube_h - 1:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
tl.send(dir="S", src=acc)
|
||||
else:
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
# ── Phase 3: inter-SIP exchange on root cube ──
|
||||
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
|
||||
if cube_id == root_cube and n_sips > 1:
|
||||
if sip_topo_kind == SIP_TOPO_RING:
|
||||
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_TORUS:
|
||||
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
elif sip_topo_kind == SIP_TOPO_MESH:
|
||||
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
|
||||
|
||||
# ── Phase 4: col broadcast S → N on rightmost column ──
|
||||
if col == cube_w - 1:
|
||||
if row == cube_h - 1:
|
||||
tl.send(dir="N", src=acc)
|
||||
elif row > 0:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="N", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
|
||||
|
||||
# ── Phase 5: row broadcast E → W ──
|
||||
if col == cube_w - 1:
|
||||
tl.send(dir="W", src=acc)
|
||||
elif col > 0:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="W", src=acc)
|
||||
else:
|
||||
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
kernel = allreduce_intercube_multidevice
|
||||
@@ -1,73 +0,0 @@
|
||||
"""2D-mesh all-reduce kernel (ADR-0023).
|
||||
|
||||
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
|
||||
1. Row reduce: ring all-reduce along E/W within each row.
|
||||
2. Column reduce: ring all-reduce along N/S within each column.
|
||||
|
||||
After both phases, every rank holds the global sum.
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. Math/recv
|
||||
handles are passed directly to the next send, avoiding store→reload
|
||||
which doesn't propagate correctly with timing-only Phase 1 math.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Mesh all-reduce requires ``world_size`` to be a perfect square —
|
||||
the mesh side length is ``sqrt(world_size)``.
|
||||
"""
|
||||
side = int(round(math.sqrt(world_size)))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_allreduce requires a square world_size; got {world_size}"
|
||||
)
|
||||
return (n_elem, side)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, side, tl):
|
||||
"""All-reduce on a square mesh.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA shared across ranks)
|
||||
n_elem: number of f16 elements per tile
|
||||
side: mesh side length (sqrt(world_size))
|
||||
tl: TLContext (ADR-0022).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
# ── Phase 1: row ring (E direction) ──
|
||||
# Ring forwards each received tile (not the cumulative acc) so every
|
||||
# tile passes through every rank exactly once.
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
|
||||
# store/reload here — the math handle's scratch addr is the source for
|
||||
# the first column send and Phase 2 ipcq_copy replays from there.
|
||||
current = acc
|
||||
|
||||
# ── Phase 2: column ring (S direction) ──
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="S", src=current)
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
|
||||
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
|
||||
of all PE tiles.
|
||||
|
||||
Strategy
|
||||
--------
|
||||
Each PE starts with its own tile in HBM. The kernel:
|
||||
1. Loads the local tile into a TensorHandle (the accumulator).
|
||||
2. In each of ``world_size - 1`` rounds:
|
||||
- Sends the current accumulator/recv slot to the E neighbor.
|
||||
- Receives a tile from the W neighbor — the recv handle points
|
||||
into the per-direction TCM slot.
|
||||
- Adds the received tile to the accumulator using the TensorHandle
|
||||
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
|
||||
3. Stores the final accumulator back to HBM via tl.store. The store is
|
||||
recorded in op_log with both src and dst, so Phase 2 will copy the
|
||||
replayed math result from PE-local scratch into HBM.
|
||||
|
||||
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
|
||||
not yet computed, so the accumulator data flowing through Phase 1 may
|
||||
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
|
||||
in stable t_start order, producing correct final HBM contents.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
|
||||
"""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Ring all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address of the column-sharded tensor — all PEs
|
||||
share this base. The per-PE slice lives at
|
||||
``t_ptr + global_rank * n_elem * 2``.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
|
||||
global rank from ``program_id(axis=0)`` (local PE) and
|
||||
``program_id(axis=1)`` (cube id):
|
||||
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2 # f16
|
||||
|
||||
# Each PE reads from its own slice of the shared base address
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Load the local tile — handle points at HBM[pe_addr].
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
# The ring forwards each received tile to the next neighbor (NOT the
|
||||
# cumulative accumulator), so every rank's tile passes through every
|
||||
# rank exactly once. The accumulator sums the new arrival each round.
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
|
||||
# numpy in Phase 2 via DataExecutor). The result handle lives at
|
||||
# an auto-allocated PE-local scratch addr.
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
# Final result back to this PE's HBM slice. Op_log captures the
|
||||
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
|
||||
# accumulated value into HBM for verification.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Two-phase binary tree all-reduce:
|
||||
|
||||
Phase 1 (reduce up):
|
||||
- leaf nodes send their value to ``parent``
|
||||
- internal nodes recv from each child, sum, then send to ``parent``
|
||||
- root accumulates child contributions; final acc holds global sum
|
||||
|
||||
Phase 2 (broadcast down):
|
||||
- root sends acc to ``child_left`` and ``child_right`` (if present)
|
||||
- internal nodes recv from ``parent``, then forward to children
|
||||
- all ranks store the final acc to HBM
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. The kernel
|
||||
deliberately avoids the store→reload→send pattern: math/recv handles
|
||||
are passed directly to the next send so PE_DMA snapshots a deterministic
|
||||
source addr that Phase 2 can replay.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Tree all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Compute children/parent existence (matches tree_binary topology generator)
|
||||
has_parent = rank > 0
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
has_left = left < world_size
|
||||
has_right = right < world_size
|
||||
|
||||
# ── Phase 1: reduce up ──
|
||||
if has_left:
|
||||
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if has_right:
|
||||
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
if has_parent:
|
||||
# Send the math/load handle directly — its addr is either the
|
||||
# original HBM tile (leaf) or the PE-local scratch where the
|
||||
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
|
||||
tl.send(dir="parent", src=acc)
|
||||
|
||||
# ── Phase 2: broadcast down ──
|
||||
if has_parent:
|
||||
# Replace acc with the value broadcast from the parent (the global
|
||||
# sum). The recv handle points at the parent-direction TCM slot.
|
||||
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
|
||||
|
||||
if has_left:
|
||||
tl.send(dir="child_left", src=acc)
|
||||
if has_right:
|
||||
tl.send(dir="child_right", src=acc)
|
||||
|
||||
# Final store to HBM for the bench's verification path.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -219,7 +219,11 @@ def install_ipcq(
|
||||
"neighbor_table": neighbor_table,
|
||||
}
|
||||
|
||||
_OPPOSITE_DIR = {"E": "W", "W": "E", "N": "S", "S": "N"}
|
||||
_OPPOSITE_DIR = {
|
||||
"E": "W", "W": "E", "N": "S", "S": "N",
|
||||
"global_E": "global_W", "global_W": "global_E",
|
||||
"global_N": "global_S", "global_S": "global_N",
|
||||
}
|
||||
|
||||
def reverse_direction(my_rank: int, peer_rank: int, my_dir: str) -> str | None:
|
||||
"""Find peer's direction that reciprocates my_dir→peer_rank.
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""SFR configuration for intercube + inter-SIP IPCQ wiring.
|
||||
|
||||
Provides ``configure_sfr_intercube_multisip`` which programs PE_IPCQ
|
||||
neighbor tables for:
|
||||
|
||||
1. Intercube within each SIP — pe0 of every cube connects to pe0 of
|
||||
its N/S/E/W mesh neighbors (no wrap-around).
|
||||
2. Inter-SIP on ALL cubes — pe0 of cube_c on sip_A connects to pe0 of
|
||||
cube_c on each peer SIP, using ``global_E``/``global_W`` (ring) or
|
||||
``global_N``/``global_S``/``global_E``/``global_W`` (mesh/torus)
|
||||
direction labels. Wiring all cubes allows the kernel to
|
||||
dynamically elect the root cube at runtime.
|
||||
|
||||
SIP-level topology is read from ``topology.yaml`` →
|
||||
``system.sips.topology`` (e.g. ``ring_1d``, ``mesh_2d``).
|
||||
Intercube mesh dimensions come from ``sip.cube_mesh.w/h``.
|
||||
|
||||
Internally delegates to ``install_ipcq`` with a computed ``rank_to_pe``
|
||||
(pe0-only) and a closure-captured ``neighbors()`` function.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
from kernbench.ccl.install import install_ipcq
|
||||
from kernbench.ccl.topologies import _BUILTIN as _TOPO_BUILTINS
|
||||
|
||||
|
||||
def configure_sfr_intercube_multisip(
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
cfg: dict,
|
||||
) -> dict[str, Any]:
|
||||
"""Wire IPCQ for intercube (pe0, mesh) + inter-SIP (pe0, all cubes).
|
||||
|
||||
Args:
|
||||
engine: GraphEngine with ``_components``.
|
||||
spec: topology spec dict (from topology.yaml).
|
||||
cfg: merged algorithm config (from ``resolve_algorithm_config``).
|
||||
|
||||
Returns:
|
||||
The install plan dict from ``install_ipcq``.
|
||||
"""
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
mesh_w = int(cm["w"])
|
||||
mesh_h = int(cm["h"])
|
||||
n_cubes = mesh_w * mesh_h
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
sip_topology = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
|
||||
if sip_topology not in _TOPO_BUILTINS:
|
||||
raise ValueError(
|
||||
f"Unknown sip topology '{sip_topology}'. "
|
||||
f"Available: {list(_TOPO_BUILTINS)}"
|
||||
)
|
||||
sip_topo_fn = _TOPO_BUILTINS[sip_topology]
|
||||
|
||||
world_size = n_sips * n_cubes
|
||||
pe_idx_to_pe: list[tuple[int, int, int]] = [
|
||||
(sip, cube, 0)
|
||||
for sip in range(n_sips)
|
||||
for cube in range(n_cubes)
|
||||
]
|
||||
|
||||
def _neighbors(pe_idx: int, ws: int, _base: dict) -> dict[str, int]:
|
||||
sip = pe_idx // n_cubes
|
||||
cube = pe_idx % n_cubes
|
||||
row = cube // mesh_w
|
||||
col = cube % mesh_w
|
||||
|
||||
nbrs: dict[str, int] = {}
|
||||
|
||||
# Intercube within SIP (mesh, no wrap-around)
|
||||
if col < mesh_w - 1:
|
||||
nbrs["E"] = sip * n_cubes + (row * mesh_w + col + 1)
|
||||
if col > 0:
|
||||
nbrs["W"] = sip * n_cubes + (row * mesh_w + col - 1)
|
||||
if row < mesh_h - 1:
|
||||
nbrs["S"] = sip * n_cubes + ((row + 1) * mesh_w + col)
|
||||
if row > 0:
|
||||
nbrs["N"] = sip * n_cubes + ((row - 1) * mesh_w + col)
|
||||
|
||||
# Inter-SIP on ALL cubes
|
||||
if n_sips > 1:
|
||||
sip_nbrs = sip_topo_fn(sip, n_sips)
|
||||
for d, peer_sip in sip_nbrs.items():
|
||||
nbrs[f"global_{d}"] = peer_sip * n_cubes + cube
|
||||
|
||||
return nbrs
|
||||
|
||||
mock_module = types.SimpleNamespace(neighbors=_neighbors)
|
||||
|
||||
cfg_copy = dict(cfg)
|
||||
cfg_copy["world_size"] = world_size
|
||||
cfg_copy["topology"] = "none"
|
||||
|
||||
return install_ipcq(
|
||||
engine, spec, cfg_copy,
|
||||
algo_module=mock_module,
|
||||
rank_to_pe=pe_idx_to_pe,
|
||||
)
|
||||
@@ -1,492 +0,0 @@
|
||||
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
|
||||
|
||||
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
|
||||
no PE_DMA, no fabric simulation. Just enough to verify *functional*
|
||||
correctness of an IPCQ-based collective algorithm.
|
||||
|
||||
Cross-rank send/recv is implemented with greenlet cooperative scheduling
|
||||
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
|
||||
queues are unbounded.
|
||||
|
||||
Typical usage in a test::
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
from kernbench.ccl.algorithms.ring_allreduce import kernel
|
||||
|
||||
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel, world_size=4, topology="ring_1d",
|
||||
inputs=inputs, kernel_args=(16,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], sum(inputs))
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.ccl.topologies import resolve_topology
|
||||
from kernbench.common.ipcq_types import IpcqInvalidDirection
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
|
||||
# ── Per-rank fake state ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockRankState:
|
||||
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
neighbors: dict[str, int],
|
||||
input_arr: np.ndarray,
|
||||
pes_per_cube: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
|
||||
# all ranks are in one cube (legacy single-cube behavior).
|
||||
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
|
||||
self.neighbors = neighbors # direction → peer rank
|
||||
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
|
||||
self._hbm: dict[int, np.ndarray] = {}
|
||||
self._tcm: dict[int, np.ndarray] = {}
|
||||
# ``t_ptr`` is the address the kernel sees. Real benches use a
|
||||
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
|
||||
# Mirror that here: each rank's slice lives at the rank-specific addr.
|
||||
nbytes = int(input_arr.nbytes)
|
||||
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
|
||||
self._slice_addr = rank * nbytes
|
||||
self._hbm[self._slice_addr] = input_arr.copy()
|
||||
# Inbound recv FIFOs: direction → deque[ndarray]
|
||||
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
|
||||
# Output (set when kernel calls tl.store at slice address)
|
||||
self.output: np.ndarray | None = None
|
||||
# Greenlet for this rank — set later
|
||||
self.g: greenlet | None = None
|
||||
|
||||
|
||||
# ── Mock TLContext ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockTL:
|
||||
"""Drop-in tl shim for mock runtime.
|
||||
|
||||
Supports the subset of TLContext API that algorithm authors use:
|
||||
program_id, num_programs, load, store, send, recv, recv_async, wait,
|
||||
plus arithmetic operations on TensorHandle (eager numpy execution,
|
||||
no SimPy involved).
|
||||
"""
|
||||
|
||||
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
|
||||
self._state = state
|
||||
self._scheduler = scheduler
|
||||
self._handle_counter = 0
|
||||
|
||||
def _next_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"mt{self._handle_counter}"
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._state.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._state.world_size
|
||||
|
||||
# axis-aware
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
# Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
|
||||
# Falls back to flat (all ranks in one cube) if pes_per_cube
|
||||
# is not set (legacy single-cube tests).
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.rank // ppc
|
||||
return self._state.rank % ppc
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
ppc = self._state.pes_per_cube
|
||||
if axis == 1:
|
||||
return self._state.world_size // ppc
|
||||
return ppc
|
||||
|
||||
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
|
||||
|
||||
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
if a_data is None or b_data is None:
|
||||
result = None
|
||||
elif op == "add":
|
||||
result = a_data + b_data
|
||||
elif op == "sub":
|
||||
result = a_data - b_data
|
||||
elif op == "mul":
|
||||
result = a_data * b_data
|
||||
elif op == "div":
|
||||
result = a_data / b_data
|
||||
elif op == "maximum":
|
||||
result = np.maximum(a_data, b_data)
|
||||
elif op == "minimum":
|
||||
result = np.minimum(a_data, b_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("maximum", a, b)
|
||||
|
||||
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("minimum", a, b)
|
||||
|
||||
def fma(
|
||||
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
c_data = np.asarray(c.data) if c.data is not None else None
|
||||
result = (
|
||||
a_data * b_data + c_data
|
||||
if (a_data is not None and b_data is not None and c_data is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def clamp(
|
||||
self,
|
||||
x: TensorHandle,
|
||||
min: TensorHandle,
|
||||
max: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
lo = np.asarray(min.data) if min.data is not None else None
|
||||
hi = np.asarray(max.data) if max.data is not None else None
|
||||
result = (
|
||||
np.minimum(np.maximum(x_data, lo), hi)
|
||||
if (x_data is not None and lo is not None and hi is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
else:
|
||||
x_max = np.max(x_data, axis=axis, keepdims=True)
|
||||
e = np.exp(x_data - x_max)
|
||||
s = np.sum(e, axis=axis, keepdims=True)
|
||||
result = e / s
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(-int(a) // int(b))
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
elif op == "exp":
|
||||
result = np.exp(x_data)
|
||||
elif op == "log":
|
||||
result = np.log(x_data)
|
||||
elif op == "sqrt":
|
||||
result = np.sqrt(x_data)
|
||||
elif op == "abs":
|
||||
result = np.abs(x_data)
|
||||
elif op == "sigmoid":
|
||||
result = 1.0 / (1.0 + np.exp(-x_data))
|
||||
elif op == "cos":
|
||||
result = np.cos(x_data)
|
||||
elif op == "sin":
|
||||
result = np.sin(x_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
data = self._state._hbm.get(ptr)
|
||||
if data is None:
|
||||
data = np.zeros(shape, dtype=np.float16)
|
||||
return TensorHandle(
|
||||
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
|
||||
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
|
||||
)
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
if handle.data is not None:
|
||||
self._state._hbm[ptr] = np.asarray(handle.data)
|
||||
if ptr == self._state._slice_addr:
|
||||
self._state.output = self._state._hbm[ptr]
|
||||
|
||||
# IPCQ
|
||||
def send(
|
||||
self,
|
||||
dir: str,
|
||||
src: TensorHandle | None = None,
|
||||
*,
|
||||
src_addr: int | None = None,
|
||||
nbytes: int | None = None,
|
||||
shape: tuple[int, ...] | None = None,
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
) -> None:
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
if src is not None:
|
||||
if src.data is not None:
|
||||
data = np.asarray(src.data)
|
||||
else:
|
||||
# Resolve from this rank's local memory at src.addr
|
||||
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
|
||||
stored = space_dict.get(src.addr)
|
||||
if stored is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
|
||||
)
|
||||
data = np.asarray(stored)
|
||||
else:
|
||||
data = None
|
||||
if data is None:
|
||||
raise RuntimeError("mock tl.send: src is None")
|
||||
peer_rank = self._state.neighbors[dir]
|
||||
# Find the reverse direction at the peer, mirroring real IPCQ
|
||||
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
|
||||
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
|
||||
"parent": "parent", "child_left": "child_left",
|
||||
"child_right": "child_right"}
|
||||
peer_state = self._scheduler.states[peer_rank]
|
||||
reverse_dir = _REVERSE.get(dir)
|
||||
# Fall back to "first direction pointing at me" if the explicit
|
||||
# reverse doesn't exist at the peer (e.g. custom directions).
|
||||
if reverse_dir is None or reverse_dir not in peer_state.neighbors:
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
if reverse_dir is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
|
||||
)
|
||||
peer_state.recv_q[reverse_dir].append(data.copy())
|
||||
self._scheduler._send_counter += 1
|
||||
# After delivering, hand control back to scheduler so the receiver
|
||||
# can wake up.
|
||||
self._scheduler.yield_()
|
||||
|
||||
def recv_async(
|
||||
self,
|
||||
dir: str,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> dict:
|
||||
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv_async: direction {dir!r} not in neighbors"
|
||||
)
|
||||
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
|
||||
|
||||
def wait(self, future: Any) -> TensorHandle:
|
||||
"""Block until the recv future has data."""
|
||||
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
|
||||
raise TypeError("tl.wait: expected recv future from tl.recv_async")
|
||||
d = future["dir"]
|
||||
while not self._state.recv_q[d]:
|
||||
self._scheduler.yield_()
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, future["dtype"])
|
||||
|
||||
def recv(
|
||||
self,
|
||||
dir: str | None = None,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
if dir is not None and dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
# Wait for data
|
||||
while True:
|
||||
if dir is None:
|
||||
# round-robin over directions
|
||||
for d in self._state.neighbors:
|
||||
if self._state.recv_q[d]:
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, dtype)
|
||||
else:
|
||||
if self._state.recv_q[dir]:
|
||||
data = self._state.recv_q[dir].popleft()
|
||||
return self._make_handle(data, dir, dtype)
|
||||
# Yield to other ranks
|
||||
self._scheduler.yield_()
|
||||
|
||||
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=f"recv_{direction}",
|
||||
addr=0, shape=data.shape, dtype=dtype,
|
||||
nbytes=int(data.nbytes), data=data, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
# ── Cooperative scheduler ────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockScheduler:
|
||||
"""Round-robin cooperative scheduler over rank greenlets."""
|
||||
|
||||
def __init__(self, states: list[_MockRankState]) -> None:
|
||||
self.states = states
|
||||
self._parent: greenlet | None = None
|
||||
self._cur_idx = 0
|
||||
|
||||
def yield_(self) -> None:
|
||||
"""Called from inside a rank greenlet to give other ranks a turn."""
|
||||
assert self._parent is not None
|
||||
self._parent.switch()
|
||||
|
||||
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
self._parent = greenlet.getcurrent()
|
||||
n = len(self.states)
|
||||
|
||||
# Per-rank tl shim
|
||||
tls: dict[int, _MockTL] = {}
|
||||
|
||||
def _spawn(rank_idx: int) -> greenlet:
|
||||
state = self.states[rank_idx]
|
||||
tl = _MockTL(state, self)
|
||||
tls[rank_idx] = tl
|
||||
|
||||
def _entry():
|
||||
# Activate this rank's tl for TensorHandle operator overloads
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
|
||||
return greenlet(_entry)
|
||||
|
||||
for state in self.states:
|
||||
state.g = _spawn(state.rank)
|
||||
|
||||
# Drive each rank round-robin until all dead. Detect global deadlock.
|
||||
# A global send counter tracks whether any greenlet delivered data
|
||||
# in the current round. This is more reliable than queue-depth
|
||||
# tracking because a recv+send pair in the same round nets to zero
|
||||
# depth change yet still represents real progress.
|
||||
self._send_counter = 0
|
||||
max_idle_rounds = 10_000
|
||||
idle_rounds = 0
|
||||
while True:
|
||||
alive = [s for s in self.states if s.g is not None and not s.g.dead]
|
||||
if not alive:
|
||||
break
|
||||
counter_before = self._send_counter
|
||||
for s in self.states:
|
||||
if s.g is None or s.g.dead:
|
||||
continue
|
||||
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
|
||||
s.g.switch()
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
any_died = any(s.g is not None and s.g.dead for s in self.states)
|
||||
if self._send_counter > counter_before or any_died:
|
||||
idle_rounds = 0
|
||||
else:
|
||||
idle_rounds += 1
|
||||
if idle_rounds >= max_idle_rounds:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_idle_rounds} rounds)"
|
||||
)
|
||||
|
||||
return [
|
||||
s.output if s.output is not None else s._hbm.get(s._slice_addr)
|
||||
for s in self.states
|
||||
]
|
||||
|
||||
|
||||
# ── Public entry ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_kernel_in_mock(
|
||||
kernel_fn: Callable,
|
||||
world_size: int,
|
||||
topology: str,
|
||||
inputs: list[np.ndarray],
|
||||
kernel_args: tuple = (),
|
||||
algo_module: Any | None = None,
|
||||
pes_per_cube: int = 0,
|
||||
) -> list[np.ndarray]:
|
||||
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
|
||||
|
||||
Args:
|
||||
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
|
||||
world_size: number of ranks
|
||||
topology: builtin topology name (e.g. "ring_1d")
|
||||
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
|
||||
local tile at HBM address 0.
|
||||
kernel_args: extra positional args after t_ptr
|
||||
algo_module: optional module providing ``neighbors()`` override
|
||||
pes_per_cube: PEs per cube for multi-cube program_id mapping.
|
||||
0 → single-cube legacy (all ranks in one cube).
|
||||
|
||||
Returns:
|
||||
Per-rank output ndarrays — whatever the kernel wrote via tl.store
|
||||
(or the original input if the kernel didn't store).
|
||||
"""
|
||||
if len(inputs) != world_size:
|
||||
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
|
||||
|
||||
topo_fn = resolve_topology(topology, algo_module=algo_module)
|
||||
states = [
|
||||
_MockRankState(
|
||||
rank=r, world_size=world_size,
|
||||
neighbors=topo_fn(r, world_size),
|
||||
input_arr=inputs[r],
|
||||
pes_per_cube=pes_per_cube,
|
||||
)
|
||||
for r in range(world_size)
|
||||
]
|
||||
|
||||
sched = _MockScheduler(states)
|
||||
return sched.run(kernel_fn, kernel_args)
|
||||
@@ -73,6 +73,39 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap:
|
||||
return n
|
||||
|
||||
|
||||
def torus_2d(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D torus (N/S/E/W) with wrap-around on all edges.
|
||||
|
||||
Alias for mesh_2d (which already wraps). Explicit name for clarity
|
||||
when used as a SIP-level topology.
|
||||
"""
|
||||
return mesh_2d(rank, world_size)
|
||||
|
||||
|
||||
def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D mesh (N/S/E/W) WITHOUT wrap-around.
|
||||
|
||||
Edge nodes have fewer neighbors (no wrapping). Used for SIP-level
|
||||
topologies where physical links don't wrap.
|
||||
"""
|
||||
side = int(round(world_size ** 0.5))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_2d_no_wrap requires square world_size, got {world_size}"
|
||||
)
|
||||
r, c = divmod(rank, side)
|
||||
n: NeighborMap = {}
|
||||
if r > 0:
|
||||
n["N"] = (r - 1) * side + c
|
||||
if r < side - 1:
|
||||
n["S"] = (r + 1) * side + c
|
||||
if c > 0:
|
||||
n["W"] = r * side + (c - 1)
|
||||
if c < side - 1:
|
||||
n["E"] = r * side + (c + 1)
|
||||
return n
|
||||
|
||||
|
||||
def none(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Empty map — algorithm's neighbors() must build from scratch."""
|
||||
return {}
|
||||
@@ -82,6 +115,8 @@ _BUILTIN: dict[str, TopologyFn] = {
|
||||
"ring_1d": ring_1d,
|
||||
"ring_1d_unidir": ring_1d_unidir,
|
||||
"mesh_2d": mesh_2d,
|
||||
"torus_2d": torus_2d,
|
||||
"mesh_2d_no_wrap": mesh_2d_no_wrap,
|
||||
"tree_binary": tree_binary,
|
||||
"none": none,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
"""Data-parallel placement policy (ADR-0026: intra-device only).
|
||||
|
||||
``DPPolicy`` describes how a tensor is sharded *within a single SIP* across
|
||||
that SIP's cubes and PEs. Crossing the SIP boundary is not a DPPolicy
|
||||
concern: ADR-0024's ``torch.ahbm.set_device(rank)`` picks the SIP, and
|
||||
Megatron-style TP (ADR-0027) expresses multi-SIP tensors when needed.
|
||||
|
||||
``ShardSpec`` is expressed in structural ``(sip, cube, pe)`` coordinates.
|
||||
The former flat ``pe_index`` field/property is fully removed — callers
|
||||
needing a flat integer key compute it explicitly at the call site.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -7,25 +18,58 @@ from typing import Literal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DPPolicy:
|
||||
"""Three-level data-parallel policy: sip-level + cube-level + pe-level.
|
||||
"""Intra-device (cube × PE) data-parallel policy.
|
||||
|
||||
Policies:
|
||||
SIP-level placement is controlled by ``torch.ahbm.set_device(rank)``
|
||||
(ADR-0024). For tensors that must cross SIP boundaries, use
|
||||
Megatron-style parallel layers (ADR-0027). DPPolicy itself never
|
||||
crosses a SIP boundary.
|
||||
|
||||
Policies (per axis):
|
||||
- "replicate": full copy at each unit
|
||||
- "column_wise": split K (column) axis across units
|
||||
- "row_wise": split M (row) axis across units
|
||||
|
||||
Optional overrides (default None = use topology dimensions):
|
||||
- num_pes: override PEs per cube (e.g., 1 for single-PE test)
|
||||
- num_cubes: override cubes per SIP (e.g., 1 for single-cube test)
|
||||
- num_sips: override SIP count
|
||||
Optional overrides (``None`` = use topology dimensions):
|
||||
- num_pes: override PEs per cube
|
||||
- num_cubes: override cubes per SIP
|
||||
"""
|
||||
|
||||
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
num_pes: int | None = None
|
||||
num_cubes: int | None = None
|
||||
num_sips: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShardSpec:
|
||||
"""Structural shard placement — ``(sip, cube, pe)`` coord (ADR-0026).
|
||||
|
||||
Global-flat ``pe_index`` was removed: callers must use structural
|
||||
coords directly. If a flat integer key is needed in a local context
|
||||
(e.g. internal dict lookup), compute it explicitly at the call site
|
||||
and do not expose it in any public API.
|
||||
"""
|
||||
|
||||
sip: int
|
||||
cube: int
|
||||
pe: int
|
||||
offset_bytes: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _LocalPeShard:
|
||||
"""Internal — PE resolver's return type (ADR-0026 D3).
|
||||
|
||||
Holds a cube-local PE identifier (``local_pe``) plus the shard's
|
||||
byte payload. Lifted into ``ShardSpec`` with full ``(sip, cube, pe)``
|
||||
coordinates inside ``resolve_dp_policy``.
|
||||
"""
|
||||
|
||||
local_pe: int
|
||||
offset_bytes: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
def _split_shape(
|
||||
@@ -52,14 +96,13 @@ def resolve_dp_policy(
|
||||
itemsize: int,
|
||||
num_pe: int,
|
||||
num_cubes: int = 1,
|
||||
num_sips: int = 1,
|
||||
target_sip: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""Resolve a DPPolicy into a list[ShardSpec] with three-level resolution.
|
||||
"""Resolve a DPPolicy into a list[ShardSpec] on a single SIP.
|
||||
|
||||
SIP-level → cube-level → pe-level.
|
||||
num_cubes is cubes per SIP (not total).
|
||||
ShardSpec.pe_index uses flat indexing:
|
||||
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
|
||||
Two-level resolution (cube × PE) within ``target_sip``. Each returned
|
||||
``ShardSpec`` carries ``sip=target_sip`` and cube/pe local to the SIP.
|
||||
No SIP-level split — DPPolicy is intra-device only (ADR-0026).
|
||||
"""
|
||||
_PE_RESOLVERS = {
|
||||
"replicate": replicate,
|
||||
@@ -70,84 +113,61 @@ def resolve_dp_policy(
|
||||
if resolver is None:
|
||||
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
|
||||
|
||||
cubes_per_sip = num_cubes
|
||||
all_shards: list[ShardSpec] = []
|
||||
|
||||
# Level 1: SIP
|
||||
sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize)
|
||||
|
||||
for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
|
||||
# Level 2: Cube within SIP
|
||||
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
|
||||
# Level 1: cube within SIP
|
||||
cube_splits = _split_shape(policy.cube, shape, num_cubes, itemsize)
|
||||
|
||||
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
|
||||
# Level 3: PE within cube
|
||||
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
|
||||
# Level 2: PE within cube — resolver returns _LocalPeShard
|
||||
local_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
|
||||
|
||||
for ps in pe_shards:
|
||||
flat_idx = (
|
||||
sip_id * cubes_per_sip * num_pe
|
||||
+ cube_id * num_pe
|
||||
+ ps.pe_index
|
||||
)
|
||||
for ls in local_shards:
|
||||
all_shards.append(ShardSpec(
|
||||
pe_index=flat_idx,
|
||||
offset_bytes=sip_offset + cube_offset + ps.offset_bytes,
|
||||
nbytes=ps.nbytes,
|
||||
sip=target_sip,
|
||||
cube=cube_id,
|
||||
pe=ls.local_pe,
|
||||
offset_bytes=cube_offset + ls.offset_bytes,
|
||||
nbytes=ls.nbytes,
|
||||
))
|
||||
|
||||
return all_shards
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShardSpec:
|
||||
pe_index: int
|
||||
offset_bytes: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
def column_wise(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
) -> list[_LocalPeShard]:
|
||||
"""Split K axis into num_pe equal parts. Each PE gets (M, K/P)."""
|
||||
M, K = shape
|
||||
chunk_k = K // num_pe
|
||||
chunk_bytes = M * chunk_k * itemsize
|
||||
shards = []
|
||||
for i in range(num_pe):
|
||||
shards.append(ShardSpec(
|
||||
pe_index=i,
|
||||
offset_bytes=i * chunk_bytes,
|
||||
nbytes=chunk_bytes,
|
||||
))
|
||||
return shards
|
||||
return [
|
||||
_LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes)
|
||||
for i in range(num_pe)
|
||||
]
|
||||
|
||||
|
||||
def row_wise(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
) -> list[_LocalPeShard]:
|
||||
"""Split M axis into num_pe equal parts. Each PE gets (M/P, K)."""
|
||||
M, K = shape
|
||||
chunk_m = M // num_pe
|
||||
chunk_bytes = chunk_m * K * itemsize
|
||||
shards = []
|
||||
for i in range(num_pe):
|
||||
shards.append(ShardSpec(
|
||||
pe_index=i,
|
||||
offset_bytes=i * chunk_bytes,
|
||||
nbytes=chunk_bytes,
|
||||
))
|
||||
return shards
|
||||
return [
|
||||
_LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes)
|
||||
for i in range(num_pe)
|
||||
]
|
||||
|
||||
|
||||
def replicate(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
) -> list[_LocalPeShard]:
|
||||
"""Full copy per PE. Each PE gets (M, K)."""
|
||||
M, K = shape
|
||||
full_bytes = M * K * itemsize
|
||||
return [
|
||||
ShardSpec(pe_index=i, offset_bytes=0, nbytes=full_bytes)
|
||||
_LocalPeShard(local_pe=i, offset_bytes=0, nbytes=full_bytes)
|
||||
for i in range(num_pe)
|
||||
]
|
||||
|
||||
@@ -155,20 +175,20 @@ def replicate(
|
||||
def tiled_column_major(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
tile_m: int, tile_k: int,
|
||||
) -> list[ShardSpec]:
|
||||
) -> list[_LocalPeShard]:
|
||||
"""2D tiling, column-major order (K axis first), round-robin across PEs."""
|
||||
M, K = shape
|
||||
tiles_m = ceil(M / tile_m)
|
||||
tiles_k = ceil(K / tile_k)
|
||||
tile_bytes = tile_m * tile_k * itemsize
|
||||
row_bytes = K * itemsize
|
||||
shards = []
|
||||
shards: list[_LocalPeShard] = []
|
||||
idx = 0
|
||||
for mi in range(tiles_m):
|
||||
for ki in range(tiles_k):
|
||||
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
|
||||
shards.append(ShardSpec(
|
||||
pe_index=idx % num_pe,
|
||||
shards.append(_LocalPeShard(
|
||||
local_pe=idx % num_pe,
|
||||
offset_bytes=offset,
|
||||
nbytes=tile_bytes,
|
||||
))
|
||||
@@ -179,20 +199,20 @@ def tiled_column_major(
|
||||
def tiled_row_major(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
tile_m: int, tile_k: int,
|
||||
) -> list[ShardSpec]:
|
||||
) -> list[_LocalPeShard]:
|
||||
"""2D tiling, row-major order (M axis first), round-robin across PEs."""
|
||||
M, K = shape
|
||||
tiles_m = ceil(M / tile_m)
|
||||
tiles_k = ceil(K / tile_k)
|
||||
tile_bytes = tile_m * tile_k * itemsize
|
||||
row_bytes = K * itemsize
|
||||
shards = []
|
||||
shards: list[_LocalPeShard] = []
|
||||
idx = 0
|
||||
for ki in range(tiles_k):
|
||||
for mi in range(tiles_m):
|
||||
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
|
||||
shards.append(ShardSpec(
|
||||
pe_index=idx % num_pe,
|
||||
shards.append(_LocalPeShard(
|
||||
local_pe=idx % num_pe,
|
||||
offset_bytes=offset,
|
||||
nbytes=tile_bytes,
|
||||
))
|
||||
|
||||
@@ -42,6 +42,59 @@ def _numpy_to_dtype_str(np_dtype) -> str:
|
||||
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
|
||||
|
||||
|
||||
# ADR-0027 D3: weak registry of the currently-active RuntimeContext so
|
||||
# module-level helpers (e.g. ``kernbench.tp.parallel_state``) can resolve
|
||||
# the ctx without threading it through every call.
|
||||
import weakref as _weakref
|
||||
|
||||
_ACTIVE_CTX_REF: _weakref.ref | None = None
|
||||
|
||||
|
||||
def _get_active_context():
|
||||
"""Return the most-recently-entered RuntimeContext, or None."""
|
||||
if _ACTIVE_CTX_REF is None:
|
||||
return None
|
||||
return _ACTIVE_CTX_REF()
|
||||
|
||||
|
||||
class _AhbmNamespace:
|
||||
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
|
||||
|
||||
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
|
||||
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
|
||||
a CUDA runtime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._device_by_greenlet: dict = {}
|
||||
|
||||
def set_device(self, device: int) -> None:
|
||||
from greenlet import getcurrent
|
||||
self._device_by_greenlet[getcurrent()] = int(device)
|
||||
|
||||
def current_device(self) -> int | None:
|
||||
from greenlet import getcurrent
|
||||
return self._device_by_greenlet.get(getcurrent())
|
||||
|
||||
|
||||
class _AcceleratorNamespace:
|
||||
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
|
||||
|
||||
Wraps _AhbmNamespace. Bench code can pick either:
|
||||
torch.ahbm.set_device(rank) # explicit backend
|
||||
torch.accelerator.set_device_index(rank) # portable
|
||||
"""
|
||||
|
||||
def __init__(self, ahbm: "_AhbmNamespace") -> None:
|
||||
self._ahbm = ahbm
|
||||
|
||||
def set_device_index(self, device: int) -> None:
|
||||
self._ahbm.set_device(device)
|
||||
|
||||
def current_device_index(self) -> int | None:
|
||||
return self._ahbm.current_device()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
engine: SimEngine
|
||||
@@ -51,7 +104,11 @@ class RuntimeContext:
|
||||
|
||||
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
# ADR-0027 D0.1: worker-deferred wait queue. When a worker greenlet
|
||||
# calls ctx.wait(h), the handle is appended here and control yields to
|
||||
# main. Main's scheduler drain consumes this list.
|
||||
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
|
||||
_allocators: dict[tuple[int, int, int], Any] = field(default_factory=dict, init=False)
|
||||
_va_allocator: Any = field(default=None, init=False)
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
@@ -67,6 +124,13 @@ class RuntimeContext:
|
||||
dc = DistributedContext()
|
||||
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
|
||||
self.distributed = dc
|
||||
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
|
||||
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
|
||||
self.ahbm = _AhbmNamespace()
|
||||
self.accelerator = _AcceleratorNamespace(self.ahbm)
|
||||
# ADR-0027 D1.3: torch.multiprocessing.spawn namespace.
|
||||
from kernbench.runtime_api.multiprocessing import _MultiprocessingNamespace
|
||||
self.multiprocessing = _MultiprocessingNamespace(self)
|
||||
|
||||
def install_ipcq(
|
||||
self,
|
||||
@@ -118,10 +182,16 @@ class RuntimeContext:
|
||||
return plan
|
||||
|
||||
def __enter__(self):
|
||||
global _ACTIVE_CTX_REF
|
||||
_ACTIVE_CTX_REF = _weakref.ref(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
global _ACTIVE_CTX_REF
|
||||
self.cleanup()
|
||||
# Clear active-context registry if we are it.
|
||||
if _ACTIVE_CTX_REF is not None and _ACTIVE_CTX_REF() is self:
|
||||
_ACTIVE_CTX_REF = None
|
||||
return False
|
||||
|
||||
def submit(self, request: Any) -> RequestHandle:
|
||||
@@ -136,10 +206,24 @@ class RuntimeContext:
|
||||
return handle in self._completed
|
||||
|
||||
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
|
||||
# ADR-0027 D0.2: fast-path for already-completed handles (avoid
|
||||
# redundant worker→main→worker round-trip).
|
||||
if handle in self._completed:
|
||||
completion, trace = self.engine.get_completion(handle)
|
||||
return completion
|
||||
|
||||
# ADR-0027 D0.2: if called from a worker greenlet (parent is main,
|
||||
# not dead), defer the wait to the main scheduler — enqueue and
|
||||
# yield. Main drains env.run, then switches back. On resume the
|
||||
# handle must be in _completed (D0.3 resume invariant).
|
||||
from greenlet import getcurrent
|
||||
g = getcurrent()
|
||||
if g.parent is not None and not g.parent.dead:
|
||||
self._pending_worker_waits.append(handle)
|
||||
g.parent.switch()
|
||||
# Resume: main drained. Fall through to completion/trace assembly.
|
||||
|
||||
# Main context (or single-driver): drive engine directly.
|
||||
wait_fn = getattr(self.engine, "wait", None)
|
||||
if wait_fn is not None:
|
||||
wait_fn(handle) # type: ignore[misc]
|
||||
@@ -228,12 +312,7 @@ class RuntimeContext:
|
||||
# Return PA space
|
||||
if self._allocators:
|
||||
for shard in handle.shards:
|
||||
flat_idx = (
|
||||
shard.sip * self._num_cubes * self._pes_per_cube
|
||||
+ shard.cube * self._pes_per_cube
|
||||
+ shard.pe
|
||||
)
|
||||
alloc = self._allocators.get(flat_idx)
|
||||
alloc = self._allocators.get((shard.sip, shard.cube, shard.pe))
|
||||
if alloc is not None:
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
|
||||
@@ -297,17 +376,15 @@ class RuntimeContext:
|
||||
tcm_scheduler_reserved_bytes=4 * (1 << 20),
|
||||
sram_bytes_per_cube=32 * (1 << 20),
|
||||
)
|
||||
# Create allocators scoped to target SIP(s) only
|
||||
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
|
||||
# Create allocators scoped to target SIP(s) only.
|
||||
# ADR-0026 D5: dict key is the structural (sip, cube, pe) tuple.
|
||||
self._pes_per_cube = pes_per_cube
|
||||
self._num_cubes = cubes_per_sip
|
||||
self._num_sips = sip_count
|
||||
cubes_x_pes = cubes_per_sip * pes_per_cube
|
||||
for sip_id in sip_range:
|
||||
for cube_id in range(cubes_per_sip):
|
||||
for pe_id in range(pes_per_cube):
|
||||
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
|
||||
self._allocators[flat_idx] = PEMemAllocator(
|
||||
self._allocators[(sip_id, cube_id, pe_id)] = PEMemAllocator(
|
||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||
)
|
||||
|
||||
@@ -394,16 +471,23 @@ class RuntimeContext:
|
||||
# DPPolicy overrides take precedence over topology dimensions
|
||||
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
|
||||
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
|
||||
eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips
|
||||
# ADR-0026 D4: resolve structural coords directly at resolve time.
|
||||
# ``torch.ahbm.set_device(rank)`` (ADR-0024 D10) selects the target
|
||||
# SIP; if unset, fall back to SIP 0 for single-driver compatibility.
|
||||
current_sip = (
|
||||
self.ahbm.current_device() if hasattr(self, "ahbm") else None
|
||||
)
|
||||
if current_sip is None:
|
||||
current_sip = 0
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=eff_num_pe, num_cubes=eff_num_cubes,
|
||||
num_sips=eff_num_sips,
|
||||
target_sip=int(current_sip),
|
||||
)
|
||||
|
||||
# Infer target_pe from placement using local (within-cube) PE IDs.
|
||||
# This ensures M_CPU only fans out to PEs that own shards, not all PEs.
|
||||
local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement})
|
||||
local_pe_ids = sorted({s.pe for s in placement})
|
||||
if len(local_pe_ids) == 1:
|
||||
target_pe: int | tuple[int, ...] | str = local_pe_ids[0]
|
||||
elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube:
|
||||
@@ -501,6 +585,21 @@ class RuntimeContext:
|
||||
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
|
||||
"nbytes": shard.nbytes,
|
||||
})
|
||||
# ADR-0027: also populate MemoryStore at VA keys so kernels
|
||||
# reading via VA (the common ``tl.load`` path) see the init
|
||||
# data. Phase 1 MemoryWriteMsg writes via PA; kernels read via
|
||||
# VA; Phase 2 DataExecutor reads via the addresses captured in
|
||||
# op_log (VA for tl.load). Without this, zero-init tensors are
|
||||
# invisible to kernels in Phase 2.
|
||||
store = getattr(self.engine, "_memory_store", None)
|
||||
if store is not None and pattern == "zero" and handle.va_base:
|
||||
import numpy as np
|
||||
from kernbench.runtime_api.tensor import _numpy_dtype
|
||||
np_dtype = _numpy_dtype(dtype)
|
||||
for shard in handle.shards:
|
||||
count = shard.nbytes // itemsize
|
||||
addr = handle.va_base + shard.offset_bytes
|
||||
store.write("hbm", addr, np.zeros(count, dtype=np_dtype))
|
||||
|
||||
return t
|
||||
|
||||
@@ -509,6 +608,7 @@ class RuntimeContext:
|
||||
kernel_name: str,
|
||||
kernel_fn: Any,
|
||||
*args: Any,
|
||||
_defer_wait: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> RequestHandle:
|
||||
"""Register and launch a kernel (like a fused torch op).
|
||||
@@ -518,6 +618,11 @@ class RuntimeContext:
|
||||
|
||||
Creates per-SIP KernelLaunchMsg with local va_base per tensor
|
||||
(like host driver sending per-rank launch commands).
|
||||
|
||||
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
|
||||
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
|
||||
responsible for waiting — used by collective ops to yield between
|
||||
submit and wait so all sibling ranks can submit first.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -593,11 +698,8 @@ class RuntimeContext:
|
||||
dp = t._dp_metadata.dp_policy if t._dp_metadata else None
|
||||
if dp is None:
|
||||
return t.shape
|
||||
if dp.sip != "replicate":
|
||||
if dp.sip == "column_wise":
|
||||
K = K // self._num_sips
|
||||
elif dp.sip == "row_wise":
|
||||
M = M // self._num_sips
|
||||
# ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE
|
||||
# are the only axes that shrink the local shape.
|
||||
if dp.cube != "replicate":
|
||||
if dp.cube == "column_wise":
|
||||
K = K // self._num_cubes
|
||||
@@ -683,6 +785,18 @@ class RuntimeContext:
|
||||
_pending_handles.append((h, sip_id))
|
||||
last_handle = h
|
||||
|
||||
if _defer_wait:
|
||||
# ADR-0024 D7: return the pending-list so the caller can yield
|
||||
# between submit and drain. Used by collective ops that need
|
||||
# all sibling ranks to submit before any rank waits.
|
||||
return [
|
||||
(h, sip_id, {
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"sip": sip_id, "target_pe": target_pe,
|
||||
})
|
||||
for h, sip_id in _pending_handles
|
||||
]
|
||||
|
||||
# Drain pending handles now that every SIP has a launch posted.
|
||||
for h, sip_id in _pending_handles:
|
||||
self.wait(h, _meta={
|
||||
|
||||
@@ -23,6 +23,7 @@ Host bench code uses only real-PyTorch names:
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -40,20 +41,44 @@ class AhbmCCLBackend:
|
||||
self._merged = resolve_algorithm_config(self._cfg_all)
|
||||
self._algo_module = importlib.import_module(self._merged["module"])
|
||||
self._world_size = self._resolve_world_size()
|
||||
self._pending_collective_handles: list = []
|
||||
self._dist_ctx: Any = None
|
||||
|
||||
# Eager IPCQ install — ``init_process_group`` time. Mirrors NCCL
|
||||
# communicator creation: done once, reused across every subsequent
|
||||
# collective call on the same process group.
|
||||
self.ctx.install_ipcq(
|
||||
algorithm=self._merged["algorithm"],
|
||||
world_size_override=self._world_size,
|
||||
spec = self.ctx.spec or {}
|
||||
self._n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
self._sip_topo = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
self._cube_w = int(cm.get("w", 4))
|
||||
self._cube_h = int(cm.get("h", 4))
|
||||
|
||||
# Resolve SIP topology dims for the kernel
|
||||
topo_map = getattr(self._algo_module, "TOPO_NAME_TO_KIND", None)
|
||||
if topo_map is not None:
|
||||
self._sip_topo_kind = topo_map.get(self._sip_topo, 0)
|
||||
else:
|
||||
self._sip_topo_kind = 0
|
||||
if self._sip_topo == "ring_1d":
|
||||
self._sip_topo_w, self._sip_topo_h = 0, 0
|
||||
else:
|
||||
side = int(round(math.sqrt(self._n_sips)))
|
||||
self._sip_topo_w, self._sip_topo_h = side, side
|
||||
|
||||
# IPCQ install: wire all pe0s across all cubes and SIPs
|
||||
engine = getattr(self.ctx, "engine", None)
|
||||
if engine is not None:
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
configure_sfr_intercube_multisip(engine, spec, self._merged)
|
||||
|
||||
def _resolve_world_size(self) -> int:
|
||||
"""Derive world_size (priority: algorithm override > defaults > topology).
|
||||
|
||||
Topology derivation:
|
||||
sips × cubes_per_sip × pes_per_cube
|
||||
ADR-0024 D1: topology fallback is SIP count. Each rank represents one
|
||||
SIP (TP dimension). Intra-SIP parallelism is expressed via DPPolicy
|
||||
inside each worker and is independent of world_size.
|
||||
Explicit ``ccl.yaml`` override still respected — legacy "rank = flat
|
||||
PE index" tests use this path.
|
||||
"""
|
||||
if "world_size" in self._merged:
|
||||
return int(self._merged["world_size"])
|
||||
@@ -61,14 +86,7 @@ class AhbmCCLBackend:
|
||||
if "world_size" in defaults:
|
||||
return int(defaults["world_size"])
|
||||
spec = self.ctx.spec or {}
|
||||
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
|
||||
pl = spec.get("cube", {}).get("pe_layout", {})
|
||||
corners = pl.get("corners", [])
|
||||
pe_per_corner = int(pl.get("pe_per_corner", 1))
|
||||
pes_per_cube = pe_per_corner * max(len(corners), 1)
|
||||
return sips * cubes_per_sip * pes_per_cube
|
||||
return int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@@ -89,20 +107,48 @@ class AhbmCCLBackend:
|
||||
"with a DPPolicy first)"
|
||||
)
|
||||
shards = tensor._handle.shards
|
||||
if len(shards) != self._world_size:
|
||||
if not shards:
|
||||
raise RuntimeError(
|
||||
f"all_reduce tensor has {len(shards)} shards but the "
|
||||
f"ahbm backend was installed with world_size="
|
||||
f"{self._world_size}; adjust the tensor's DPPolicy or "
|
||||
"restart the process group"
|
||||
f"all_reduce tensor '{tensor.name}' has no shards"
|
||||
)
|
||||
n_elem = shards[0].nbytes // tensor.itemsize
|
||||
kernel_fn = self._algo_module.kernel
|
||||
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
|
||||
self.ctx.launch(
|
||||
self._merged["algorithm"], kernel_fn, tensor, *kernel_args,
|
||||
|
||||
# Resolve sip_rank from the current greenlet's bound rank
|
||||
from greenlet import getcurrent as _gc
|
||||
g = _gc()
|
||||
dist_ctx = getattr(self, "_dist_ctx", None)
|
||||
if dist_ctx is not None:
|
||||
sip_rank = int(dist_ctx._rank_by_greenlet.get(g, 0))
|
||||
else:
|
||||
sip_rank = 0
|
||||
|
||||
extra_args = (
|
||||
sip_rank,
|
||||
self._sip_topo_kind,
|
||||
self._sip_topo_w,
|
||||
self._sip_topo_h,
|
||||
)
|
||||
|
||||
pending = self.ctx.launch(
|
||||
self._merged["algorithm"], kernel_fn, tensor,
|
||||
*kernel_args, *extra_args,
|
||||
_defer_wait=True,
|
||||
)
|
||||
from greenlet import getcurrent
|
||||
g = getcurrent()
|
||||
if g.parent is not None and not g.parent.dead:
|
||||
# Multi-greenlet mode: hand pending to the backend-level queue so
|
||||
# the main scheduler drains. Worker just yields.
|
||||
self._pending_collective_handles.extend(pending)
|
||||
g.parent.switch()
|
||||
# On resume, all pending handles have been drained by main.
|
||||
else:
|
||||
# Single-driver (no bench scheduler): drain inline.
|
||||
for h, _sip_id, meta in pending:
|
||||
self.ctx.wait(h, _meta=meta)
|
||||
|
||||
def barrier(self) -> None:
|
||||
# Single-driver model → no cross-process sync needed. Keeping the
|
||||
# method so ``dist.barrier()`` is callable (pytorch-compat surface).
|
||||
@@ -121,6 +167,11 @@ class DistributedContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._backend: AhbmCCLBackend | None = None
|
||||
# ADR-0024 D9: greenlet-local rank registry. Bench launcher calls
|
||||
# _bind_rank(g, rank) when spawning workers; get_rank() resolves the
|
||||
# current greenlet to its rank. Unbound greenlets fall back to 0 for
|
||||
# single-driver test compat.
|
||||
self._rank_by_greenlet: dict = {}
|
||||
|
||||
def init_process_group(
|
||||
self,
|
||||
@@ -146,6 +197,7 @@ class DistributedContext:
|
||||
"DistributedContext not bound to a RuntimeContext"
|
||||
)
|
||||
self._backend = AhbmCCLBackend(torch_ctx=ctx)
|
||||
self._backend._dist_ctx = self
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._backend is not None
|
||||
@@ -155,9 +207,20 @@ class DistributedContext:
|
||||
return self._backend.world_size
|
||||
|
||||
def get_rank(self) -> int:
|
||||
# Single-driver kernbench: there is only one host rank.
|
||||
"""Return the rank bound to the current greenlet (default 0).
|
||||
|
||||
ADR-0024 D9: workers spawned by the bench launcher each get a rank
|
||||
registered via ``_bind_rank``. Callers outside any bound greenlet
|
||||
fall back to rank 0 for single-driver test compat.
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
return 0
|
||||
from greenlet import getcurrent
|
||||
g = getcurrent()
|
||||
return int(self._rank_by_greenlet.get(g, 0))
|
||||
|
||||
def _bind_rank(self, g: Any, rank: int) -> None:
|
||||
"""Bind a greenlet to a rank so ``get_rank()`` returns it (ADR-0024 D9)."""
|
||||
self._rank_by_greenlet[g] = int(rank)
|
||||
|
||||
def get_backend(self) -> str:
|
||||
self._ensure_initialized()
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"""``torch.multiprocessing.spawn``-compatible namespace (ADR-0027 D1).
|
||||
|
||||
Real-PyTorch API *signature* parity only — execution model is a cooperative
|
||||
greenlet scheduler in a single Python process (D1.0). Non-goals: process
|
||||
isolation, independent address space, failure isolation, OS-level scheduler
|
||||
fairness, mp.Queue/Lock.
|
||||
|
||||
Attached to ``RuntimeContext`` as ``ctx.multiprocessing`` in
|
||||
``__post_init__`` (D1.3).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class SpawnException(RuntimeError):
|
||||
"""Raised from ``_MultiprocessingNamespace.spawn`` on worker failure.
|
||||
|
||||
``errors`` contains only root-cause ranks — the rank(s) whose body
|
||||
raised. Sibling greenlets terminated via ``throw(SystemExit)`` during
|
||||
cleanup are NOT recorded (SystemExit does not satisfy ``except
|
||||
Exception`` in the entry wrapper).
|
||||
"""
|
||||
|
||||
def __init__(self, errors: dict[int, Exception]):
|
||||
self.errors = errors
|
||||
first = next(iter(errors.items()), None)
|
||||
msg = (
|
||||
f"spawn failed on ranks {sorted(errors.keys())}"
|
||||
+ (
|
||||
f": rank {first[0]} raised {first[1]!r}"
|
||||
if first is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
def _drain_pending(ctx: Any) -> None:
|
||||
"""Drain worker-wait + collective-pending queues in main context (D0.4/D0.5).
|
||||
|
||||
Loop-until-empty: runs until both queues are simultaneously empty. Safe
|
||||
under the current model where main-context ``ctx.wait`` never re-enqueues
|
||||
(D0.5 main-context non-reentrance invariant); also safe under future
|
||||
extensions where drain can add sub-handles (SimPy causality gives finite
|
||||
depth).
|
||||
"""
|
||||
distributed = getattr(ctx, "distributed", None)
|
||||
backend = getattr(distributed, "_backend", None) if distributed else None
|
||||
|
||||
def _collective_nonempty() -> bool:
|
||||
if backend is None:
|
||||
return False
|
||||
pending = getattr(backend, "_pending_collective_handles", None)
|
||||
return bool(pending)
|
||||
|
||||
while ctx._pending_worker_waits or _collective_nonempty():
|
||||
# (a) Worker-driven waits (D0.1). FIFO.
|
||||
while ctx._pending_worker_waits:
|
||||
h = ctx._pending_worker_waits.pop(0)
|
||||
if h not in ctx._completed:
|
||||
wait_fn = getattr(ctx.engine, "wait", None)
|
||||
if wait_fn is not None:
|
||||
wait_fn(h)
|
||||
# Populate _completed so fast-path in ctx.wait short-circuits
|
||||
# on the return leg.
|
||||
ctx._completed.add(h)
|
||||
# (b) Collective backend queue (ADR-0024 D7 + D0.4-(2)).
|
||||
if backend is not None:
|
||||
pending_list = getattr(backend, "_pending_collective_handles", None)
|
||||
if pending_list is not None:
|
||||
while pending_list:
|
||||
h, _sip_id, meta = pending_list.pop(0)
|
||||
# Main context: ctx.wait drives engine directly and does
|
||||
# NOT re-enqueue (D0.5 invariant).
|
||||
ctx.wait(h, _meta=meta)
|
||||
|
||||
|
||||
class _MultiprocessingNamespace:
|
||||
"""torch.multiprocessing-compat facade bound to a RuntimeContext."""
|
||||
|
||||
def __init__(self, ctx: Any) -> None:
|
||||
self._ctx = ctx
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
fn: Callable,
|
||||
args: tuple = (),
|
||||
nprocs: int = 1,
|
||||
join: bool = True,
|
||||
) -> None:
|
||||
"""Spawn ``nprocs`` worker greenlets, each calling ``fn(rank, *args)``.
|
||||
|
||||
Mirrors ``torch.multiprocessing.spawn`` signature (minus ``daemon``).
|
||||
Runs the D0.4 round-robin scheduler loop until all workers finish,
|
||||
draining pending queues between rounds.
|
||||
"""
|
||||
from greenlet import greenlet
|
||||
|
||||
ctx = self._ctx
|
||||
dist = ctx.distributed
|
||||
gs: list = []
|
||||
errors: dict[int, Exception] = {}
|
||||
|
||||
for rank in range(nprocs):
|
||||
def _entry(r: int = rank) -> None:
|
||||
try:
|
||||
fn(r, *args)
|
||||
except Exception as e:
|
||||
errors[r] = e
|
||||
raise
|
||||
|
||||
g = greenlet(_entry)
|
||||
if dist is not None and hasattr(dist, "_bind_rank"):
|
||||
dist._bind_rank(g, rank)
|
||||
gs.append(g)
|
||||
|
||||
try:
|
||||
while True:
|
||||
alive = [g for g in gs if not g.dead]
|
||||
if not alive:
|
||||
break
|
||||
for g in alive:
|
||||
if not g.dead:
|
||||
g.switch()
|
||||
_drain_pending(ctx)
|
||||
except Exception as outer:
|
||||
# D0.4-(4) sibling cleanup. Abort live greenlets, clear state.
|
||||
for other in gs:
|
||||
if not other.dead:
|
||||
try:
|
||||
other.throw(SystemExit)
|
||||
except BaseException:
|
||||
# SystemExit inherits BaseException; greenlet.throw
|
||||
# re-raises in caller if target doesn't catch it.
|
||||
# Silent — we're already in cleanup.
|
||||
pass
|
||||
backend = getattr(dist, "_backend", None)
|
||||
if backend is not None:
|
||||
if hasattr(backend, "_barrier") and hasattr(backend._barrier, "reset"):
|
||||
try:
|
||||
backend._barrier.reset()
|
||||
except Exception:
|
||||
pass
|
||||
pending_collective = getattr(
|
||||
backend, "_pending_collective_handles", None,
|
||||
)
|
||||
if pending_collective is not None:
|
||||
pending_collective.clear()
|
||||
ctx._pending_worker_waits.clear()
|
||||
raise SpawnException(errors) from outer
|
||||
# join=True: we already waited for all workers above.
|
||||
@@ -66,13 +66,64 @@ def _numpy_dtype(dtype: str) -> np.dtype:
|
||||
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
|
||||
|
||||
|
||||
# ADR-0027 T5.g: closed-set registry of host-read barrier entry-points.
|
||||
# Any new Tensor API with host-observable read semantics must be added here
|
||||
# AND implement the barrier call. Code review + this registry keep the set
|
||||
# consistent (Python introspection-based auto-detection is a non-goal).
|
||||
# Note on ``copy_``: the source read is barriered via ``source.numpy()``.
|
||||
# A target-side write barrier was specified in an earlier revision of
|
||||
# ADR-0027 D0.5 but is intentionally not applied (global-pending target
|
||||
# barrier can prematurely drain cross-rank collectives → deadlock).
|
||||
_HOST_READ_BARRIERS: frozenset[str] = frozenset({
|
||||
"numpy",
|
||||
"data",
|
||||
"__getitem__",
|
||||
"__repr__",
|
||||
"copy_", # source-side via source.numpy(); target-side not barriered
|
||||
})
|
||||
|
||||
|
||||
def _host_read_barrier(tensor: "Tensor") -> None:
|
||||
"""ADR-0027 D0.5: drain pending worker-wait queue before a host-observable
|
||||
read/write.
|
||||
|
||||
Scope: the barrier yields to main when ``ctx._pending_worker_waits`` is
|
||||
non-empty AND the caller is a worker greenlet. Collective pending
|
||||
(``backend._pending_collective_handles``) is **deliberately excluded**
|
||||
from this check — collective handles represent cross-rank protocol that
|
||||
must be drained only at scheduler synchronisation points (all workers
|
||||
yielded). A collective's own yield (inside ``all_reduce``) already
|
||||
ensures that once the collective call returns to the worker, post-drain
|
||||
values are visible, so subsequent host reads see materialised data
|
||||
without needing to trigger drain themselves. Including collective
|
||||
pending here would cause an unrelated rank's barrier to prematurely
|
||||
request drain of a cross-rank operation → deadlock.
|
||||
|
||||
No-op when called from main context or when the worker-wait queue is
|
||||
empty (fast-path avoids needless context switches).
|
||||
"""
|
||||
ctx = None
|
||||
if tensor._ctx_ref is not None:
|
||||
ctx = tensor._ctx_ref()
|
||||
if ctx is None:
|
||||
return
|
||||
worker_pending = getattr(ctx, "_pending_worker_waits", None)
|
||||
if not worker_pending:
|
||||
return # fast-path
|
||||
from greenlet import getcurrent
|
||||
g = getcurrent()
|
||||
if g.parent is None or g.parent.dead:
|
||||
return # main context: caller drains directly when needed
|
||||
g.parent.switch()
|
||||
|
||||
|
||||
def deploy_tensor(
|
||||
*,
|
||||
name: str,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str,
|
||||
placement: list[ShardSpec],
|
||||
allocators: dict[int, PEMemAllocator],
|
||||
allocators: dict[tuple[int, int, int], PEMemAllocator],
|
||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||
va_allocator=None,
|
||||
) -> TensorHandle:
|
||||
@@ -86,15 +137,15 @@ def deploy_tensor(
|
||||
|
||||
shards: list[TensorShard] = []
|
||||
for spec in placement:
|
||||
alloc = allocators[spec.pe_index]
|
||||
alloc = allocators[(spec.sip, spec.cube, spec.pe)]
|
||||
if mem_kind == "hbm":
|
||||
pa = alloc.alloc_hbm(spec.nbytes)
|
||||
else:
|
||||
pa = alloc.alloc_tcm(spec.nbytes)
|
||||
shards.append(TensorShard(
|
||||
sip=alloc._sip_id,
|
||||
cube=alloc._cube_id,
|
||||
pe=alloc._pe_id,
|
||||
sip=spec.sip,
|
||||
cube=spec.cube,
|
||||
pe=spec.pe,
|
||||
pa=pa.encode(),
|
||||
nbytes=spec.nbytes,
|
||||
offset_bytes=spec.offset_bytes,
|
||||
@@ -217,7 +268,9 @@ class Tensor:
|
||||
"""Read a shard-aligned slice. Returns a numpy array.
|
||||
|
||||
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
|
||||
ADR-0027 D0.5: host-read barrier.
|
||||
"""
|
||||
_host_read_barrier(self)
|
||||
start, stop = self._resolve_shard_index(key)
|
||||
shard = self._shard_for_range(start, stop)
|
||||
if self._memory_store is None:
|
||||
@@ -272,6 +325,8 @@ class Tensor:
|
||||
def __repr__(self) -> str:
|
||||
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
# ADR-0027 D0.5: barrier on data-containing repr path.
|
||||
_host_read_barrier(self)
|
||||
arr = self.data
|
||||
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
|
||||
else:
|
||||
@@ -308,7 +363,11 @@ class Tensor:
|
||||
Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are
|
||||
gathered into a single full-shape ndarray according to each shard's
|
||||
``offset_bytes`` / ``nbytes`` range.
|
||||
|
||||
ADR-0027 D0.5: acts as a host-read barrier — drains pending waits +
|
||||
collective handles before reading, ensuring post-drain values.
|
||||
"""
|
||||
_host_read_barrier(self)
|
||||
np_dtype = _numpy_dtype(self.dtype)
|
||||
# Host-side tensor (created via torch.from_numpy) has no shards.
|
||||
if self._host_buffer is not None:
|
||||
@@ -340,6 +399,12 @@ class Tensor:
|
||||
re-scattered into self's shard layout.
|
||||
|
||||
Shapes must match. Returns self.
|
||||
|
||||
ADR-0027 D0.5: source-side read barrier is triggered inside
|
||||
``source.numpy()``. Target-side write barrier is not applied here
|
||||
because it would require cross-rank coordination when other ranks
|
||||
have pending collectives (see _host_read_barrier docstring on
|
||||
collective pending being cross-rank).
|
||||
"""
|
||||
if self._handle is None or self._memory_store is None:
|
||||
raise RuntimeError(
|
||||
@@ -394,7 +459,8 @@ class Tensor:
|
||||
) -> Tensor:
|
||||
"""Set DP placement metadata (like torch.Tensor.to())."""
|
||||
if placement is None:
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)]
|
||||
placement = [ShardSpec(sip=0, cube=0, pe=0,
|
||||
offset_bytes=0, nbytes=self.nbytes)]
|
||||
self._dp_metadata = DPMetadata(
|
||||
placement=placement, dp_policy=dp_policy,
|
||||
sip=sip, cube=cube, target_pe=target_pe,
|
||||
|
||||
@@ -101,12 +101,19 @@ class DataExecutor:
|
||||
p = op.params
|
||||
if "src_a_addr" not in p:
|
||||
return # composite record without full params
|
||||
space = p.get("addr_space", "tcm")
|
||||
default_space = p.get("addr_space", "tcm")
|
||||
# ADR-0027: per-operand + output spaces (fall back to single space
|
||||
# for legacy records without explicit space keys).
|
||||
src_a_space = p.get("src_a_space", default_space)
|
||||
src_b_space = p.get("src_b_space", default_space)
|
||||
dst_space = p.get("dst_space", default_space)
|
||||
dtype_in = p.get("dtype_in", "f16")
|
||||
dtype_out = p.get("dtype_out", dtype_in)
|
||||
|
||||
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in)
|
||||
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in)
|
||||
a = self.store.read(src_a_space, p["src_a_addr"],
|
||||
shape=p.get("shape_a"), dtype=dtype_in)
|
||||
b = self.store.read(src_b_space, p["src_b_addr"],
|
||||
shape=p.get("shape_b"), dtype=dtype_in)
|
||||
|
||||
# Compute in higher precision if specified
|
||||
dtype_acc = p.get("dtype_acc", "f32")
|
||||
@@ -114,7 +121,7 @@ class DataExecutor:
|
||||
b_f = b.astype(_resolve_dtype(dtype_acc))
|
||||
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
|
||||
|
||||
self.store.write(space, p["dst_addr"], result)
|
||||
self.store.write(dst_space, p["dst_addr"], result)
|
||||
|
||||
def _execute_math(self, op: OpRecord) -> None:
|
||||
"""Execute math op: unary, binary, or reduction."""
|
||||
|
||||
@@ -79,6 +79,14 @@ class OpLogger:
|
||||
snaps.append(None)
|
||||
params["input_snapshots"] = snaps
|
||||
elif op_name == "dma_write":
|
||||
# ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch)
|
||||
# sources are repopulated by Phase 2 math/gemm replay —
|
||||
# capturing a Phase-1-time snapshot here would pick up stale
|
||||
# data from a PRIOR kernel's Phase 2 output that aliased the
|
||||
# same scratch address, causing the later kernel's replay
|
||||
# to write that stale value instead of the fresh math
|
||||
# result. See ADR-0027 postmortem (TP gemm → all_reduce).
|
||||
if params.get("src_space") == "hbm":
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
params["src_space"], params["src_addr"],
|
||||
@@ -167,6 +175,13 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
"dtype_in": msg.a.dtype,
|
||||
"dtype_out": msg.out.dtype,
|
||||
"m": msg.m, "k": msg.k, "n": msg.n,
|
||||
# ADR-0027: preserve per-operand + output MemoryStore spaces so
|
||||
# Phase 2 replay can resolve HBM-resident operands (e.g. tl.load
|
||||
# results keep space="hbm"). Absent → DataExecutor falls back
|
||||
# to the legacy single-space mode via ``addr_space``.
|
||||
"src_a_space": getattr(msg.a, "space", "tcm"),
|
||||
"src_b_space": getattr(msg.b, "space", "tcm"),
|
||||
"dst_space": getattr(msg.out, "space", "tcm"),
|
||||
}
|
||||
if isinstance(msg, MathCmd):
|
||||
return "math", msg.op, {
|
||||
@@ -181,10 +196,27 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
"axis": msg.axis,
|
||||
}
|
||||
if isinstance(msg, CompositeCmd):
|
||||
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", {
|
||||
params: dict[str, Any] = {
|
||||
"op": msg.op,
|
||||
"out_addr": msg.out_addr,
|
||||
"out_nbytes": msg.out_nbytes,
|
||||
}
|
||||
# ADR-0027: preserve operand info so Phase 2 DataExecutor can replay
|
||||
# the composite's numerical effect (treat it like a GemmCmd).
|
||||
if msg.op == "gemm" and msg.a is not None and msg.b is not None:
|
||||
params.update({
|
||||
"src_a_addr": msg.a.addr,
|
||||
"src_b_addr": msg.b.addr,
|
||||
"shape_a": msg.a.shape,
|
||||
"shape_b": msg.b.shape,
|
||||
"dtype_in": msg.a.dtype,
|
||||
"dtype_out": msg.a.dtype,
|
||||
"src_a_space": getattr(msg.a, "space", "hbm"),
|
||||
"src_b_space": getattr(msg.b, "space", "hbm"),
|
||||
"dst_space": "hbm",
|
||||
# dst_addr alias so DataExecutor._execute_gemm picks it up.
|
||||
"dst_addr": msg.out_addr,
|
||||
})
|
||||
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", params
|
||||
# Fallback for unknown data_op messages
|
||||
return "unknown", type(msg).__name__, {}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""kernbench.tp — Megatron-style Tensor Parallelism (ADR-0027).
|
||||
|
||||
Public API re-exports.
|
||||
"""
|
||||
from kernbench.tp.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from kernbench.tp.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ColumnParallelLinear",
|
||||
"RowParallelLinear",
|
||||
"get_tensor_model_parallel_rank",
|
||||
"get_tensor_model_parallel_world_size",
|
||||
"initialize_model_parallel",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Kernel used by ``kernbench.tp`` layers (ADR-0027 D4/D5).
|
||||
|
||||
Intentionally self-contained inside the ``tp`` package — the ``tp`` package
|
||||
must not import from ``benches/``. Future work: move to a shared
|
||||
``kernbench.kernels`` module so benches and TP can share.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE: str = "f16") -> None:
|
||||
"""Single-PE GEMM: out = a @ b via load → dot → store.
|
||||
|
||||
Uses the ``tl.load + tl.dot + tl.store`` path. Unlike ``tl.composite``
|
||||
(which is absorbed by the PE scheduler into TileTokens that don't reach
|
||||
the op_log), this path emits explicit ``DmaReadCmd`` / ``GemmCmd`` /
|
||||
``DmaWriteCmd`` records, which DataExecutor replays numerically in
|
||||
Phase 2.
|
||||
"""
|
||||
M, K, N = int(M), int(K), int(N)
|
||||
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
||||
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
||||
out = tl.dot(a, b)
|
||||
tl.store(int(out_ptr), out)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Megatron-style parallel layers (ADR-0027 D4/D5).
|
||||
|
||||
- ``ColumnParallelLinear``: weight's out_features axis split across TP ranks.
|
||||
forward(x) is local gemm; no collective.
|
||||
- ``RowParallelLinear``: weight's in_features axis split across TP ranks.
|
||||
forward(x) ends with ``dist.all_reduce`` to sum partial products.
|
||||
|
||||
Both layers use the intra-device ``DPPolicy`` (ADR-0026). TP shard
|
||||
ownership is determined by ``torch.ahbm.set_device(rank)`` (ADR-0024 D10).
|
||||
|
||||
Yield-safety contract (ADR-0027 D4/D5): every forward path contains at
|
||||
least one ``ctx.wait`` (via ``torch.launch``) or one collective; this
|
||||
keeps the scheduler loop making progress.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.tp.kernels import _gemm_kernel
|
||||
from kernbench.tp.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
|
||||
class ColumnParallelLinear:
|
||||
"""Weight's K (out_features) axis distributed across TP ranks.
|
||||
|
||||
forward(x):
|
||||
x: (M, N) — full-replicated across ranks
|
||||
W_k: (N, K / world_size) — this rank's slice (on its SIP)
|
||||
y_k = x @ W_k → (M, K / world_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
dtype: str = "f16",
|
||||
torch: Any = None,
|
||||
) -> None:
|
||||
if torch is None:
|
||||
raise TypeError("ColumnParallelLinear requires torch=<RuntimeContext>")
|
||||
ws = get_tensor_model_parallel_world_size()
|
||||
if out_features % ws != 0:
|
||||
raise ValueError(
|
||||
f"out_features ({out_features}) must be divisible by TP world "
|
||||
f"size ({ws})"
|
||||
)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.k_local = out_features // ws
|
||||
self.dtype = dtype
|
||||
self._torch = torch
|
||||
# Per-rank weight slice. ``set_device(rank)`` (ADR-0024 D10) places
|
||||
# it on SIP ``rank``. Intra-SIP layout comes from DPPolicy (ADR-0026).
|
||||
self.weight = torch.zeros(
|
||||
(in_features, self.k_local),
|
||||
dtype=dtype,
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="col_parallel_w",
|
||||
)
|
||||
# Bias omitted in initial scope (ADR-0027 D9).
|
||||
self.bias = None
|
||||
if bias:
|
||||
raise NotImplementedError(
|
||||
"bias=True is deferred (ADR-0027 D9 initial scope)"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
M = int(x.shape[0])
|
||||
out = self._torch.empty(
|
||||
(M, self.k_local),
|
||||
dtype=x.dtype,
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="col_parallel_out",
|
||||
)
|
||||
self._torch.launch(
|
||||
"col_parallel_gemm",
|
||||
_gemm_kernel,
|
||||
x, self.weight, out,
|
||||
M, self.in_features, self.k_local,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class RowParallelLinear:
|
||||
"""Weight's N (in_features) axis distributed across TP ranks.
|
||||
|
||||
forward(x):
|
||||
x: (M, N / world_size) — rank-local slice (ColumnParallel output)
|
||||
W_k: (N / world_size, K) — this rank's slice
|
||||
y_k = x @ W_k → (M, K) — partial sum
|
||||
y = all_reduce(y_k, op="sum") → (M, K) on every rank
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
dtype: str = "f16",
|
||||
torch: Any = None,
|
||||
) -> None:
|
||||
if torch is None:
|
||||
raise TypeError("RowParallelLinear requires torch=<RuntimeContext>")
|
||||
ws = get_tensor_model_parallel_world_size()
|
||||
if in_features % ws != 0:
|
||||
raise ValueError(
|
||||
f"in_features ({in_features}) must be divisible by TP world "
|
||||
f"size ({ws})"
|
||||
)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.n_local = in_features // ws
|
||||
self.dtype = dtype
|
||||
self._torch = torch
|
||||
self.weight = torch.zeros(
|
||||
(self.n_local, out_features),
|
||||
dtype=dtype,
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="row_parallel_w",
|
||||
)
|
||||
self.bias = None
|
||||
if bias:
|
||||
raise NotImplementedError(
|
||||
"bias=True is deferred (ADR-0027 D9 initial scope)"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
M = int(x.shape[0])
|
||||
y_partial = self._torch.empty(
|
||||
(M, self.out_features),
|
||||
dtype=x.dtype,
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="row_parallel_partial",
|
||||
)
|
||||
self._torch.launch(
|
||||
"row_parallel_gemm",
|
||||
_gemm_kernel,
|
||||
x, self.weight, y_partial,
|
||||
M, self.n_local, self.out_features,
|
||||
)
|
||||
self._torch.distributed.all_reduce(y_partial, op="sum")
|
||||
return y_partial
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Forward/backward mappings stub (ADR-0027 — future backward work).
|
||||
|
||||
Inference-only initial scope. Backward hooks land when training simulation
|
||||
arrives.
|
||||
"""
|
||||
@@ -0,0 +1,83 @@
|
||||
"""TP group state (ADR-0027 D3).
|
||||
|
||||
Single global TP group. Initial scope: TP size == world_size (pure TP;
|
||||
mixed DP+TP is future work).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
_TP_WORLD_SIZE: int | None = None
|
||||
|
||||
|
||||
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
|
||||
"""Initialize the TP process group.
|
||||
|
||||
Must be called after ``torch.distributed.init_process_group``.
|
||||
Only ``tensor_model_parallel_size == world_size`` is supported in the
|
||||
initial scope.
|
||||
"""
|
||||
global _TP_WORLD_SIZE
|
||||
# Import here to avoid cycle when tp is imported before a ctx exists.
|
||||
_ws = _current_world_size()
|
||||
if tensor_model_parallel_size != _ws:
|
||||
raise NotImplementedError(
|
||||
f"Only TP == world_size supported; got TP={tensor_model_parallel_size}, "
|
||||
f"world_size={_ws}"
|
||||
)
|
||||
_TP_WORLD_SIZE = tensor_model_parallel_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size() -> int:
|
||||
"""Return the TP group's world size.
|
||||
|
||||
Raises if not initialised — callers must call
|
||||
:func:`initialize_model_parallel` first.
|
||||
"""
|
||||
if _TP_WORLD_SIZE is None:
|
||||
raise RuntimeError(
|
||||
"TP group not initialised; call initialize_model_parallel() first"
|
||||
)
|
||||
return _TP_WORLD_SIZE
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank() -> int:
|
||||
"""Return this worker's rank within the TP group.
|
||||
|
||||
Delegates to the greenlet-local rank registered by the spawn launcher
|
||||
(ADR-0024 D9 via ``torch.distributed.get_rank``).
|
||||
"""
|
||||
# Resolve via the global torch.distributed facade on the active ctx.
|
||||
return _current_rank()
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear _TP_WORLD_SIZE so ordering-sensitive tests can re-init."""
|
||||
global _TP_WORLD_SIZE
|
||||
_TP_WORLD_SIZE = None
|
||||
|
||||
|
||||
# ── helpers (resolve current ctx) ────────────────────────────────────
|
||||
|
||||
|
||||
def _current_ctx():
|
||||
"""Best-effort resolution of the currently-active RuntimeContext.
|
||||
|
||||
In KernBench, the ``ctx`` is passed as the ``torch`` positional in
|
||||
bench/worker code. Since parallel_state is a module-global helper,
|
||||
we look it up via a weak registry maintained by RuntimeContext.
|
||||
"""
|
||||
from kernbench.runtime_api.context import _get_active_context
|
||||
ctx = _get_active_context()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"No active RuntimeContext; kernbench.tp requires one "
|
||||
"(call init_process_group / spawn under a live ctx)"
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def _current_world_size() -> int:
|
||||
return _current_ctx().distributed.get_world_size()
|
||||
|
||||
|
||||
def _current_rank() -> int:
|
||||
return _current_ctx().distributed.get_rank()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""TP primitive ops (ADR-0027 D6).
|
||||
|
||||
``copy_to_tp_region`` / ``reduce_from_tp_region`` are forward-only in the
|
||||
initial scope (backward pass is future work). ``scatter`` / ``gather`` are
|
||||
not implemented — they require an all-gather kernel that is not yet
|
||||
available in KernBench (see ADR-0027 D9).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def copy_to_tp_region(x: Any) -> Any:
|
||||
"""Forward: identity. Backward: all-reduce. (Training is future.)"""
|
||||
return x
|
||||
|
||||
|
||||
def reduce_from_tp_region(x: Any, torch: Any) -> Any:
|
||||
"""Forward: all-reduce. Backward: identity."""
|
||||
torch.distributed.all_reduce(x, op="sum")
|
||||
return x
|
||||
|
||||
|
||||
def scatter_to_tp_region(x: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"scatter_to_tp_region deferred — caller should create the sharded "
|
||||
"tensor directly (ADR-0027 D9)"
|
||||
)
|
||||
|
||||
|
||||
def gather_from_tp_region(x: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"gather_from_tp_region deferred — requires all-gather kernel (ADR-0027 D9)"
|
||||
)
|
||||
@@ -0,0 +1,239 @@
|
||||
"""ADR-0026 Phase 1 tests: DPPolicy intra-device only + ShardSpec structural.
|
||||
|
||||
These tests encode the contract from ADR-0026:
|
||||
|
||||
- DPPolicy no longer accepts ``sip`` or ``num_sips`` kwargs (TypeError).
|
||||
- ShardSpec carries structural ``(sip, cube, pe)`` coordinates; the old flat
|
||||
``pe_index`` field/property is fully removed (AttributeError).
|
||||
- ``resolve_dp_policy(..., target_sip=N)`` stamps every returned ShardSpec
|
||||
with ``sip=N``; cube and pe fields are local.
|
||||
- ``RuntimeContext._allocators`` is keyed by ``(sip, cube, pe)`` tuples.
|
||||
|
||||
Phase 1: production code is unchanged → these tests SHOULD FAIL until the
|
||||
Phase 2 diff lands. Phase 2 makes all of them pass.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.runtime_api.tensor import deploy_tensor
|
||||
|
||||
|
||||
# ── D1: DPPolicy no longer accepts sip / num_sips ─────────────────────
|
||||
|
||||
|
||||
def test_dppolicy_rejects_sip_kwarg():
|
||||
"""DPPolicy(sip=...) must raise TypeError after field removal."""
|
||||
with pytest.raises(TypeError):
|
||||
DPPolicy(sip="column_wise", cube="replicate", pe="replicate")
|
||||
|
||||
|
||||
def test_dppolicy_rejects_num_sips_kwarg():
|
||||
"""DPPolicy(num_sips=...) must raise TypeError after field removal."""
|
||||
with pytest.raises(TypeError):
|
||||
DPPolicy(cube="replicate", pe="replicate", num_sips=2)
|
||||
|
||||
|
||||
def test_dppolicy_accepts_only_intra_device_fields():
|
||||
"""Intra-device fields still work: cube, pe, num_cubes, num_pes."""
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise",
|
||||
num_cubes=2, num_pes=4)
|
||||
assert dp.cube == "column_wise"
|
||||
assert dp.pe == "column_wise"
|
||||
assert dp.num_cubes == 2
|
||||
assert dp.num_pes == 4
|
||||
# No sip / num_sips attributes — even reading them must fail.
|
||||
assert not hasattr(dp, "sip"), "DPPolicy.sip must be removed"
|
||||
assert not hasattr(dp, "num_sips"), "DPPolicy.num_sips must be removed"
|
||||
|
||||
|
||||
# ── D2: ShardSpec structural coords, no pe_index ──────────────────────
|
||||
|
||||
|
||||
def test_shardspec_has_structural_coords():
|
||||
"""ShardSpec constructs from (sip, cube, pe, offset_bytes, nbytes)."""
|
||||
s = ShardSpec(sip=1, cube=2, pe=3, offset_bytes=128, nbytes=64)
|
||||
assert s.sip == 1
|
||||
assert s.cube == 2
|
||||
assert s.pe == 3
|
||||
assert s.offset_bytes == 128
|
||||
assert s.nbytes == 64
|
||||
|
||||
|
||||
def test_shardspec_has_no_pe_index_attr():
|
||||
"""Flat pe_index must be fully removed — no field, no property."""
|
||||
s = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=8)
|
||||
with pytest.raises(AttributeError):
|
||||
_ = s.pe_index # noqa: F841
|
||||
|
||||
|
||||
def test_shardspec_rejects_pe_index_kwarg():
|
||||
"""ShardSpec(pe_index=...) must raise TypeError."""
|
||||
with pytest.raises(TypeError):
|
||||
ShardSpec(pe_index=0, offset_bytes=0, nbytes=8) # type: ignore[call-arg]
|
||||
|
||||
|
||||
# ── D3: resolve_dp_policy(target_sip=...) structural semantics ────────
|
||||
|
||||
|
||||
def test_resolve_dp_policy_target_sip_stamps_shards():
|
||||
"""All returned shards must carry sip == target_sip; cube/pe local."""
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(4, 32), itemsize=2,
|
||||
num_pe=4, num_cubes=2, target_sip=1,
|
||||
)
|
||||
assert len(shards) == 2 * 4
|
||||
assert all(s.sip == 1 for s in shards)
|
||||
assert all(0 <= s.cube < 2 for s in shards)
|
||||
assert all(0 <= s.pe < 4 for s in shards)
|
||||
|
||||
|
||||
def test_resolve_dp_policy_target_sip_differ_only_in_sip():
|
||||
"""Same policy + dims on two SIPs → shards identical except .sip."""
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
shards_0 = resolve_dp_policy(
|
||||
dp, shape=(4, 32), itemsize=2,
|
||||
num_pe=4, num_cubes=1, target_sip=0,
|
||||
)
|
||||
shards_1 = resolve_dp_policy(
|
||||
dp, shape=(4, 32), itemsize=2,
|
||||
num_pe=4, num_cubes=1, target_sip=1,
|
||||
)
|
||||
assert len(shards_0) == len(shards_1)
|
||||
for a, b in zip(shards_0, shards_1):
|
||||
assert a.sip == 0 and b.sip == 1
|
||||
assert a.cube == b.cube
|
||||
assert a.pe == b.pe
|
||||
assert a.offset_bytes == b.offset_bytes
|
||||
assert a.nbytes == b.nbytes
|
||||
|
||||
|
||||
def test_resolve_dp_policy_no_num_sips_param():
|
||||
"""resolve_dp_policy must not accept num_sips anymore.
|
||||
|
||||
Post-Phase-2 signature drops ``num_sips`` (DPPolicy no longer crosses
|
||||
SIP boundaries) and adds required ``target_sip``. Calling with
|
||||
``num_sips=...`` must raise TypeError (unexpected keyword argument).
|
||||
"""
|
||||
dp = DPPolicy(cube="replicate", pe="replicate")
|
||||
with pytest.raises(TypeError, match="num_sips"):
|
||||
resolve_dp_policy(
|
||||
dp, shape=(4, 8), itemsize=2,
|
||||
num_pe=1, num_cubes=1, num_sips=2, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
|
||||
# ── D5: Allocator dict keyed by (sip, cube, pe) tuples ────────────────
|
||||
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=2,
|
||||
pes_per_cube=4,
|
||||
hbm_bytes_per_cube=_GB,
|
||||
hbm_slices_per_cube=4,
|
||||
tcm_bytes_per_pe=_MB,
|
||||
tcm_scheduler_reserved_bytes=0,
|
||||
sram_bytes_per_cube=_MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_tuple_allocators(
|
||||
num_sips: int = 1, num_cubes: int = 1, num_pe: int = 4,
|
||||
) -> dict[tuple[int, int, int], PEMemAllocator]:
|
||||
return {
|
||||
(s, c, p): PEMemAllocator(
|
||||
rack_id=0, sip_id=s, cube_id=c, pe_id=p, cfg=_CFG,
|
||||
)
|
||||
for s in range(num_sips)
|
||||
for c in range(num_cubes)
|
||||
for p in range(num_pe)
|
||||
}
|
||||
|
||||
|
||||
def test_deploy_tensor_uses_tuple_lookup():
|
||||
"""deploy_tensor(allocators={(sip,cube,pe): alloc, ...}) succeeds."""
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=(4, 16), itemsize=2,
|
||||
num_pe=4, num_cubes=1, target_sip=0,
|
||||
)
|
||||
allocators = _make_tuple_allocators(num_sips=1, num_cubes=1, num_pe=4)
|
||||
handle = deploy_tensor(
|
||||
name="t", shape=(4, 16), dtype="f16",
|
||||
placement=placement, allocators=allocators,
|
||||
)
|
||||
assert len(handle.shards) == 4
|
||||
# Each shard's TensorShard carries structural coords; those coords
|
||||
# must match the shard's ShardSpec (sip, cube, pe).
|
||||
for spec, shard in zip(placement, handle.shards):
|
||||
assert shard.sip == spec.sip
|
||||
assert shard.cube == spec.cube
|
||||
assert shard.pe == spec.pe
|
||||
|
||||
|
||||
def test_runtime_context_allocator_keys_are_tuples(topology):
|
||||
"""After ctx tensor op, ctx._allocators keys are (sip, cube, pe) tuples.
|
||||
|
||||
Ensures D5 migration landed (allocator population + lookup).
|
||||
"""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:0"),
|
||||
correlation_id="test_adr0026_tuple_keys",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
_ = ctx.zeros((1, 16), dtype="f16", dp=dp)
|
||||
|
||||
assert ctx._allocators, "allocators dict should be populated"
|
||||
keys = list(ctx._allocators.keys())
|
||||
assert all(isinstance(k, tuple) and len(k) == 3 for k in keys), (
|
||||
f"_allocators keys must be (sip, cube, pe) tuples; got {keys[:5]}"
|
||||
)
|
||||
|
||||
|
||||
# ── D4 (via regression): no SIP-crossing tensor without set_device ────
|
||||
|
||||
|
||||
def test_create_tensor_on_target_sip_via_set_device(topology):
|
||||
"""torch.ahbm.set_device(1) + DPPolicy(cube=replicate, pe=replicate)
|
||||
→ all shards land on SIP 1 structurally (no post-hoc shifting needed)."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
# Skip the test if topology has only 1 SIP (nothing to verify).
|
||||
n_sips = int(
|
||||
topology.topology_obj.spec.get("system", {})
|
||||
.get("sips", {}).get("count", 1)
|
||||
)
|
||||
if n_sips < 2:
|
||||
pytest.skip("topology has <2 SIPs; set_device(1) not meaningful")
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("sip:1"),
|
||||
correlation_id="test_adr0026_set_device",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
ctx.ahbm.set_device(1)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
t = ctx.zeros((1, 16), dtype="f16", dp=dp)
|
||||
|
||||
assert t._handle is not None
|
||||
assert all(s.sip == 1 for s in t._handle.shards), (
|
||||
f"expected all shards on SIP 1; got {[s.sip for s in t._handle.shards]}"
|
||||
)
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Config-driven multi-device allreduce test application.
|
||||
|
||||
Reads ``ccl.yaml`` + ``topology.yaml``, dynamically loads the kernel
|
||||
module from ``ccl.yaml → module``, and picks the inter-SIP exchange
|
||||
pattern from ``topology.yaml → system.sips.topology``.
|
||||
|
||||
Run directly::
|
||||
|
||||
python -m pytest tests/allreduce_app.py -v -s
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
|
||||
def _sip_topo_dims(sip_topo: str, n_sips: int) -> tuple[int, int]:
|
||||
if sip_topo == "ring_1d":
|
||||
return (0, 0)
|
||||
side = int(round(math.sqrt(n_sips)))
|
||||
if side * side != n_sips:
|
||||
raise ValueError(
|
||||
f"SIP topology '{sip_topo}' requires square n_sips, got {n_sips}"
|
||||
)
|
||||
return (side, side)
|
||||
|
||||
|
||||
def run_allreduce(
|
||||
ctx: Any,
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
*,
|
||||
algorithm: str | None = None,
|
||||
ccl_yaml: str | None = None,
|
||||
) -> dict:
|
||||
"""Config-driven allreduce: read yaml, load kernel, run.
|
||||
|
||||
Everything is resolved from config — no hardcoded kernel imports.
|
||||
"""
|
||||
cfg_all = load_ccl_config(ccl_yaml)
|
||||
cfg = resolve_algorithm_config(cfg_all, algorithm)
|
||||
|
||||
# Dynamic import from ccl.yaml → module
|
||||
algo_module = importlib.import_module(cfg["module"])
|
||||
kernel_fn = algo_module.kernel
|
||||
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
|
||||
|
||||
n_elem = int(cfg.get("n_elem", 8))
|
||||
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
sip_topo = str(
|
||||
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
|
||||
)
|
||||
|
||||
cm = spec["sip"]["cube_mesh"]
|
||||
cube_w = int(cm["w"])
|
||||
cube_h = int(cm["h"])
|
||||
n_cubes = cube_w * cube_h
|
||||
|
||||
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
|
||||
sip_topo_w, sip_topo_h = _sip_topo_dims(sip_topo, n_sips)
|
||||
|
||||
algo_name = cfg.get("algorithm", "allreduce")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"algorithm: {algo_name}")
|
||||
print(f"module: {cfg['module']}")
|
||||
print(f"sip_topology: {sip_topo}")
|
||||
print(f"kernel: {kernel_fn.__name__}")
|
||||
print(f"n_sips: {n_sips}")
|
||||
print(f"n_cubes: {n_cubes}")
|
||||
print(f"n_elem: {n_elem}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
dp = DPPolicy(
|
||||
cube="row_wise", pe="replicate",
|
||||
num_pes=1, num_cubes=n_cubes,
|
||||
)
|
||||
|
||||
tensors = []
|
||||
for sip in range(n_sips):
|
||||
ctx.ahbm.set_device(sip)
|
||||
t = ctx.zeros(
|
||||
(n_cubes, n_elem), dtype="f16", dp=dp,
|
||||
name=f"sip{sip}",
|
||||
)
|
||||
t.copy_(ctx.from_numpy(
|
||||
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
|
||||
))
|
||||
tensors.append(t)
|
||||
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
print(f"[SIP {sip}] input cube0[:4] = {arr[0][:4].tolist()} "
|
||||
f"cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
||||
|
||||
t_start = engine._env.now
|
||||
|
||||
all_pending = []
|
||||
for sip_rank, t in enumerate(tensors):
|
||||
pending = ctx.launch(
|
||||
algo_name, kernel_fn, t,
|
||||
n_elem, cube_w, cube_h, n_sips, sip_rank,
|
||||
sip_topo_kind, sip_topo_w, sip_topo_h,
|
||||
_defer_wait=True,
|
||||
)
|
||||
all_pending.extend(pending)
|
||||
|
||||
for h, sip_id, meta in all_pending:
|
||||
ctx.wait(h, _meta=meta)
|
||||
|
||||
t_end = engine._env.now
|
||||
latency_ns = t_end - t_start
|
||||
print(f"\n[{algo_name} ws={n_sips}] sim latency = "
|
||||
f"{latency_ns:.1f} ns ({latency_ns / 1000:.3f} us)")
|
||||
|
||||
for key, (_, trace) in engine._results.items():
|
||||
if not isinstance(trace, dict):
|
||||
continue
|
||||
total = trace.get("total_ns", 0.0)
|
||||
pe_exec = trace.get("pe_exec_ns", 0.0) or 0.0
|
||||
network = total - pe_exec
|
||||
print(f" [{key}] total={total:.1f} ns "
|
||||
f"pe_exec={pe_exec:.1f} ns network={network:.1f} ns")
|
||||
|
||||
expected = float(n_cubes * sum(range(1, n_sips + 1)))
|
||||
|
||||
print()
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
print(f"[SIP {sip}] output cube0[:4] = {arr[0][:4].tolist()}")
|
||||
print(f"[SIP {sip}] output cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
||||
|
||||
ok_cubes = 0
|
||||
for sip in range(n_sips):
|
||||
arr = tensors[sip].numpy()
|
||||
for cube_id in range(n_cubes):
|
||||
assert np.allclose(
|
||||
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
|
||||
), (
|
||||
f"SIP{sip} cube {cube_id}: "
|
||||
f"got {arr[cube_id][:4]}, expected {expected}"
|
||||
)
|
||||
ok_cubes += 1
|
||||
|
||||
print(f"\n {algo_name} (ws={n_sips}): {ok_cubes} OK")
|
||||
|
||||
return {
|
||||
"expected": expected,
|
||||
"latency_ns": latency_ns,
|
||||
"ok_cubes": ok_cubes,
|
||||
}
|
||||
|
||||
|
||||
# ── pytest entry point ───────────────────────────────────────────────
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
CONFIGS = [
|
||||
pytest.param("intercube_allreduce", "ring_1d", 2, id="ring_2sip"),
|
||||
pytest.param("intercube_allreduce", "torus_2d", 4, id="torus_4sip"),
|
||||
pytest.param("intercube_allreduce", "mesh_2d_no_wrap", 4, id="mesh_4sip"),
|
||||
]
|
||||
|
||||
|
||||
def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm):
|
||||
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
|
||||
with open(TOPOLOGY_PATH) as f:
|
||||
topo_cfg = yaml.safe_load(f)
|
||||
topo_cfg["system"]["sips"]["count"] = n_sips
|
||||
topo_cfg["system"]["sips"]["topology"] = sip_topology
|
||||
topo_path = tmp_path / "topology.yaml"
|
||||
with open(topo_path, "w") as f:
|
||||
yaml.dump(topo_cfg, f, default_flow_style=False)
|
||||
|
||||
ccl_path = Path(__file__).parent.parent / "ccl.yaml"
|
||||
with open(ccl_path) as f:
|
||||
ccl_cfg = yaml.safe_load(f)
|
||||
ccl_cfg["defaults"]["algorithm"] = algorithm
|
||||
tmp_ccl = tmp_path / "ccl.yaml"
|
||||
with open(tmp_ccl, "w") as f:
|
||||
yaml.dump(ccl_cfg, f, default_flow_style=False)
|
||||
|
||||
return str(topo_path), str(tmp_ccl)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm,sip_topology,n_sips", CONFIGS)
|
||||
def test_allreduce(tmp_path, algorithm, sip_topology, n_sips):
|
||||
topo_path, ccl_path = _write_temp_configs(
|
||||
tmp_path, sip_topology, n_sips, algorithm,
|
||||
)
|
||||
topo = resolve_topology(topo_path)
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
spec = topo.topology_obj.spec
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id=f"test_{algorithm}_{sip_topology}",
|
||||
spec=spec,
|
||||
) as ctx:
|
||||
result = run_allreduce(
|
||||
ctx, engine, spec,
|
||||
algorithm=algorithm, ccl_yaml=ccl_path,
|
||||
)
|
||||
assert result["ok_cubes"] > 0
|
||||
@@ -1,150 +0,0 @@
|
||||
"""End-to-end matrix tests for the unified ``ccl_allreduce`` bench.
|
||||
|
||||
Each parametrized case writes a tmp ``ccl.yaml`` overlay that selects a
|
||||
specific (algorithm, world_size, buffer_kind, n_elem) combination, then
|
||||
runs the bench via the CLI and asserts the printed line reports all
|
||||
ranks OK.
|
||||
|
||||
This single test file replaces the per-variant bench tests
|
||||
(test_ccl_allreduce_e2e, test_ccl_mesh_allreduce, test_ccl_tree_allreduce,
|
||||
test_ccl_multicube, test_ccl_multisip).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import kernbench.cli.main as cli_main
|
||||
|
||||
|
||||
CCL_YAML_TEMPLATE = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: {algorithm}
|
||||
buffer_kind: {buffer_kind}
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
{algorithm}:
|
||||
module: {module}
|
||||
topology: {topology}
|
||||
buffer_kind: {buffer_kind}
|
||||
{world_size_line}{n_elem_line}
|
||||
""")
|
||||
|
||||
|
||||
def _write_ccl_yaml(
|
||||
tmp_path,
|
||||
*,
|
||||
algorithm: str,
|
||||
module: str,
|
||||
topology: str,
|
||||
buffer_kind: str = "tcm",
|
||||
world_size: int | None = None,
|
||||
n_elem: int | None = None,
|
||||
) -> str:
|
||||
"""Write a tmp ccl.yaml in tmp_path and return its directory."""
|
||||
ws_line = f" world_size: {world_size}\n" if world_size is not None else ""
|
||||
nel_line = f" n_elem: {n_elem}\n" if n_elem is not None else ""
|
||||
body = CCL_YAML_TEMPLATE.format(
|
||||
algorithm=algorithm,
|
||||
module=module,
|
||||
topology=topology,
|
||||
buffer_kind=buffer_kind,
|
||||
world_size_line=ws_line,
|
||||
n_elem_line=nel_line,
|
||||
)
|
||||
yaml_path = tmp_path / "ccl.yaml"
|
||||
yaml_path.write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
CASES = [
|
||||
# algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws
|
||||
#
|
||||
# Full-system (256-rank, cross-SIP) — run only ONCE (tcm). Buffer
|
||||
# variant differences are purely IPCQ slot placement; the compute path
|
||||
# is identical. Cross-SIP routing is the real thing being verified here.
|
||||
pytest.param(
|
||||
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "tcm", None, 8, 256,
|
||||
id="ring_full_system",
|
||||
marks=pytest.mark.slow,
|
||||
),
|
||||
# Buffer variants at 8-rank (fast — same kernel, different slot space).
|
||||
pytest.param(
|
||||
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "tcm", 8, 32, 8,
|
||||
id="ring_tcm_8",
|
||||
),
|
||||
pytest.param(
|
||||
"ring_allreduce_hbm", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "hbm", 8, 32, 8,
|
||||
id="ring_hbm_8",
|
||||
),
|
||||
pytest.param(
|
||||
"ring_allreduce_sram", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "sram", 8, 32, 8,
|
||||
id="ring_sram_8",
|
||||
),
|
||||
# Multi-cube (16-rank, cross-cube within 1 SIP).
|
||||
pytest.param(
|
||||
"ring_allreduce_16", "kernbench.ccl.algorithms.ring_allreduce",
|
||||
"ring_1d", "tcm", 16, 16, 16,
|
||||
id="ring_multi_cube",
|
||||
),
|
||||
# Mesh + tree algorithms.
|
||||
pytest.param(
|
||||
"mesh_allreduce_4", "kernbench.ccl.algorithms.mesh_allreduce",
|
||||
"mesh_2d", "tcm", 4, 16, 4,
|
||||
id="mesh_2x2",
|
||||
),
|
||||
pytest.param(
|
||||
"tree_allreduce_7", "kernbench.ccl.algorithms.tree_allreduce",
|
||||
"tree_binary", "tcm", 7, 16, 7,
|
||||
id="tree_binary_7",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm,module,topology,buffer_kind,world_size,n_elem,expected_ws",
|
||||
CASES,
|
||||
)
|
||||
def test_ccl_allreduce_matrix(
|
||||
tmp_path, capsys, monkeypatch,
|
||||
algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws,
|
||||
):
|
||||
"""Each (algorithm × buffer × world_size) combo passes through the
|
||||
unified bench and yields all ranks OK."""
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..")
|
||||
)
|
||||
yaml_dir = _write_ccl_yaml(
|
||||
tmp_path,
|
||||
algorithm=algorithm,
|
||||
module=module,
|
||||
topology=topology,
|
||||
buffer_kind=buffer_kind,
|
||||
world_size=world_size,
|
||||
n_elem=n_elem,
|
||||
)
|
||||
monkeypatch.chdir(yaml_dir)
|
||||
rc = cli_main.main([
|
||||
"run",
|
||||
"--topology", os.path.join(project_root, "topology.yaml"),
|
||||
"--bench", "ccl_allreduce",
|
||||
"--verify-data",
|
||||
])
|
||||
assert rc == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "FAIL" not in out, f"unexpected FAIL in output:\n{out}"
|
||||
assert f"{algorithm} (ws={expected_ws}): {expected_ws} OK" in out, (
|
||||
f"expected '{algorithm} (ws={expected_ws}): {expected_ws} OK' "
|
||||
f"in output:\n{out}"
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Tests for IPCQ deadlock detection (ADR-0023 D14 F3)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import simpy
|
||||
|
||||
from kernbench.ccl import diagnostics
|
||||
from kernbench.common.ipcq_types import (
|
||||
IpcqEndpoint,
|
||||
IpcqInitEntry,
|
||||
IpcqRecvCmd,
|
||||
IpcqRequest,
|
||||
)
|
||||
from kernbench.components.builtin.pe_ipcq import PeIpcqComponent
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTxn:
|
||||
request: Any
|
||||
done: simpy.Event
|
||||
result_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _make_isolated_pe_ipcq(env):
|
||||
node = Node(
|
||||
id="sip0.cube0.pe0.pe_ipcq", kind="pe_ipcq",
|
||||
impl="builtin.pe_ipcq", attrs={}, pos_mm=None,
|
||||
)
|
||||
comp = PeIpcqComponent(node, ctx=None)
|
||||
comp.in_ports["host"] = simpy.Store(env)
|
||||
comp.out_ports["sip0.cube0.pe0.pe_dma"] = simpy.Store(env)
|
||||
comp.start(env)
|
||||
|
||||
peer_credit = simpy.Store(env)
|
||||
ep = IpcqEndpoint(
|
||||
sip=0, cube=0, pe=1, buffer_kind="tcm",
|
||||
rx_base_pa=0x10_000, rx_base_va=0,
|
||||
n_slots=4, slot_size=4096,
|
||||
)
|
||||
init_msg = IpcqInitMsg(
|
||||
correlation_id="t", request_id="t",
|
||||
target_sips=(0,), target_cubes=(0,), target_pe=0,
|
||||
entries=(IpcqInitEntry(
|
||||
direction="W", peer=ep,
|
||||
my_rx_base_pa=0x40_000, my_rx_base_va=0,
|
||||
n_slots=4, slot_size=4096,
|
||||
peer_credit_store=peer_credit,
|
||||
),),
|
||||
backpressure_mode="sleep",
|
||||
buffer_kind="tcm",
|
||||
credit_size_bytes=16,
|
||||
)
|
||||
done = env.event()
|
||||
comp.in_ports["host"].put(_FakeTxn(request=init_msg, done=done))
|
||||
env.run(until=done)
|
||||
return comp
|
||||
|
||||
|
||||
def test_pointer_dump_includes_blocked_state():
|
||||
"""A blocked recv should still be visible in the pointer dump."""
|
||||
env = simpy.Environment()
|
||||
comp = _make_isolated_pe_ipcq(env)
|
||||
|
||||
# Issue a recv that will block (no data has arrived)
|
||||
recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r1")
|
||||
req = IpcqRequest(command=recv_cmd, done=env.event())
|
||||
comp.in_ports["host"].put(req)
|
||||
env.run(until=10)
|
||||
assert not req.done.triggered
|
||||
|
||||
# Pointer dump should show my_tail=0 and peer_head_cache=0
|
||||
# We need to use the engine API but for an isolated component, just call directly
|
||||
class FakeEngine:
|
||||
_components = {"sip0.cube0.pe0.pe_ipcq": comp}
|
||||
|
||||
dump = diagnostics.pointer_dump(FakeEngine())
|
||||
assert "my_tail=0" in dump
|
||||
assert "peer_head_cache=0" in dump
|
||||
|
||||
|
||||
def test_deadlock_detection_recv_without_send():
|
||||
"""A recv with no matching sender → SimPy schedule empties → engine
|
||||
raises ``IpcqDeadlock`` with a pointer dump.
|
||||
"""
|
||||
from kernbench.ccl.diagnostics import IpcqDeadlock
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
def deadlock_kernel(t_ptr, n_elem, tl):
|
||||
# Every PE just receives, no sends → no one delivers → deadlock
|
||||
tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
def run(torch):
|
||||
torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1,
|
||||
),
|
||||
name="dl_in",
|
||||
)
|
||||
torch.launch("dl", deadlock_kernel, a, 8)
|
||||
|
||||
with pytest.raises(IpcqDeadlock):
|
||||
run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""Tests for CCL diagnostics: trace + pointer dump (ADR-0023 D14)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from kernbench.ccl import diagnostics
|
||||
|
||||
|
||||
# ── trace toggle ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_trace_disabled_by_default(monkeypatch):
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE", raising=False)
|
||||
diagnostics.reload_trace_setting()
|
||||
assert diagnostics.trace_enabled() is False
|
||||
|
||||
|
||||
def test_trace_enabled_via_env(monkeypatch):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
assert diagnostics.trace_enabled() is True
|
||||
|
||||
|
||||
def test_trace_record_send(monkeypatch, capsys):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
diagnostics.log_send(t_ns=100.0, sender="sip0.cube0.pe0",
|
||||
direction="E", nbytes=64, sender_seq=0)
|
||||
out = capsys.readouterr().out
|
||||
assert "send" in out
|
||||
assert "sip0.cube0.pe0" in out
|
||||
assert "dir=E" in out
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
|
||||
diagnostics.reload_trace_setting()
|
||||
|
||||
|
||||
def test_trace_record_recv(monkeypatch, capsys):
|
||||
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
|
||||
diagnostics.reload_trace_setting()
|
||||
diagnostics.log_recv(t_ns=200.0, receiver="sip0.cube0.pe1",
|
||||
direction="W", nbytes=64)
|
||||
out = capsys.readouterr().out
|
||||
assert "recv" in out
|
||||
assert "sip0.cube0.pe1" in out
|
||||
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
|
||||
diagnostics.reload_trace_setting()
|
||||
|
||||
|
||||
# ── pointer dump ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_pointer_dump_format():
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
from kernbench.ccl.install import (
|
||||
install_ipcq, load_ccl_config, resolve_algorithm_config,
|
||||
)
|
||||
|
||||
topo = resolve_topology("topology.yaml").topology_obj
|
||||
engine = GraphEngine(topo, enable_data=True)
|
||||
cfg = resolve_algorithm_config(load_ccl_config(), name="ring_allreduce_tcm")
|
||||
install_ipcq(engine, topo.spec, cfg)
|
||||
|
||||
dump = diagnostics.pointer_dump(engine)
|
||||
# 8 ranks × 2 directions = 16 lines (plus 8 PE headers)
|
||||
assert "sip0.cube0.pe0" in dump
|
||||
assert "E:" in dump
|
||||
assert "W:" in dump
|
||||
assert "my_head=" in dump
|
||||
assert "peer_tail_cache=" in dump
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Validate the hello-world example from docs/ccl-author-guide.md.
|
||||
|
||||
This is the simplest possible CCL kernel — each PE sends its tile E
|
||||
and receives a tile from W. After running, each rank's slice should
|
||||
contain the data of the previous rank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.algorithms import hello_send
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def test_hello_send_4_ranks_mock():
|
||||
n_elem = 8
|
||||
inputs = [np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(4)]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=hello_send.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem,),
|
||||
)
|
||||
|
||||
# rank r should have rank (r-1) % 4's data
|
||||
for r in range(4):
|
||||
prev = inputs[(r - 1) % 4]
|
||||
assert np.array_equal(outputs[r], prev), f"rank {r}: got {outputs[r]}"
|
||||
|
||||
|
||||
def test_hello_send_via_simpy_runner():
|
||||
"""Same but through real SimPy + IPCQ."""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
n_elem = 8
|
||||
world_size = 8
|
||||
|
||||
def run(torch):
|
||||
# World size for this hello test is 8 (one cube). ccl.yaml no
|
||||
# longer carries a default world_size — pass it explicitly.
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=world_size,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, world_size * n_elem), dtype="f16",
|
||||
dp=DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1,
|
||||
),
|
||||
name="hello_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = n_elem * 2
|
||||
for r in range(world_size):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16))
|
||||
|
||||
torch.launch("hello_send", hello_send.kernel, a, n_elem)
|
||||
|
||||
# Each rank should hold the previous rank's data after the round
|
||||
for r in range(world_size):
|
||||
arr = store.read("hbm", base + r * nbytes, shape=(n_elem,), dtype="f16")
|
||||
prev_value = float(((r - 1) % world_size) + 1)
|
||||
assert np.allclose(arr, prev_value), f"rank {r}: got {arr}, expected {prev_value}"
|
||||
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.ccl.install import (
|
||||
install_ipcq,
|
||||
linear_rank_to_pe,
|
||||
load_ccl_config,
|
||||
resolve_algorithm_config,
|
||||
@@ -26,28 +25,14 @@ def test_resolve_algorithm_config_default():
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg)
|
||||
assert merged["algorithm"] == cfg["defaults"]["algorithm"]
|
||||
# ccl.yaml no longer carries defaults.world_size — backend derives
|
||||
# it from topology.yaml at install time. Just check the field is
|
||||
# absent here (verified per-test where install_ipcq is called).
|
||||
assert "world_size" not in merged or merged["world_size"] >= 1
|
||||
|
||||
|
||||
def test_resolve_algorithm_config_override():
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_hbm")
|
||||
assert merged["algorithm"] == "ring_allreduce_hbm"
|
||||
assert merged["buffer_kind"] == "hbm" # algo override
|
||||
# defaults still apply
|
||||
assert merged["n_slots"] == cfg["defaults"]["n_slots"]
|
||||
|
||||
|
||||
def test_linear_rank_to_pe():
|
||||
engine, topo = _engine()
|
||||
spec = topo.spec
|
||||
# Cube 0 of SIP 0
|
||||
assert linear_rank_to_pe(0, spec) == (0, 0, 0)
|
||||
assert linear_rank_to_pe(7, spec) == (0, 0, 7)
|
||||
# Should not exceed total PE count
|
||||
pes_per_sip = (
|
||||
spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"]
|
||||
* spec["cube"]["pe_layout"]["pe_per_corner"]
|
||||
@@ -56,105 +41,3 @@ def test_linear_rank_to_pe():
|
||||
sips = spec["system"]["sips"]["count"]
|
||||
total = sips * pes_per_sip
|
||||
assert total >= 8
|
||||
|
||||
|
||||
def test_install_ipcq_neighbors_correct():
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
# Force a single-cube 8-rank install for the assertions below.
|
||||
merged["world_size"] = 8
|
||||
plan = install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
assert plan["world_size"] == 8
|
||||
assert plan["buffer_kind"] == "tcm"
|
||||
|
||||
# Each rank should have E and W entries
|
||||
for r, nbrs in plan["neighbor_table"].items():
|
||||
assert "E" in nbrs
|
||||
assert "W" in nbrs
|
||||
|
||||
# Inspect installed PE_IPCQ for rank 0
|
||||
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
qp_e = ipcq.queue_pairs["E"]
|
||||
qp_w = ipcq.queue_pairs["W"]
|
||||
assert qp_e["peer"].pe == 1 # rank 0's E neighbor is rank 1
|
||||
assert qp_w["peer"].pe == 7 # rank 0's W neighbor is rank 7
|
||||
# rx_base addresses should be unique
|
||||
assert qp_e["my_rx_base_pa"] != qp_w["my_rx_base_pa"]
|
||||
|
||||
|
||||
def test_install_ipcq_credit_stores_wired():
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 8
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
# rank 0 (pe0) sending E goes to rank 1 (pe1)
|
||||
# rank 0's peer_credit_store on E direction should equal rank 1's credit_inbox
|
||||
pe0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
pe1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
|
||||
qp_e = pe0.queue_pairs["E"]
|
||||
assert qp_e["peer_credit_store"] is pe1.credit_inbox
|
||||
|
||||
|
||||
# ── ADR-0025 D1: reverse_direction opposite-preference ───────────────
|
||||
|
||||
|
||||
def test_reverse_direction_opposite_preference_2rank_ring():
|
||||
"""ADR-0025 D1: In a 2-rank bidirectional ring both E and W point to the
|
||||
same peer; reverse_direction must pick the OPPOSITE direction (W for E,
|
||||
E for W) so rx_base targets the semantically-correct slot.
|
||||
|
||||
Concretely: rank 0 sending via E to rank 1 must target rank 1's W-rx
|
||||
buffer (not rank 1's E-rx), because rank 1's kernel recv(W) reads from
|
||||
its W-rx.
|
||||
"""
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 2
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
|
||||
rank1_e_rx = ipcq1.queue_pairs["E"]["my_rx_base_pa"]
|
||||
rank1_w_rx = ipcq1.queue_pairs["W"]["my_rx_base_pa"]
|
||||
|
||||
qp0_e = ipcq0.queue_pairs["E"]
|
||||
qp0_w = ipcq0.queue_pairs["W"]
|
||||
|
||||
# rank 0's E entry should target rank 1's W-rx (opposite), NOT rank 1's E-rx.
|
||||
assert qp0_e["peer"].rx_base_pa == rank1_w_rx, (
|
||||
f"expected rank 0's E peer.rx_base_pa == rank 1's W-rx ({rank1_w_rx:#x}), "
|
||||
f"got {qp0_e['peer'].rx_base_pa:#x} (matches E-rx: {rank1_e_rx:#x}) — "
|
||||
f"reverse_direction picked same-label instead of opposite"
|
||||
)
|
||||
# rank 0's W entry should target rank 1's E-rx (opposite).
|
||||
assert qp0_w["peer"].rx_base_pa == rank1_e_rx
|
||||
|
||||
|
||||
def test_reverse_direction_opposite_preference_4rank_ring_sanity():
|
||||
"""ADR-0025 D1 sanity: ws>=3 ring. E and W have distinct peers, so
|
||||
opposite-preference produces same result as old dict-order first-match.
|
||||
This test should PASS both under current and post-fix code.
|
||||
"""
|
||||
engine, topo = _engine()
|
||||
cfg = load_ccl_config()
|
||||
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
|
||||
merged["world_size"] = 4
|
||||
install_ipcq(engine, topo.spec, merged)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
|
||||
ipcq3 = engine._components["sip0.cube0.pe3.pe_ipcq"]
|
||||
|
||||
# rank 0 E → rank 1 → rank 1's W-rx
|
||||
qp0_e = ipcq0.queue_pairs["E"]
|
||||
assert qp0_e["peer"].rx_base_pa == ipcq1.queue_pairs["W"]["my_rx_base_pa"]
|
||||
# rank 0 W → rank 3 (last in ring) → rank 3's E-rx
|
||||
qp0_w = ipcq0.queue_pairs["W"]
|
||||
assert qp0_w["peer"].rx_base_pa == ipcq3.queue_pairs["E"]["my_rx_base_pa"]
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Tests for the mock CCL runtime (ADR-0023 D15)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.algorithms import ring_allreduce
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def test_ring_allreduce_4_ranks():
|
||||
"""Run the ring all-reduce kernel under the mock runtime, no SimPy."""
|
||||
n_elem = 8
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(4)
|
||||
]
|
||||
expected = sum(inputs) # [10, 10, ..., 10]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 4),
|
||||
)
|
||||
|
||||
assert len(outputs) == 4
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_ring_allreduce_8_ranks():
|
||||
n_elem = 16
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(8)
|
||||
]
|
||||
expected = sum(inputs) # [36, 36, ...]
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=8,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 8),
|
||||
)
|
||||
for r in range(8):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_ring_allreduce_random_data():
|
||||
n_elem = 32
|
||||
rng = np.random.default_rng(42)
|
||||
inputs = [rng.standard_normal(n_elem).astype(np.float16) for _ in range(4)]
|
||||
expected = sum(inputs)
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=ring_allreduce.kernel,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem, 4),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
def test_mock_runtime_invalid_direction_raises():
|
||||
"""A kernel that uses an unsupported direction should raise."""
|
||||
import pytest
|
||||
|
||||
def bad_kernel(t_ptr, n_elem, tl):
|
||||
tl.send(dir="N", src_addr=0, nbytes=2, shape=(1,), dtype="f16", space="hbm")
|
||||
|
||||
inputs = [np.array([1.0], dtype=np.float16) for _ in range(2)]
|
||||
with pytest.raises(Exception):
|
||||
run_kernel_in_mock(
|
||||
kernel_fn=bad_kernel,
|
||||
world_size=2,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(1,),
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""CCL performance validation tests (ADR-0023 D13 T5).
|
||||
|
||||
Sanity-checks the simulated latency of the unified ``ccl_allreduce`` bench.
|
||||
|
||||
Uses 8-rank (single cube) for all buffer variants — the latency model
|
||||
is topology-aware, so buffer_kind differences are visible even at small
|
||||
scale. Full-system (256-rank) cross-SIP latency is covered by the
|
||||
``test_ccl_allreduce_matrix[ring_full_system]`` slow test.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
|
||||
def _engine_factory(topology, device):
|
||||
return GraphEngine(getattr(topology, "topology_obj", topology), enable_data=True)
|
||||
|
||||
|
||||
def _run_8rank(algorithm: str, buffer_kind: str = "tcm") -> float:
|
||||
"""Run an 8-rank ring via the unified bench with a tmp ccl.yaml overlay.
|
||||
Returns simulated kernel total_ns."""
|
||||
import tempfile
|
||||
|
||||
body = f"""\
|
||||
defaults:
|
||||
algorithm: {algorithm}
|
||||
buffer_kind: {buffer_kind}
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
{algorithm}:
|
||||
module: kernbench.ccl.algorithms.ring_allreduce
|
||||
topology: ring_1d
|
||||
buffer_kind: {buffer_kind}
|
||||
world_size: 8
|
||||
n_elem: 32
|
||||
"""
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with open(os.path.join(tmp, "ccl.yaml"), "w") as f:
|
||||
f.write(body)
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp)
|
||||
try:
|
||||
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
|
||||
bench_mod = importlib.import_module("benches.ccl_allreduce")
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=bench_mod.run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=_engine_factory,
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
assert result.completion.ok, f"{algorithm} did not complete"
|
||||
last_kernel = None
|
||||
for tr in (result.traces or []):
|
||||
if tr.get("phase") == "kernel":
|
||||
last_kernel = tr
|
||||
assert last_kernel is not None, f"{algorithm} produced no kernel trace"
|
||||
return float(last_kernel.get("total_ns", 0.0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("buffer_kind", ["tcm", "hbm", "sram"])
|
||||
def test_ccl_latency_positive(buffer_kind):
|
||||
"""Every buffer kind must produce a positive simulated latency."""
|
||||
algo = f"ring_allreduce_{buffer_kind}"
|
||||
ns = _run_8rank(algo, buffer_kind)
|
||||
assert ns > 0
|
||||
|
||||
|
||||
def test_ccl_latency_under_reasonable_bound():
|
||||
"""8-rank ring all-reduce (tile=32 f16) should finish well under 1ms."""
|
||||
ns = _run_8rank("ring_allreduce_tcm", "tcm")
|
||||
assert ns < 1_000_000 # < 1 ms simulated
|
||||
@@ -0,0 +1,119 @@
|
||||
"""End-to-end distributed test for intercube allreduce.
|
||||
|
||||
Exercises the full process-group path:
|
||||
dist.init_process_group(backend="ahbm")
|
||||
→ mp.spawn(nprocs=n_sips)
|
||||
→ each worker: set_device → allocate → fill → dist.all_reduce → verify
|
||||
|
||||
This is the same flow a real DDP training script would use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
N_CUBES = 16
|
||||
N_ELEM = 8
|
||||
|
||||
|
||||
def _write_ccl_yaml(tmp_path) -> str:
|
||||
body = textwrap.dedent("""\
|
||||
defaults:
|
||||
algorithm: intercube_allreduce
|
||||
buffer_kind: tcm
|
||||
backpressure: sleep
|
||||
n_slots: 4
|
||||
slot_size: 4096
|
||||
vc_chunk_size: 256
|
||||
ipcq_credit_size_bytes: 16
|
||||
|
||||
algorithms:
|
||||
intercube_allreduce:
|
||||
module: kernbench.ccl.algorithms.intercube_allreduce
|
||||
topology: none
|
||||
buffer_kind: tcm
|
||||
n_elem: 8
|
||||
root_cube: 15
|
||||
""")
|
||||
(tmp_path / "ccl.yaml").write_text(body)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def _worker(rank: int, n_sips: int, torch) -> None:
|
||||
"""Per-SIP worker: allocate, fill, all_reduce, verify."""
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
torch.ahbm.set_device(rank)
|
||||
|
||||
dp = DPPolicy(
|
||||
cube="row_wise", pe="replicate",
|
||||
num_pes=1, num_cubes=N_CUBES,
|
||||
)
|
||||
tensor = torch.zeros(
|
||||
(N_CUBES, N_ELEM), dtype="f16", dp=dp,
|
||||
name=f"sip{rank}",
|
||||
)
|
||||
|
||||
init_arr = np.full((N_CUBES, N_ELEM), float(rank + 1), dtype=np.float16)
|
||||
tensor.copy_(torch.from_numpy(init_arr))
|
||||
|
||||
print(f"[SIP {rank}] input cube0[:4] = {tensor.numpy()[0][:4].tolist()}")
|
||||
|
||||
torch.distributed.all_reduce(tensor, op="sum")
|
||||
|
||||
arr = tensor.numpy()
|
||||
expected = float(N_CUBES * sum(range(1, n_sips + 1)))
|
||||
|
||||
print(f"[SIP {rank}] output cube0[:4] = {arr[0][:4].tolist()}")
|
||||
print(f"[SIP {rank}] output cube15[:4] = {arr[15][:4].tolist()}")
|
||||
|
||||
for cube_id in range(N_CUBES):
|
||||
assert np.allclose(arr[cube_id], expected, rtol=1e-1, atol=1e-1), (
|
||||
f"SIP{rank} cube {cube_id}: "
|
||||
f"got {arr[cube_id][:4]}, expected {expected}"
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n intercube_allreduce (ws={n_sips}): "
|
||||
f"{n_sips * N_CUBES} OK")
|
||||
|
||||
|
||||
def test_distributed_intercube_allreduce(tmp_path, monkeypatch):
|
||||
"""Full distributed path: init_process_group → mp.spawn → all_reduce."""
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
monkeypatch.chdir(_write_ccl_yaml(tmp_path))
|
||||
|
||||
topo = resolve_topology(str(TOPOLOGY_PATH))
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
spec = topo.topology_obj.spec
|
||||
n_sips = int(spec["system"]["sips"]["count"])
|
||||
|
||||
with RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="dist_intercube_ar",
|
||||
spec=spec,
|
||||
) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
|
||||
assert ctx.distributed.get_world_size() == n_sips
|
||||
|
||||
t_start = engine._env.now
|
||||
|
||||
ctx.multiprocessing.spawn(
|
||||
_worker, args=(n_sips, ctx), nprocs=n_sips,
|
||||
)
|
||||
|
||||
t_end = engine._env.now
|
||||
print(f"\n[distributed] sim latency = "
|
||||
f"{t_end - t_start:.1f} ns ({(t_end - t_start) / 1000:.3f} us)")
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests for configure_sfr_intercube_multisip neighbor table wiring.
|
||||
|
||||
Verifies that IPCQ neighbor tables are correctly installed for
|
||||
intercube (pe0, 4×4 mesh N/S/E/W) + inter-SIP (pe0, all cubes,
|
||||
global_E/global_W) communication.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
||||
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
N_CUBES = 16
|
||||
|
||||
|
||||
def _engine_and_spec():
|
||||
topo = resolve_topology(str(TOPOLOGY_PATH))
|
||||
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||
return engine, topo.topology_obj.spec
|
||||
|
||||
|
||||
def _merged_cfg():
|
||||
cfg = load_ccl_config()
|
||||
return resolve_algorithm_config(cfg, name="intercube_allreduce")
|
||||
|
||||
|
||||
class TestConfigureSfrNeighborTables:
|
||||
def test_world_size_and_rank_to_pe(self):
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
plan = configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
n_sips = int(spec["system"]["sips"]["count"])
|
||||
assert plan["world_size"] == n_sips * N_CUBES
|
||||
assert len(plan["rank_to_pe"]) == n_sips * N_CUBES
|
||||
for pe_idx, (sip, cube, pe) in enumerate(plan["rank_to_pe"]):
|
||||
assert pe == 0, f"pe_idx {pe_idx}: pe must be 0, got {pe}"
|
||||
|
||||
def test_corner_cube0_has_E_and_S_only(self):
|
||||
"""Cube 0 (row=0, col=0) is NW corner: only E and S neighbors."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert "E" in qp, "cube 0 must have E neighbor"
|
||||
assert "S" in qp, "cube 0 must have S neighbor"
|
||||
assert "W" not in qp, "cube 0 (col=0) must NOT have W neighbor"
|
||||
assert "N" not in qp, "cube 0 (row=0) must NOT have N neighbor"
|
||||
assert qp["E"]["peer"].cube == 1
|
||||
assert qp["S"]["peer"].cube == 4
|
||||
|
||||
def test_interior_cube5_has_all_four(self):
|
||||
"""Cube 5 (row=1, col=1) is interior: N/S/E/W all present."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq = engine._components["sip0.cube5.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert qp["N"]["peer"].cube == 1
|
||||
assert qp["S"]["peer"].cube == 9
|
||||
assert qp["E"]["peer"].cube == 6
|
||||
assert qp["W"]["peer"].cube == 4
|
||||
|
||||
def test_root_cube15_has_inter_sip(self):
|
||||
"""Cube 15 (root, SE corner) has N, W + global_E/global_W."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
ipcq0 = engine._components["sip0.cube15.pe0.pe_ipcq"]
|
||||
qp0 = ipcq0.queue_pairs
|
||||
assert "N" in qp0
|
||||
assert "W" in qp0
|
||||
assert "E" not in qp0, "cube 15 (col=3) must NOT have E"
|
||||
assert "S" not in qp0, "cube 15 (row=3) must NOT have S"
|
||||
assert "global_E" in qp0, "root cube must have global_E"
|
||||
assert "global_W" in qp0, "root cube must have global_W"
|
||||
assert qp0["global_E"]["peer"].sip == 1
|
||||
assert qp0["global_E"]["peer"].cube == 15
|
||||
|
||||
ipcq1 = engine._components["sip1.cube15.pe0.pe_ipcq"]
|
||||
qp1 = ipcq1.queue_pairs
|
||||
assert qp1["global_E"]["peer"].sip == 0
|
||||
assert qp1["global_E"]["peer"].cube == 15
|
||||
|
||||
def test_all_cubes_have_inter_sip(self):
|
||||
"""ALL cubes (not just root) are wired for inter-SIP."""
|
||||
engine, spec = _engine_and_spec()
|
||||
cfg = _merged_cfg()
|
||||
configure_sfr_intercube_multisip(engine, spec, cfg)
|
||||
|
||||
root_cube = int(cfg.get("root_cube", N_CUBES - 1))
|
||||
for cube_id in range(N_CUBES):
|
||||
ipcq = engine._components[f"sip0.cube{cube_id}.pe0.pe_ipcq"]
|
||||
qp = ipcq.queue_pairs
|
||||
assert "global_E" in qp, (
|
||||
f"sip0.cube{cube_id}.pe0 missing global_E"
|
||||
)
|
||||
assert "global_W" in qp, (
|
||||
f"sip0.cube{cube_id}.pe0 missing global_W"
|
||||
)
|
||||
if cube_id == root_cube:
|
||||
assert qp["global_E"]["peer"].sip != 0, (
|
||||
f"root cube {root_cube} global_E must point to another SIP"
|
||||
)
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tests for recv_mode='copy_to_dst' (ADR-0023 D9.5)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_recv_copy_to_dst_via_simpy_runner():
|
||||
"""Run a kernel that uses tl.recv(..., dst_addr=..., dst_space=...).
|
||||
Verify the data is moved to the dst location after recv.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
def kernel(t_ptr, n_elem, dst_buf_addr, tl):
|
||||
rank = tl.program_id(axis=0)
|
||||
ws = tl.num_programs(axis=0)
|
||||
nbytes = n_elem * 2
|
||||
# Each PE sends own data, then recv into a custom dst slot
|
||||
current = TensorHandle(
|
||||
id="loc", addr=t_ptr + rank * nbytes,
|
||||
shape=(n_elem,), dtype="f16",
|
||||
nbytes=nbytes, data=None, space="hbm",
|
||||
)
|
||||
tl.send(dir="E", src=current)
|
||||
# copy_to_dst: move into a per-rank scratch HBM addr
|
||||
recv = tl.recv(
|
||||
dir="W", shape=(n_elem,), dtype="f16",
|
||||
dst_addr=dst_buf_addr + rank * nbytes,
|
||||
dst_space="hbm",
|
||||
)
|
||||
# Sanity: recv handle should now point to our dst addr
|
||||
assert recv.addr == dst_buf_addr + rank * nbytes
|
||||
assert recv.space == "hbm"
|
||||
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
def run(torch):
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1,
|
||||
),
|
||||
name="copy_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = 8 * 2
|
||||
for r in range(8):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((8,), float(r + 1), dtype=np.float16))
|
||||
|
||||
# Use a separate dst region (synthetic addresses)
|
||||
dst_buf = 0xC0FFEE_0000
|
||||
torch.launch("ring_allreduce_tcm", kernel, a, 8, dst_buf)
|
||||
|
||||
# After the kernel, dst_buf + r*16 should contain rank (r-1)%8's data
|
||||
for r in range(8):
|
||||
arr = store.read("hbm", dst_buf + r * nbytes, shape=(8,), dtype="f16")
|
||||
expected = float(((r - 1) % 8) + 1)
|
||||
assert np.allclose(arr, expected), f"rank {r}: got {arr}, expected {expected}"
|
||||
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -48,8 +48,8 @@ def test_from_numpy_creates_host_tensor():
|
||||
assert h._handle is None
|
||||
# Submit a no-op so run_bench has at least one handle.
|
||||
torch.zeros((1, 8), dtype="f16",
|
||||
dp=DPPolicy(sip="replicate", cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1),
|
||||
dp=DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1),
|
||||
name="dummy")
|
||||
|
||||
_run_with(body)
|
||||
@@ -63,8 +63,8 @@ def test_copy_and_numpy_single_pe():
|
||||
a single-PE (no real sharding) tensor."""
|
||||
|
||||
def body(torch):
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1)
|
||||
t = torch.zeros((1, 16), dtype="f16", dp=dp, name="t")
|
||||
src = np.arange(16, dtype=np.float16).reshape(1, 16)
|
||||
t.copy_(torch.from_numpy(src))
|
||||
@@ -83,8 +83,8 @@ def test_copy_and_numpy_multi_pe_column_wise():
|
||||
|
||||
def body(torch):
|
||||
n_pe = 8
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=n_pe)
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise",
|
||||
num_cubes=1, num_pes=n_pe)
|
||||
t = torch.zeros((1, n_pe * 4), dtype="f16", dp=dp, name="t")
|
||||
src = np.arange(n_pe * 4, dtype=np.float16).reshape(1, n_pe * 4)
|
||||
t.copy_(torch.from_numpy(src))
|
||||
@@ -107,8 +107,8 @@ def test_copy_and_numpy_multi_cube():
|
||||
n_pe_per_cube = 8
|
||||
n_cubes = 2
|
||||
total = n_cubes * n_pe_per_cube # 16
|
||||
dp = DPPolicy(sip="replicate", cube="column_wise", pe="column_wise",
|
||||
num_sips=1, num_cubes=n_cubes)
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise",
|
||||
num_cubes=n_cubes)
|
||||
t = torch.zeros((1, total * 4), dtype="f16", dp=dp, name="t")
|
||||
src = np.arange(total * 4, dtype=np.float16).reshape(1, total * 4)
|
||||
t.copy_(torch.from_numpy(src))
|
||||
@@ -126,8 +126,8 @@ def test_copy_shape_mismatch_raises():
|
||||
"""copy_ with mismatched shapes raises ValueError."""
|
||||
|
||||
def body(torch):
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1)
|
||||
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
|
||||
src = np.zeros((1, 16), dtype=np.float16)
|
||||
with pytest.raises(ValueError, match="copy_ shape mismatch"):
|
||||
@@ -143,8 +143,8 @@ def test_setitem_getitem_single_pe():
|
||||
"""Scalar and slice assignment on a single-PE tensor round-trips."""
|
||||
|
||||
def body(torch):
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate",
|
||||
num_sips=1, num_cubes=1, num_pes=1)
|
||||
dp = DPPolicy(cube="replicate", pe="replicate",
|
||||
num_cubes=1, num_pes=1)
|
||||
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
|
||||
|
||||
# Scalar broadcast
|
||||
@@ -169,8 +169,8 @@ def test_setitem_getitem_multi_pe_shard_aligned():
|
||||
def body(torch):
|
||||
n_pe = 8
|
||||
n_elem = 4 # per shard
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=n_pe)
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise",
|
||||
num_cubes=1, num_pes=n_pe)
|
||||
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
|
||||
|
||||
# Write each shard with its rank value
|
||||
@@ -197,8 +197,8 @@ def test_setitem_cross_shard_raises():
|
||||
def body(torch):
|
||||
n_pe = 4
|
||||
n_elem = 4
|
||||
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1, num_pes=n_pe)
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise",
|
||||
num_cubes=1, num_pes=n_pe)
|
||||
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
|
||||
with pytest.raises(NotImplementedError, match="spans multiple shards"):
|
||||
t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8)
|
||||
|
||||
+90
-127
@@ -1,157 +1,120 @@
|
||||
"""Tests for SIP-level tensor parallelism.
|
||||
"""Tests for SIP-level tensor parallelism — ADR-0026 structural model.
|
||||
|
||||
Validates:
|
||||
SP1. DPPolicy accepts sip field (default "replicate", backward compat)
|
||||
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips
|
||||
SP3. sip="row_wise": tensor M-axis split across SIPs
|
||||
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets
|
||||
SP5. sip="replicate": all SIPs get full copy (existing behavior)
|
||||
SP6. PE_CPU sets num_programs from shard count per cube
|
||||
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology
|
||||
DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is
|
||||
now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)``
|
||||
emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling
|
||||
the resolver once per SIP (typically driven by the ADR-0024 launcher, one
|
||||
worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``).
|
||||
|
||||
Covered here:
|
||||
SP1. ``target_sip`` stamps every shard.
|
||||
SP2. Two-SIP placement: union of two resolver calls covers the whole
|
||||
tensor K-axis when the combined bench treats them as column-split.
|
||||
SP3. Same for row-wise.
|
||||
SP4. Cube + PE sharding within a SIP remains correct across SIPs.
|
||||
SP5. PE_CPU num_programs contract (unchanged by ADR-0026).
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from __future__ import annotations
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||||
|
||||
|
||||
# ── SP1. DPPolicy sip field ──────────────────────────────────────────
|
||||
# ── SP1. target_sip stamps shards ────────────────────────────────────
|
||||
|
||||
|
||||
def test_dp_policy_sip_default_replicate():
|
||||
"""DPPolicy without sip= defaults to 'replicate'."""
|
||||
def test_target_sip_stamps_all_shards():
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "replicate"
|
||||
|
||||
|
||||
def test_dp_policy_sip_column_wise():
|
||||
"""DPPolicy accepts sip='column_wise'."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
assert dp.sip == "column_wise"
|
||||
|
||||
|
||||
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sip_column_wise_splits_across_sips():
|
||||
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
num_pe=8, num_cubes=1, target_sip=3,
|
||||
)
|
||||
# 2 SIPs × 1 cube × 8 PEs = 16 shards
|
||||
assert len(shards) == 16
|
||||
|
||||
# SIP0 shards: first half of K (0 to K//2)
|
||||
# SIP1 shards: second half of K (K//2 to K)
|
||||
total_bytes = 128 * 256 * 2 # 64KB
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0 offsets start at 0
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
# SIP1 offsets start at half
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
# Total coverage
|
||||
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
|
||||
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
|
||||
assert all(s.sip == 3 for s in shards)
|
||||
assert all(0 <= s.pe < 8 for s in shards)
|
||||
assert all(s.cube == 0 for s in shards)
|
||||
|
||||
|
||||
# ── SP3. sip="row_wise" ──────────────────────────────────────────────
|
||||
# ── SP2. column-wise placement composed across two SIPs ─────────────
|
||||
|
||||
|
||||
def test_sip_row_wise_splits_across_sips():
|
||||
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows."""
|
||||
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
def test_compose_two_sips_column_wise_covers_tensor():
|
||||
"""Bench splits K-axis across 2 SIPs by calling resolve twice and
|
||||
giving each SIP half of the tensor (half-shape + offset). Shards
|
||||
from both SIPs together cover the whole K axis."""
|
||||
full_shape = (128, 256)
|
||||
itemsize = 2
|
||||
# Per-SIP half-shape (K split across SIPs).
|
||||
half_shape = (128, 128)
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
|
||||
shards_sip0 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
shards_sip1 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=1,
|
||||
)
|
||||
|
||||
total_bytes = full_shape[0] * full_shape[1] * itemsize
|
||||
sip0_bytes = sum(s.nbytes for s in shards_sip0)
|
||||
sip1_bytes = sum(s.nbytes for s in shards_sip1)
|
||||
assert sip0_bytes + sip1_bytes == total_bytes
|
||||
assert all(s.sip == 0 for s in shards_sip0)
|
||||
assert all(s.sip == 1 for s in shards_sip1)
|
||||
|
||||
|
||||
# ── SP3. row-wise placement composed across two SIPs ────────────────
|
||||
|
||||
|
||||
def test_compose_two_sips_row_wise_covers_tensor():
|
||||
full_shape = (128, 256)
|
||||
itemsize = 2
|
||||
half_shape = (64, 256) # per-SIP half of M
|
||||
dp = DPPolicy(cube="replicate", pe="column_wise")
|
||||
|
||||
shards_sip0 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
shards_sip1 = resolve_dp_policy(
|
||||
dp, shape=half_shape, itemsize=itemsize,
|
||||
num_pe=8, num_cubes=1, target_sip=1,
|
||||
)
|
||||
|
||||
total_bytes = full_shape[0] * full_shape[1] * itemsize
|
||||
assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes
|
||||
|
||||
|
||||
# ── SP4. cube × PE sharding is independent per SIP ──────────────────
|
||||
|
||||
|
||||
def test_cube_pe_sharding_independent_per_sip():
|
||||
"""Intra-SIP cube + PE layout matches across SIPs; only sip field differs."""
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise")
|
||||
s0 = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=1, num_sips=2,
|
||||
num_pe=4, num_cubes=2, target_sip=0,
|
||||
)
|
||||
assert len(shards) == 16
|
||||
|
||||
sip0_shards = [s for s in shards if s.pe_index < 8]
|
||||
sip1_shards = [s for s in shards if s.pe_index >= 8]
|
||||
|
||||
# SIP0: rows 0..63, SIP1: rows 64..127
|
||||
total_bytes = 128 * 256 * 2
|
||||
assert sip0_shards[0].offset_bytes == 0
|
||||
assert sip1_shards[0].offset_bytes == total_bytes // 2
|
||||
|
||||
|
||||
# ── SP4. 3-level resolve ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_3level_resolve_flat_index():
|
||||
"""3-level: sip × cube × pe produces correct flat indices."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
s1 = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
num_pe=4, num_cubes=2, target_sip=1,
|
||||
)
|
||||
# 2 SIPs × 2 cubes × 8 PEs = 32 shards
|
||||
assert len(shards) == 32
|
||||
|
||||
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id
|
||||
indices = [s.pe_index for s in shards]
|
||||
# SIP0: 0..15, SIP1: 16..31
|
||||
assert min(indices) == 0
|
||||
assert max(indices) == 31
|
||||
assert len(set(indices)) == 32 # all unique
|
||||
|
||||
|
||||
def test_3level_offsets_cover_full_tensor():
|
||||
"""3-level sharding covers the entire tensor with no gaps."""
|
||||
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
|
||||
shards = resolve_dp_policy(
|
||||
dp, shape=(128, 256), itemsize=2,
|
||||
num_pe=4, num_cubes=1, num_sips=2,
|
||||
assert len(s0) == len(s1) == 2 * 4
|
||||
for a, b in zip(s0, s1):
|
||||
assert a.sip == 0 and b.sip == 1
|
||||
assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == (
|
||||
b.cube, b.pe, b.offset_bytes, b.nbytes
|
||||
)
|
||||
# 2 SIPs × 1 cube × 4 PEs = 8 shards
|
||||
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
|
||||
total = 128 * 256 * 2
|
||||
# For non-replicate, total shard bytes == tensor bytes
|
||||
# (replicate within cube means cube shards overlap, but sip shards don't)
|
||||
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
|
||||
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
|
||||
assert sip0_bytes + sip1_bytes == total
|
||||
|
||||
|
||||
# ── SP5. sip="replicate" backward compat ─────────────────────────────
|
||||
|
||||
|
||||
def test_sip_replicate_backward_compat():
|
||||
"""sip='replicate' produces same result as before (2-level)."""
|
||||
dp_old = DPPolicy(cube="replicate", pe="column_wise")
|
||||
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
|
||||
|
||||
shards_old = resolve_dp_policy(
|
||||
dp_old, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
shards_new = resolve_dp_policy(
|
||||
dp_new, shape=(128, 256), itemsize=2,
|
||||
num_pe=8, num_cubes=2, num_sips=2,
|
||||
)
|
||||
assert len(shards_old) == len(shards_new)
|
||||
for a, b in zip(shards_old, shards_new):
|
||||
assert a.pe_index == b.pe_index
|
||||
assert a.offset_bytes == b.offset_bytes
|
||||
assert a.nbytes == b.nbytes
|
||||
|
||||
|
||||
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
|
||||
# ── SP5. PE_CPU num_programs (contract unchanged) ───────────────────
|
||||
|
||||
|
||||
def test_pe_cpu_sets_num_programs():
|
||||
"""PE_CPU should create TLContext with num_programs = PEs per cube."""
|
||||
# This test validates the interface contract.
|
||||
# After implementation, PE_CPU should derive num_programs from the
|
||||
# number of PE shards in the kernel launch's target cube.
|
||||
"""TLContext reports num_programs from its initializer — used by PE_CPU
|
||||
when it launches a kernel on behalf of its shards."""
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
# With 8 PEs per cube, num_programs should be 8
|
||||
tl = TLContext(pe_id=3, num_programs=8)
|
||||
assert tl.program_id(0) == 3
|
||||
assert tl.num_programs(0) == 8
|
||||
|
||||
+23
-17
@@ -2,11 +2,13 @@ import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
|
||||
from kernbench.policy.placement.dp import (
|
||||
DPPolicy,
|
||||
ShardSpec,
|
||||
column_wise,
|
||||
tiled_column_major,
|
||||
replicate,
|
||||
resolve_dp_policy,
|
||||
row_wise,
|
||||
tiled_column_major,
|
||||
tiled_row_major,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import (
|
||||
@@ -40,9 +42,9 @@ _CFG = AddressConfig(
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
(0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
@@ -133,7 +135,7 @@ def test_column_wise_placement():
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 1024 * 64 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.local_pe == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
# offsets are contiguous
|
||||
assert shards[0].offset_bytes == 0
|
||||
@@ -151,7 +153,7 @@ def test_row_wise_placement():
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 128 * 512 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.local_pe == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
assert shards[0].offset_bytes == 0
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
@@ -166,7 +168,7 @@ def test_replicate_placement():
|
||||
assert len(shards) == 8
|
||||
full_nbytes = 1024 * 512 * 2 # 1 MB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.local_pe == i
|
||||
assert s.nbytes == full_nbytes
|
||||
assert s.offset_bytes == 0 # each is a full copy
|
||||
|
||||
@@ -188,10 +190,10 @@ def test_tiled_column_major():
|
||||
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
|
||||
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
|
||||
# tile (m=2,k=0) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
assert shards[0].local_pe == 0
|
||||
assert shards[1].local_pe == 1
|
||||
assert shards[7].local_pe == 7
|
||||
assert shards[8].local_pe == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
@@ -212,10 +214,10 @@ def test_tiled_row_major():
|
||||
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
|
||||
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
|
||||
# tile (m=0,k=2) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
assert shards[0].local_pe == 0
|
||||
assert shards[1].local_pe == 1
|
||||
assert shards[7].local_pe == 7
|
||||
assert shards[8].local_pe == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
@@ -226,7 +228,11 @@ def test_tiled_row_major():
|
||||
def test_deploy_tensor_hbm():
|
||||
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
|
||||
allocs = _make_allocators()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(1024, 512), itemsize=2,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
@@ -253,7 +259,7 @@ def test_deploy_tensor_hbm():
|
||||
def test_deploy_tensor_tcm():
|
||||
"""Deploy with TCM → uses pe_tcm_addr allocation."""
|
||||
allocs = _make_allocators()
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)]
|
||||
placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=256)]
|
||||
th = deploy_tensor(
|
||||
name="small",
|
||||
shape=(128,),
|
||||
@@ -271,7 +277,7 @@ def test_deploy_tensor_overflow():
|
||||
"""Allocation exceeding PE HBM capacity raises AllocationError."""
|
||||
allocs = _make_allocators()
|
||||
# 6 GB per PE slice, try to allocate 7 GB
|
||||
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB)
|
||||
big_shard = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=7 * _GB)
|
||||
with pytest.raises(AllocationError):
|
||||
deploy_tensor(
|
||||
name="toobig",
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tests for tl.recv_async + tl.wait (ADR-0023 D4)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
|
||||
|
||||
def kernel_async_recv(t_ptr, n_elem, tl):
|
||||
"""Each PE issues recv_async first, then send, then wait — this exercises
|
||||
the non-blocking path. Uses TensorHandle math (PE_MATH) for accumulation
|
||||
so Phase 2 produces correct final HBM contents."""
|
||||
rank = tl.program_id(axis=0)
|
||||
world_size = tl.num_programs(axis=0)
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
future = tl.recv_async(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.wait(future)
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
|
||||
|
||||
def test_recv_async_mock_runtime():
|
||||
n_elem = 8
|
||||
inputs = [
|
||||
np.full((n_elem,), float(r + 1), dtype=np.float16)
|
||||
for r in range(4)
|
||||
]
|
||||
expected = sum(inputs)
|
||||
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel_async_recv,
|
||||
world_size=4,
|
||||
topology="ring_1d",
|
||||
inputs=inputs,
|
||||
kernel_args=(n_elem,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], expected)
|
||||
|
||||
|
||||
def test_recv_async_simpy_runner():
|
||||
"""Run the async kernel through the real SimPy stack via the
|
||||
install_ipcq + launch path.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
# Re-use the standard 8-PE bench skeleton but swap in the async kernel.
|
||||
topo = resolve_topology("topology.yaml")
|
||||
|
||||
# Build a tiny inline bench module
|
||||
import types
|
||||
mod = types.ModuleType("inline_bench_async")
|
||||
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
|
||||
def run(torch):
|
||||
plan = torch.install_ipcq(
|
||||
algorithm="ring_allreduce_tcm", world_size_override=8,
|
||||
)
|
||||
a = torch.zeros(
|
||||
(1, 8 * 8),
|
||||
dtype="f16",
|
||||
dp=DPPolicy(
|
||||
sip="replicate", cube="replicate", pe="column_wise",
|
||||
num_sips=1, num_cubes=1,
|
||||
),
|
||||
name="async_in",
|
||||
)
|
||||
store = torch.engine.memory_store
|
||||
base = a._handle.va_base or a._handle.shards[0].pa
|
||||
nbytes = 8 * 2
|
||||
for r in range(8):
|
||||
store.write("hbm", base + r * nbytes,
|
||||
np.full((8,), float(r + 1), dtype=np.float16))
|
||||
|
||||
torch.launch("ring_allreduce_tcm", kernel_async_recv, a, 8)
|
||||
|
||||
for r in range(8):
|
||||
result = store.read("hbm", base + r * nbytes, shape=(8,), dtype="f16")
|
||||
expected = float(sum(range(1, 9))) # 36
|
||||
assert np.allclose(result, expected, rtol=1e-2, atol=1e-2), \
|
||||
f"rank {r}: got {result}, expected {expected}"
|
||||
|
||||
mod.run = run
|
||||
result = run_bench(
|
||||
topology=topo, bench_fn=mod.run,
|
||||
device=resolve_device("all"),
|
||||
engine_factory=lambda t, d: GraphEngine(
|
||||
getattr(t, "topology_obj", t), enable_data=True
|
||||
),
|
||||
)
|
||||
assert result.completion.ok
|
||||
@@ -0,0 +1,234 @@
|
||||
"""ADR-0027 T2: TP layer shape + numerical correctness (D4/D5).
|
||||
|
||||
Phase 1: ``kernbench.tp.layers`` doesn't exist → import failure. Phase 2
|
||||
lands D4/D5 and T2 passes with deterministic non-zero weight patterns.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t2",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
# ── Shape / structural ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_column_parallel_weight_shape_per_rank(topology):
|
||||
"""ColumnParallelLinear weight per rank is (in_features, out // ws)."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc = tp.ColumnParallelLinear(
|
||||
in_features=256, out_features=512, torch=ctx,
|
||||
)
|
||||
assert fc.weight.shape == (256, 512 // ws)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
|
||||
def test_row_parallel_weight_shape_per_rank(topology):
|
||||
"""RowParallelLinear weight per rank is (in_features // ws, out_features)."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc = tp.RowParallelLinear(
|
||||
in_features=512, out_features=256, torch=ctx,
|
||||
)
|
||||
assert fc.weight.shape == (512 // ws, 256)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
|
||||
# ── T2.a: ColumnParallel deterministic numerical ─────────────────────
|
||||
|
||||
|
||||
def test_column_parallel_forward_matches_matmul(topology):
|
||||
"""T2.a: ColumnParallelLinear.forward output == x @ W_rank (rtol 1e-2)."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
M = 4
|
||||
D_in, D_out = 32, 32 * ws
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc = tp.ColumnParallelLinear(
|
||||
in_features=D_in, out_features=D_out, torch=ctx,
|
||||
)
|
||||
# Deterministic non-zero weight: rank-scaled constant.
|
||||
k_local = D_out // ws
|
||||
weight_np = np.full(
|
||||
(D_in, k_local), 0.01 * (rank + 1), dtype=np.float16,
|
||||
)
|
||||
src = Tensor(shape=(D_in, k_local), dtype="f16", name="host_w")
|
||||
src._host_buffer = weight_np
|
||||
fc.weight.copy_(src)
|
||||
|
||||
# Input: full-replicated constant.
|
||||
x_np = np.full((M, D_in), 0.5, dtype=np.float16)
|
||||
x = ctx.zeros(
|
||||
(M, D_in), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t2a_x_r{rank}",
|
||||
)
|
||||
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
|
||||
hx._host_buffer = x_np
|
||||
x.copy_(hx)
|
||||
|
||||
y = fc.forward(x)
|
||||
out = y.numpy()
|
||||
|
||||
expected = x_np.astype(np.float32) @ weight_np.astype(np.float32)
|
||||
assert out.shape == (M, k_local)
|
||||
assert np.allclose(out.astype(np.float32), expected,
|
||||
rtol=1e-2, atol=1e-2), (
|
||||
f"rank {rank}: output does not match x @ W_local"
|
||||
)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
|
||||
# ── T2.b: RowParallel observable equality ────────────────────────────
|
||||
|
||||
|
||||
def test_row_parallel_forward_concat_matmul_equality(topology):
|
||||
"""T2.b (primary): RowParallel output == concat(x) @ concat(W) (all-reduced)."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
M = 4
|
||||
D_in, D_out = 32 * ws, 32 # must divide ws evenly
|
||||
results: dict[int, np.ndarray] = {}
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc = tp.RowParallelLinear(
|
||||
in_features=D_in, out_features=D_out, torch=ctx,
|
||||
)
|
||||
# Per-rank W_k = constant 0.01 * (rank + 1)
|
||||
n_local = D_in // ws
|
||||
weight_np = np.full(
|
||||
(n_local, D_out), 0.01 * (rank + 1), dtype=np.float16,
|
||||
)
|
||||
src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w")
|
||||
src._host_buffer = weight_np
|
||||
fc.weight.copy_(src)
|
||||
|
||||
# Input x_k = constant 0.1 * (rank + 1) (pretending it was
|
||||
# column-sharded from upstream).
|
||||
x_np = np.full((M, n_local), 0.1 * (rank + 1), dtype=np.float16)
|
||||
x = ctx.zeros(
|
||||
(M, n_local), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t2b_x_r{rank}",
|
||||
)
|
||||
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
|
||||
hx._host_buffer = x_np
|
||||
x.copy_(hx)
|
||||
|
||||
y = fc.forward(x)
|
||||
results[rank] = y.numpy().astype(np.float32)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
# Host-side reference: compute sum_r (x_r @ W_r) = y (same on all ranks).
|
||||
expected = np.zeros((M, D_out), dtype=np.float32)
|
||||
n_local = D_in // ws
|
||||
for r in range(ws):
|
||||
x_r = np.full((M, n_local), 0.1 * (r + 1), dtype=np.float32)
|
||||
w_r = np.full((n_local, D_out), 0.01 * (r + 1), dtype=np.float32)
|
||||
expected += x_r @ w_r
|
||||
|
||||
for r, out in results.items():
|
||||
assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), (
|
||||
f"rank {r}: all-reduced output != expected partial sum"
|
||||
)
|
||||
|
||||
|
||||
# ── T2.c: rank-consistency post all-reduce ───────────────────────────
|
||||
|
||||
|
||||
def test_row_parallel_rank_identity_post_all_reduce(topology):
|
||||
"""T2.c: after all_reduce, all ranks see elementwise-identical output."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
M = 2
|
||||
D_in, D_out = 16 * ws, 16
|
||||
results: dict[int, np.ndarray] = {}
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc = tp.RowParallelLinear(
|
||||
in_features=D_in, out_features=D_out, torch=ctx,
|
||||
)
|
||||
n_local = D_in // ws
|
||||
weight_np = np.full((n_local, D_out), 0.01, dtype=np.float16)
|
||||
src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w")
|
||||
src._host_buffer = weight_np
|
||||
fc.weight.copy_(src)
|
||||
|
||||
x_np = np.full((M, n_local), 0.1, dtype=np.float16)
|
||||
x = ctx.zeros(
|
||||
(M, n_local), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t2c_x_r{rank}",
|
||||
)
|
||||
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
|
||||
hx._host_buffer = x_np
|
||||
x.copy_(hx)
|
||||
|
||||
y = fc.forward(x)
|
||||
results[rank] = y.numpy()
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
ref = results[0]
|
||||
for r, out in results.items():
|
||||
assert np.allclose(out, ref, rtol=1e-2, atol=1e-2), (
|
||||
f"rank {r} output differs from rank 0 — all_reduce failed to make "
|
||||
f"outputs elementwise identical"
|
||||
)
|
||||
|
||||
|
||||
def _replicate_dp():
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
@@ -0,0 +1,238 @@
|
||||
"""ADR-0027 T6: End-to-end 2-layer MLP with TP.
|
||||
|
||||
Phase 1: fails at imports. Phase 2 lands the TP package + D7 bench pattern
|
||||
and these pass with numerical-correctness checks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t6",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
def _replicate_dp():
|
||||
from kernbench.policy.placement.dp import DPPolicy
|
||||
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||
|
||||
|
||||
# ── T6.a: zero-weight smoke ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mlp_zero_weight_produces_zero_output(topology):
|
||||
"""T6.a: zero-init weight → output ≈ 0 for every rank."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
B, D_in, D_hidden, D_out = 1, 32, 32 * ws, 32
|
||||
outputs: dict[int, np.ndarray] = {}
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
|
||||
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
|
||||
|
||||
x = ctx.zeros((B, D_in), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t6a_x_r{rank}")
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
|
||||
hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16)
|
||||
x.copy_(hx)
|
||||
|
||||
h = fc1.forward(x)
|
||||
y = fc2.forward(h)
|
||||
outputs[rank] = y.numpy()
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
for r, out in outputs.items():
|
||||
assert np.allclose(out, 0.0, atol=1e-2), (
|
||||
f"rank {r}: zero-weight output should be ~0; got mean={out.mean()}"
|
||||
)
|
||||
|
||||
|
||||
# ── T6.b: deterministic weight + numerical check ─────────────────────
|
||||
|
||||
|
||||
def test_mlp_deterministic_weight_matches_reference(topology):
|
||||
"""T6.b: non-zero deterministic weights → output matches numpy reference."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
|
||||
# W1 (D_in, D_hidden) — column-sharded; per rank: (D_in, D_hidden/ws)
|
||||
# W2 (D_hidden, D_out) — row-sharded; per rank: (D_hidden/ws, D_out)
|
||||
# Constant values: W1 = 0.02, W2 = 0.03, x = 0.1 (all fp16).
|
||||
X_VAL, W1_VAL, W2_VAL = 0.1, 0.02, 0.03
|
||||
|
||||
outputs: dict[int, np.ndarray] = {}
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
|
||||
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
|
||||
|
||||
# W1 slice (per rank column slice)
|
||||
k_local_1 = D_hidden // ws
|
||||
w1_np = np.full((D_in, k_local_1), W1_VAL, dtype=np.float16)
|
||||
src1 = Tensor(shape=w1_np.shape, dtype="f16", name="host_w1")
|
||||
src1._host_buffer = w1_np
|
||||
fc1.weight.copy_(src1)
|
||||
|
||||
# W2 slice (per rank row slice)
|
||||
n_local_2 = D_hidden // ws
|
||||
w2_np = np.full((n_local_2, D_out), W2_VAL, dtype=np.float16)
|
||||
src2 = Tensor(shape=w2_np.shape, dtype="f16", name="host_w2")
|
||||
src2._host_buffer = w2_np
|
||||
fc2.weight.copy_(src2)
|
||||
|
||||
# Input x (full-replicated constant)
|
||||
x = ctx.zeros((B, D_in), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t6b_x_r{rank}")
|
||||
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
|
||||
hx._host_buffer = np.full((B, D_in), X_VAL, dtype=np.float16)
|
||||
x.copy_(hx)
|
||||
|
||||
h = fc1.forward(x)
|
||||
y = fc2.forward(h)
|
||||
outputs[rank] = y.numpy().astype(np.float32)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
# Host reference: y = x @ W1_full @ W2_full
|
||||
w1_full = np.full((D_in, D_hidden), W1_VAL, dtype=np.float32)
|
||||
w2_full = np.full((D_hidden, D_out), W2_VAL, dtype=np.float32)
|
||||
x_full = np.full((B, D_in), X_VAL, dtype=np.float32)
|
||||
expected = x_full @ w1_full @ w2_full
|
||||
|
||||
for r, out in outputs.items():
|
||||
assert out.shape == (B, D_out)
|
||||
assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), (
|
||||
f"rank {r}: MLP output != reference "
|
||||
f"(got mean={out.mean():.4f}, expected={expected.mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# ── T6.c: rank-consistency after final all_reduce ────────────────────
|
||||
|
||||
|
||||
def test_mlp_rank_consistency_after_all_reduce(topology):
|
||||
"""T6.c: all ranks see elementwise-identical final output."""
|
||||
import kernbench.tp as tp
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
|
||||
outputs: dict[int, np.ndarray] = {}
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
|
||||
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
|
||||
|
||||
# Zero weights OK for this check — just need all_reduce to run.
|
||||
x = ctx.zeros((B, D_in), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t6c_x_r{rank}")
|
||||
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
|
||||
hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16)
|
||||
x.copy_(hx)
|
||||
|
||||
h = fc1.forward(x)
|
||||
y = fc2.forward(h)
|
||||
outputs[rank] = y.numpy()
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
ref = outputs[0]
|
||||
for r, out in outputs.items():
|
||||
assert np.array_equal(out, ref), (
|
||||
f"rank {r} output differs from rank 0 — all-reduce should "
|
||||
f"make every rank see the same final tensor"
|
||||
)
|
||||
|
||||
|
||||
# ── T6.d: shape contract ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mlp_shape_contract(topology):
|
||||
"""T6.d: ColumnParallel → (B, D_hidden/ws); RowParallel → (B, D_out)."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
|
||||
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
|
||||
|
||||
x = ctx.zeros((B, D_in), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t6d_x_r{rank}")
|
||||
h = fc1.forward(x)
|
||||
assert h.shape == (B, D_hidden // ws), (
|
||||
f"ColumnParallel output shape: {h.shape} != (B, D_hidden/ws)"
|
||||
)
|
||||
y = fc2.forward(h)
|
||||
assert y.shape == (B, D_out), (
|
||||
f"RowParallel output shape: {y.shape} != (B, D_out)"
|
||||
)
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
|
||||
# ── liveness: deadlock 없음 (pytest timeout 간접 검증) ───────────────
|
||||
|
||||
|
||||
def test_mlp_completes_without_deadlock(topology):
|
||||
"""Structural: full E2E spawn returns within a reasonable wall-clock.
|
||||
|
||||
Relies on the test suite's overall timeout harness. If this hangs
|
||||
beyond ~60s it would surface as a pytest timeout — a deadlock
|
||||
regression in the scheduler loop would manifest here."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
def _worker(rank: int):
|
||||
ctx.ahbm.set_device(rank)
|
||||
fc1 = tp.ColumnParallelLinear(16, 16 * ws, torch=ctx)
|
||||
fc2 = tp.RowParallelLinear(16 * ws, 16, torch=ctx)
|
||||
x = ctx.zeros((1, 16), dtype="f16",
|
||||
dp=_replicate_dp(), name=f"t6live_r{rank}")
|
||||
h = fc1.forward(x)
|
||||
y = fc2.forward(h)
|
||||
_ = y.numpy()
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
@@ -0,0 +1,85 @@
|
||||
"""ADR-0027 T1: TP parallel_state (D3).
|
||||
|
||||
Phase 1: ``kernbench.tp`` module does not exist yet — tests fail at import.
|
||||
Phase 2 (D2/D3) lands the package and these pass.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_ctx(topology):
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
from kernbench.runtime_api.types import DeviceSelector
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
engine = GraphEngine(topology.topology_obj, enable_data=True)
|
||||
return RuntimeContext(
|
||||
engine=engine,
|
||||
target_device=DeviceSelector("all"),
|
||||
correlation_id="test_t1",
|
||||
spec=topology.topology_obj.spec,
|
||||
)
|
||||
|
||||
|
||||
def test_tp_package_importable():
|
||||
"""D2: kernbench.tp must be importable."""
|
||||
import kernbench.tp as tp
|
||||
assert hasattr(tp, "initialize_model_parallel")
|
||||
assert hasattr(tp, "get_tensor_model_parallel_world_size")
|
||||
assert hasattr(tp, "get_tensor_model_parallel_rank")
|
||||
|
||||
|
||||
def test_initialize_model_parallel_matches_world_size(topology, tmp_path, monkeypatch):
|
||||
"""D3: TP size must equal dist world_size; otherwise NotImplementedError."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
|
||||
tp.initialize_model_parallel(ws)
|
||||
assert tp.get_tensor_model_parallel_world_size() == ws
|
||||
|
||||
|
||||
def test_initialize_mismatched_ws_raises(topology):
|
||||
"""D3: calling with tp_size != world_size raises NotImplementedError."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
tp.initialize_model_parallel(ws + 1)
|
||||
|
||||
|
||||
def test_get_tp_rank_is_greenlet_local(topology):
|
||||
"""D3: get_tensor_model_parallel_rank returns greenlet-local rank
|
||||
(delegates to torch.distributed.get_rank, ADR-0024 D9)."""
|
||||
import kernbench.tp as tp
|
||||
|
||||
with _make_ctx(topology) as ctx:
|
||||
ctx.distributed.init_process_group(backend="ahbm")
|
||||
ws = ctx.distributed.get_world_size()
|
||||
tp.initialize_model_parallel(ws)
|
||||
|
||||
observed: list[int] = []
|
||||
|
||||
def _worker(rank: int):
|
||||
observed.append(tp.get_tensor_model_parallel_rank())
|
||||
|
||||
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
|
||||
|
||||
assert sorted(observed) == list(range(ws))
|
||||
|
||||
|
||||
def test_get_world_size_before_init_raises():
|
||||
"""D3: uninitialised TP group → accessing world_size fails informatively."""
|
||||
from kernbench.tp import parallel_state
|
||||
|
||||
# Reset internal state if previous tests (or parallel workers) left it set.
|
||||
parallel_state._reset_for_tests()
|
||||
|
||||
with pytest.raises((RuntimeError, AssertionError, TypeError)):
|
||||
_ = parallel_state.get_tensor_model_parallel_world_size() + 0
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import column_wise, ShardSpec
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.runtime_api.tensor import (
|
||||
TensorHandle,
|
||||
TensorShard,
|
||||
@@ -37,9 +37,9 @@ _CFG = AddressConfig(
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
(0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
@@ -88,7 +88,11 @@ def test_deploy_tensor_assigns_va_base():
|
||||
"""deploy_tensor with VA allocator assigns va_base to TensorHandle."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(1024, 512), itemsize=2,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
@@ -107,7 +111,11 @@ def test_deploy_tensor_va_covers_all_shards():
|
||||
"""VA allocation covers the entire tensor; each shard is at va_base + offset."""
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(1024, 512), itemsize=2,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
@@ -128,7 +136,11 @@ def test_deploy_tensor_does_not_install_mmu_mappings():
|
||||
allocs = _make_allocators()
|
||||
va_alloc = _make_va_allocator()
|
||||
mmus = _make_mmus()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(1024, 512), itemsize=2,
|
||||
num_pe=8, num_cubes=1, target_sip=0,
|
||||
)
|
||||
|
||||
deploy_tensor(
|
||||
name="W",
|
||||
@@ -153,7 +165,7 @@ def test_tensor_va_property():
|
||||
|
||||
allocs = _make_allocators(1)
|
||||
va_alloc = _make_va_allocator()
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)]
|
||||
placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=4096)]
|
||||
|
||||
t = Tensor(shape=(2048,), dtype="f16", name="test")
|
||||
t._handle = deploy_tensor(
|
||||
|
||||
+15
-5
@@ -20,7 +20,7 @@ from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.address.va_allocator import VirtualAllocator
|
||||
from kernbench.policy.placement.dp import DPPolicy, column_wise
|
||||
from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
|
||||
from kernbench.runtime_api.tensor import deploy_tensor
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.runtime_api.context import RuntimeContext
|
||||
@@ -70,7 +70,7 @@ def _make_standalone(shape, num_pe=NUM_PE):
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
allocators = {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
|
||||
(0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
|
||||
@@ -110,7 +110,11 @@ def test_2d_va_translates_to_local_hbm():
|
||||
cols_per_pe = K // NUM_PE
|
||||
block_bytes = M * cols_per_pe * ELEM_BYTES
|
||||
|
||||
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(M, K), itemsize=ELEM_BYTES,
|
||||
num_pe=NUM_PE, num_cubes=1, target_sip=0,
|
||||
)
|
||||
handle = deploy_tensor(
|
||||
name="src", shape=(M, K), dtype="fp16",
|
||||
placement=placement, allocators=allocators, va_allocator=va_alloc,
|
||||
@@ -178,7 +182,11 @@ def test_1d_va_translates_to_local_hbm():
|
||||
elems_per_pe = N_1D // NUM_PE
|
||||
block_bytes = elems_per_pe * ELEM_BYTES
|
||||
|
||||
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE)
|
||||
placement = resolve_dp_policy(
|
||||
DPPolicy(cube="replicate", pe="column_wise"),
|
||||
shape=(1, N_1D), itemsize=ELEM_BYTES,
|
||||
num_pe=NUM_PE, num_cubes=1, target_sip=0,
|
||||
)
|
||||
handle = deploy_tensor(
|
||||
name="src_1d", shape=(N_1D,), dtype="fp16",
|
||||
placement=placement, allocators=allocators, va_allocator=va_alloc,
|
||||
@@ -207,7 +215,9 @@ def test_1d_e2e_completes():
|
||||
correlation_id="vo6", spec=graph.spec,
|
||||
)
|
||||
|
||||
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
|
||||
# ADR-0026: DPPolicy is intra-device only; SIP scoping comes from the
|
||||
# RuntimeContext's target_device. This 1D e2e runs on a single SIP.
|
||||
dp = DPPolicy(cube="column_wise", pe="column_wise")
|
||||
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
|
||||
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ system:
|
||||
|
||||
sips:
|
||||
count: 2
|
||||
topology: ring_1d
|
||||
|
||||
components:
|
||||
switch: { kind: switch, impl: builtin.switch, attrs: { overhead_ns: 5.0 } }
|
||||
|
||||
Reference in New Issue
Block a user