8 Commits

Author SHA1 Message Date
mukesh 1d8b9401e5 Intercube allreduce: pe0 cube-mesh reduce + multi-SIP ring/torus/mesh
New intercube allreduce kernel replacing the old flat ring algorithms.
Reduces across the 4x4 cube mesh within each SIP (pe0-only, same-lane),
then inter-SIP exchange on root cube, then broadcast back. Supports
ring_1d, torus_2d, and mesh_2d_no_wrap SIP topologies driven by
topology.yaml. Integrated with dist.init_process_group / dist.all_reduce.

New files:
- src/kernbench/ccl/algorithms/intercube_allreduce.py (kernel)
- src/kernbench/ccl/sfr_config.py (configure_sfr_intercube_multisip)
- tests/test_allreduce_multidevice.py (config-driven, 3 topologies)
- tests/test_distributed_intercube_allreduce.py (full distributed path)
- tests/test_intercube_sfr_config.py (SFR wiring verification)

Modified:
- distributed.py: AhbmCCLBackend uses configure_sfr_intercube_multisip
- topologies.py: added torus_2d, mesh_2d_no_wrap
- install.py: global_E/W/N/S in _OPPOSITE_DIR
- topology.yaml: added system.sips.topology
- ccl.yaml: single intercube_allreduce algorithm
- benches/ccl_allreduce.py: row_wise cube-mesh tensor layout

Removed old flat-ring algorithms and their tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 17:33:42 -07:00
ywkang cfc2d74ec4 Refactor ccl_allreduce bench: rank=SIP only, remove rank=PE legacy path
The unified ccl_allreduce bench previously carried two execution models
in one worker with ``if world_size == n_sips:`` branching:
  - TP mode (rank = SIP, ADR-0024/0027): proper ProcessGroup semantics.
  - Legacy rank = PE mode: single-driver worker allocating one big tensor
    distributed across all PEs via _derive_dp, with kernel-level SPMD via
    program_id.

The second model is unnecessary — intra-SIP PE-level collectives are
expressed inside the kernel (tl.send/tl.recv with program_id, IPCQ) and
do not need a host-side ProcessGroup. Removing it lets the bench be a
clean reference implementation of the TP launcher.

benches/ccl_allreduce.py:
- Config resolved once in run() via _resolve_cfg -> _BenchCfg dataclass.
- rank != n_sips now raises RuntimeError explicitly.
- _worker / _allocate_rank_tile / _init_with_rank_value / _report each
  have one concern; duplicated init + verification paths collapsed.
- _derive_dp and the second verify+print block deleted.
- 166 lines -> 91 lines.

ccl.yaml:
- mesh_allreduce_4 (world_size: 4) and tree_allreduce_7 (world_size: 7)
  algorithm entries removed (rank = PE only).
- Algorithm kernel files (kernbench.ccl.algorithms.mesh_allreduce,
  tree_allreduce) kept as-is for direct-dispatch future use.

tests/test_ccl_allreduce_matrix.py:
- Matrix shrinks from 7 cases to 3: ring × {tcm, hbm, sram} at ws =
  topology SIP count (= 2). mesh_2x2, tree_binary_7, ring_multi_cube,
  and the three ring_*_8 cases removed.

tests/test_ccl_performance.py:
- _run_8rank renamed to _run_ring; world_size: 8 override dropped; now
  exercises rank = SIP ring all-reduce.

tests/test_mp_spawn.py, tests/test_ccl_ddp_launcher.py:
- Monkeypatch target updated from bench.worker to bench._worker
  (signature now takes BenchCfg instead of (rank, world_size)).

555 passed, 1 intentional skip. Tests that directly call
install_ipcq(world_size_override=N) for kernel-level sanity
(test_ccl_hello_world_guide, test_recv_copy_to_dst, test_tl_recv_async,
test_ccl_deadlock_detection) are unchanged — they never went through
the bench and still exercise the kernel-only path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 16:45:27 -07:00
ywkang 105f1dc09e ADR-0027: Megatron TP API + worker-wait generalization + mp.spawn
Implements ADR-0027 Phase 2 end-to-end. All 559 tests pass (was 523 +
1 xfail; ring_default_ws strict-xfail is now resolved).

D0 — Worker-wait generalization (context.py):
- _pending_worker_waits queue on RuntimeContext.
- ctx.wait(h) in worker context defers to main via g.parent.switch().
  Fast-path for already-completed handles.
- Worker API is unchanged: tensor deploy, launch, etc. still look
  synchronous; they're transparently cooperatively scheduled.
- Solves ADR-0024 Phase B kernel-greenlet orphan bug (env.run now
  only ever drives from main; kernel _parent is always main).

D0.5 — Host-read barrier (tensor.py):
- Explicit _HOST_READ_BARRIERS registry (T5.g closed-set via code
  review, not reflection-magic).
- numpy/data/__getitem__/__repr__ drain pending worker-waits before
  host-observable read.
- copy_: source-side barrier via source.numpy(). Target-side write
  barrier is intentionally NOT applied — global pending target barrier
  prematurely drains cross-rank collectives → deadlock.
- Collective pending is excluded from barrier drain condition
  (collective is cross-rank; its own yield in all_reduce covers the
  invariant naturally).

D1 — torch.multiprocessing.spawn (runtime_api/multiprocessing.py):
- API signature parity with real PyTorch spawn; execution is
  cooperative greenlet scheduler (process isolation etc. are explicit
  non-goals per D1.0).
- _drain_pending drains worker-waits then collectives in one barrier,
  loop-until-empty.
- Round-based exception handling with SystemExit sibling abort +
  SpawnException(errors) wrapping root-cause ranks.
- RuntimeContext attaches ctx.multiprocessing in __post_init__.
- benches/ccl_allreduce.py hand-rolled loop collapses to one
  torch.multiprocessing.spawn call.

D2–D6 — kernbench.tp package:
- parallel_state: initialize_model_parallel, get_*_rank,
  get_*_world_size, with weak active-ctx registry in context.py.
- layers: ColumnParallelLinear, RowParallelLinear (shape-only
  primitives — fp16 gemm via tl.load + tl.dot + tl.store).
- kernels: _gemm_kernel used by TP layers (self-contained; no bench
  dependency).
- primitives / mappings stubs per D6/D8.

Data-path fixes (surfaced by TP gemm + all_reduce sequence):
- sim_engine/op_log.py: dma_write snapshot is skipped for TCM
  sources (PE scratch is repopulated by Phase 2 math/gemm replay —
  capturing Phase-1-time snapshot picked up STALE data from prior
  kernel's output aliased at the same scratch addr, causing the later
  kernel's dma_write to overwrite Phase 2 result with stale value).
- sim_engine/op_log.py + sim_engine/data_executor.py: per-operand
  space recorded on GemmCmd and composite gemm records so HBM-resident
  operands (tl.load output) don't default to TCM during replay.
- runtime_api/context.py: ctx.zeros writes zero-init to MemoryStore
  at VA keys so kernels reading via VA see deterministic init even
  without explicit copy_().

Tests (Phase 1 + Phase 2):
- test_worker_wait_drain (T3): orphan invariant + resume + multi-rank
  drain + idempotency + exception propagation.
- test_mp_spawn (T4): spawn shape + bind + SpawnException scope.
- test_host_read_barrier (T5): barrier contract per entry-point +
  closed-set registry check.
- test_tp_parallel_state (T1): initialize + rank lookup.
- test_tp_layers (T2): shape + deterministic numerical correctness
  (concat-matmul equality for RowParallel, not mean-only).
- test_tp_mlp (T6): full 2-layer MLP with deterministic weight
  numerical match + rank-consistency post all-reduce.
- test_ccl_allreduce_matrix: ring_default_ws xfail removed (T7).

Regression: 523 pre + 35 new + 1 ex-xfail = 559 passed, 1 intentional
skip (T3.e historical failure documentation).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 16:31:13 -07:00
ywkang e7f376ebaa ADR-0027 rev7 (Megatron TP + worker-wait generalization) + ADR-0026 typo fix
ADR-0027 is a design-only change (no production code). Rev 7 closes design
across 7 iterations of review. Key decisions:

- D0 (worker-wait generalization): ctx.wait in worker context yields to
  main scheduler, which drains env.run. Solves ADR-0024 Phase B orphan
  bug (ring_default_ws strict xfail). Normative contracts on resume
  invariant, fast-path, main-context non-reentrance, barrier
  loop-until-empty, and scheduler non-progress as user contract.
- D0.5 (host-read barrier): Tensor.numpy/data/__getitem__/__repr__/copy_
  auto-drain pending before reading. Closed-set via explicit registry
  (T5.g). copy_ uses global-pending barrier with explicit
  over-serialization tradeoff.
- D1 (torch.multiprocessing.spawn): real-PyTorch API-signature parity,
  cooperative greenlet scheduler internally. Explicit non-goal on
  process isolation / address space / failure isolation. Sibling
  cleanup via SystemExit + SpawnException(errors) wrapping root-cause
  ranks.
- D4/D5 (TP layers): ColumnParallelLinear / RowParallelLinear use
  torch.launch(gemm_kernel) — no host-side torch.matmul. Yield-safety
  contract normatively required for all TP forward paths.
- Supersedes ADR-0024 D7/D12/D13 as design (none landed). Source of
  truth declared normative.

Test strategy: T1-T8 with numerical-correctness primary (not mean/
aggregate-only), orphan invariant direct assertion, host-read barrier
closed-set via registry. Phase 2 acceptance = 524 passed + 0 xfail
(ring_default_ws unblocked by D0).

ADR-0026 typo fix: torch.cuda.set_device → torch.ahbm.set_device in
DPPolicy docstring (ADR-0024 D10 convention).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 14:13:26 -07:00
ywkang 357cab525b ADR-0026: DPPolicy intra-device only + ShardSpec structural coords
DPPolicy no longer carries a cross-SIP axis. SIP-level placement is
solely controlled by torch.ahbm.set_device(rank) (ADR-0024); DPPolicy
itself describes only the cube × PE layout within one SIP. ShardSpec
switches to structural (sip, cube, pe) coordinates; the flat pe_index
field/property is fully removed — silent drift between global-flat and
SIP-local interpretations was a foot-gun flagged by ADR-0024 D11.

Breaking API (explicit TypeError / AttributeError):
- DPPolicy(sip=...) / DPPolicy(num_sips=...) -> TypeError
- ShardSpec.pe_index -> AttributeError
- ShardSpec(pe_index=...) -> TypeError
- resolve_dp_policy now takes target_sip= (required), no num_sips.

Downstream migration:
- PE allocator dict keyed by (sip, cube, pe) tuples, in both
  _ensure_allocators and _free_tensor. deploy_tensor uses tuple lookup.
- _create_tensor passes target_sip=current_sip; post-hoc pe_index
  shifting removed entirely.
- launch._compute_local_shape drops the dp.sip branch.
- Internal resolvers (column_wise / row_wise / replicate / tiled_*)
  return _LocalPeShard (cube-local identifier) instead of ShardSpec —
  resolve_dp_policy lifts them to full structural coords.

Tests:
- New tests/test_adr0026_dppolicy_intra_device.py (12 tests) pins the
  contract end-to-end.
- test_sip_parallel.py rewritten: SIP composition now modeled as two
  resolve_dp_policy(target_sip=...) calls (ADR-0024 launcher style).
- Call-site migration: test_tensor, test_va_integration, test_va_offset,
  test_runtime_api_tensor, test_tl_recv_async, test_ccl_* and benches
  gemm_single_pe, gpt3_qkv, va_offset_verify, ccl_allreduce (legacy
  branch) all use intra-device DPPolicy and structural ShardSpec.

Result: 523 passed, 1 strict xfail (ring_default_ws — unchanged
ADR-0024 Phase B blocker; architectural fix deferred to ADR-0027).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 13:02:19 -07:00
ywkang 787409ced1 ADR-0024 Phase B: update xfail reason with architectural blocker details
Phase B Option A (freeze + defer to ADR-0027): the root cause of
ring_default_ws strict-xfail is that bench workers call torch.zeros /
copy_ which drive env.run in the WORKER-greenlet context. Any pending
KernelLaunchMsg gets stepped inside that worker, spawning kernel_runner
with parent = worker (not main). When the worker yields/finishes, the
kernel greenlet is orphaned and its next switch_to_simpy raises
GreenletExit mid-add — producing rank 0 mean=1 (expected 3).

This is a larger architectural redesign (lazy-deploy tensor API,
coroutine worker, or setup/verify split) and is parked until ADR-0027
(Megatron TP) starts, where the proper solution ships with TP use cases.

No production changes; xfail reason + inline comment only.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 12:46:33 -07:00
ywkang 79124daab1 ADR-0024 Phase B (partial): scheduler-level collective drain
Root cause (hang diagnosis):
`kernel_runner.run()` captures `greenlet.getcurrent()` at spawn time as
the kernel greenlet's `_parent`. When a worker greenlet (say g0) calls
`dist.all_reduce` → `ctx.wait(h)` → `env.run(until=h0)`, the SimPy
scheduler steps pe_cpu processes, which in turn spawn kernel greenlets.
Those kernels' `_parent` becomes g0 (current greenlet at spawn). When a
kernel yields via switch_to_simpy, control jumps back up to g0's LAST
switch point — which is the main scheduler's `g.switch()` call — rather
than the kernel_runner's generator frame. Main then re-enters its
`for g in alive: g.switch()` loop mid-wait, producing nested greenlet
re-entry. Scheduler spins: g0 never completes, g1 appears to complete
out of order, infinite loop at 100% CPU.

Fix:
- AhbmCCLBackend.all_reduce: in multi-greenlet mode, submit via
  launch(_defer_wait=True), extend backend._pending_collective_handles,
  and yield to the parent greenlet. Worker does NOT call wait.
- benches/ccl_allreduce.py run(): after each scheduler round, the MAIN
  greenlet drains backend._pending_collective_handles. This keeps
  env.run invocation in the main context, so kernel_runner's spawned
  kernel greenlets have main as their _parent — no nested re-entry.
- Legacy single-driver path (no bench scheduler): all_reduce falls back
  to inline wait when g.parent is None.

Result:
- Multi-greenlet cross-SIP ring no longer hangs (was 100% CPU infinite
  loop in kernel_runner._switch_kernel).
- ring_default_ws still xfail(strict=True): now fails as a data
  correctness issue — DataExecutor reports only 1 math op for a 2-rank
  ring (expected 2). Cross-SIP op_log replay integration is the
  remaining Phase B task.

514 passed, 1 xfailed (strict).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 09:14:03 -07:00
ywkang 4ba0a83e71 Implement ADR-0024 Phase A: SIP-level TP launcher MVP
Scope (Phase A):
- D1: world_size fallback = SIP count (rank = SIP, TP boundary)
- D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0)
- D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias
- D11: tensor placement scoped to current-device SIP (post-hoc pe_index
  shift — ADR-0026 replaces with structural coords)
- D12/D13: multi-greenlet run() with simple round-robin scheduler;
  hybrid dispatch (ws == SIP count → multi-greenlet, else legacy
  single-worker for ccl.yaml override compat)
- D7 partial: backend.all_reduce submit + yield + wait via launch()'s
  new _defer_wait flag; parent-less greenlets skip yield
- Relaxed shard-count check (len(shards) > 0 instead of == world_size)
- rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips

Deferred to Phase B:
- Engine-routed install (D2) — keeps sideband
- install_plan.py module (D6) — keeps install.py
- Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock
- Validator registry (D8)
- Cross-SIP multi-greenlet + real kernel integration — matrix
  ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix;
  marked xfail(run=False) pending Phase B diagnosis (suspected per-rank
  kernel_args / program_id mismatch)

Tests:
- test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13
- test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override
  cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 /
  tree_binary_7) all pass via legacy path

514 tests pass, 1 xfail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 09:00:28 -07:00
55 changed files with 3666 additions and 2464 deletions
+1
View File
@@ -29,3 +29,4 @@ build/
# Logs # Logs
*.log *.log
.claude/
+80 -106
View File
@@ -1,129 +1,103 @@
"""CCL all-reduce bench — single unified entry point. """CCL all-reduce bench (ADR-0024 + ADR-0027).
Driven entirely by ``ccl.yaml`` + ``topology.yaml``: Pure TP launcher model: rank = SIP. Each rank owns a ``(N_CUBES, n_elem)``
tensor sharded row-wise across the cube mesh (pe0 per cube). After
``dist.all_reduce(op="sum")`` every cube on every rank must hold
``N_CUBES * sum(1..world_size)``. Rank 0 prints the pass/fail line.
- ``defaults.algorithm`` in ``ccl.yaml`` picks which kernel to run Driven by ``ccl.yaml`` (``defaults.algorithm``, ``n_elem``) + ``topology.yaml``
(``ring_allreduce_{tcm,hbm,sram}`` / ``mesh_allreduce_4`` / (SIP count → world_size, cube_mesh → N_CUBES).
``tree_allreduce_7``).
- ``world_size`` is derived from the algorithm entry's override or from
the topology spec (``sips × cubes_per_sip × pes_per_cube``).
- The host code uses only real PyTorch ``torch.distributed`` names:
``init_process_group``, ``get_world_size``, ``get_rank``, ``all_reduce``.
The bench is split into ``worker(rank, world_size, torch)`` — the
per-rank business logic, designed to look like a real PyTorch DDP
training worker so future model benches can reuse the same skeleton —
and ``run(torch)`` — the kernbench-specific launcher that initializes
the process group and invokes the worker.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import numpy as np import numpy as np
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
# Default per-rank tile size if ccl.yaml doesn't override it. Real DEFAULT_N_ELEM = 8
# pytorch benches hardcode batch/feature dims similarly.
DEFAULT_N_ELEM = 32
def _derive_dp(spec: dict, world_size: int) -> DPPolicy: @dataclass(frozen=True)
"""Pick a DPPolicy that fans the tensor across exactly ``world_size`` PEs. class _BenchCfg:
algorithm: str
n_elem: int
n_cubes: int
world_size: int
Mirrors what a real PyTorch DDP user does manually with
``tensor.to(f"cuda:{rank}")``: the host code chooses the placement so def _resolve_cfg(torch) -> _BenchCfg:
that the collective sees the right number of participating ranks. """Read ccl.yaml + topology once at host side."""
""" merged = resolve_algorithm_config(load_ccl_config())
sips = int(spec["system"]["sips"]["count"]) ws = torch.distributed.get_world_size()
cm = spec["sip"]["cube_mesh"] spec = torch.spec or {}
pl = spec["cube"]["pe_layout"] n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
pes_per_cube = int(pl["pe_per_corner"]) * len(pl["corners"]) if ws != n_sips:
cubes_per_sip = int(cm["w"]) * int(cm["h"]) raise RuntimeError(
total = sips * cubes_per_sip * pes_per_cube f"ccl_allreduce bench requires world_size == topology SIP count "
if world_size == total: f"(world_size={ws}, n_sips={n_sips})."
return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise")
if world_size <= pes_per_cube:
return DPPolicy(
sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=world_size,
) )
if world_size <= cubes_per_sip * pes_per_cube: cm = spec.get("sip", {}).get("cube_mesh", {})
return DPPolicy( n_cubes = int(cm.get("w", 4)) * int(cm.get("h", 4))
sip="replicate", cube="column_wise", pe="column_wise", return _BenchCfg(
num_sips=1, num_cubes=world_size // pes_per_cube, algorithm=merged["algorithm"],
) n_elem=int(merged.get("n_elem", DEFAULT_N_ELEM)),
return DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") n_cubes=n_cubes,
world_size=ws,
def worker(rank: int, world_size: int, torch) -> None:
"""Per-rank business logic. Mirrors a real PyTorch DDP worker.
In real PyTorch DDP, this function runs in N separate processes,
each with its own ``rank``. In kernbench (single-process multi-device)
it is invoked once with ``rank=0`` on the single host driver; the
actual per-PE parallelism is handled by ``torch.launch`` fanning out
the kernel across all participating PEs via the tensor's DPPolicy.
The ``rank`` parameter is therefore always 0 today, and is kept as
an explicit argument for parity with real DDP workers (``if rank ==
0`` logging guards, future multi-host extensions).
"""
cfg = resolve_algorithm_config(load_ccl_config())
algo_name = cfg["algorithm"]
n_elem = int(cfg.get("n_elem", DEFAULT_N_ELEM))
# Pick a DP that produces exactly ``world_size`` shards on this topology.
dp = _derive_dp(torch.spec, world_size)
tensor = torch.zeros(
(1, world_size * n_elem), dtype="f16", dp=dp, name="ccl_in",
) )
# Initialize: CCL rank r's slice gets value (r + 1). Real PyTorch idiom:
# target.copy_(torch.from_numpy(source))
init = np.zeros((1, world_size * n_elem), dtype=np.float16)
for r in range(world_size):
init[0, r * n_elem : (r + 1) * n_elem] = float(r + 1)
tensor.copy_(torch.from_numpy(init))
# The main act: one all_reduce call — the backend installs IPCQ at def _rank_dp(n_cubes: int) -> DPPolicy:
# init_process_group time and here only dispatches the kernel. return DPPolicy(cube="row_wise", pe="replicate", num_cubes=n_cubes, num_pes=1)
def _allocate_rank_tensor(torch, rank: int, cfg: _BenchCfg):
"""Allocate this rank's ``(n_cubes, n_elem)`` tensor on its SIP."""
return torch.zeros(
(cfg.n_cubes, cfg.n_elem), dtype="f16",
dp=_rank_dp(cfg.n_cubes), name=f"ccl_in_r{rank}",
)
def _init_with_rank_value(torch, tensor, rank: int, cfg: _BenchCfg) -> None:
"""Fill all cubes with the scalar ``rank + 1``."""
arr = np.full((cfg.n_cubes, cfg.n_elem), float(rank + 1), dtype=np.float16)
tensor.copy_(torch.from_numpy(arr))
def _report(result: np.ndarray, cfg: _BenchCfg) -> None:
"""Single-line pass/fail printer (rank 0 only)."""
expected = float(cfg.n_cubes * sum(range(1, cfg.world_size + 1)))
ok = True
for cube_id in range(cfg.n_cubes):
if not np.allclose(result[cube_id], expected, rtol=1e-1, atol=1e-1):
ok = False
break
if ok:
total = cfg.world_size * cfg.n_cubes
print(f" {cfg.algorithm} (ws={cfg.world_size}): {total} OK")
return
got = float(result.reshape(-1).mean())
print(
f" [FAIL] {cfg.algorithm} (ws={cfg.world_size}): "
f"got mean={got:.3f}, expected={expected:.3f}"
)
def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
torch.ahbm.set_device(rank)
tensor = _allocate_rank_tensor(torch, rank, cfg)
_init_with_rank_value(torch, tensor, rank, cfg)
torch.distributed.all_reduce(tensor, op="sum") torch.distributed.all_reduce(tensor, op="sum")
# Verify: each shard should hold sum(1..world_size) after all-reduce.
result = tensor.numpy()
expected = float(sum(range(1, world_size + 1)))
all_ok = bool(np.allclose(result, expected, rtol=1e-1, atol=1e-1))
# Print only on rank 0 — real PyTorch DDP idiom for single-source logs.
if rank == 0: if rank == 0:
if all_ok: _report(tensor.numpy(), cfg)
print(f" {algo_name} (ws={world_size}): {world_size} OK")
else:
flat = result.reshape(-1)
n_fail = 0
for r in range(world_size):
slice_r = flat[r * n_elem : (r + 1) * n_elem]
if not np.allclose(slice_r, expected, rtol=1e-1, atol=1e-1):
n_fail += 1
if n_fail <= 5:
print(
f" [FAIL] rank {r} "
f"(ws={world_size}, algo={algo_name}): "
f"got mean={float(slice_r.mean()):.3f}, "
f"expected={expected:.3f}"
)
print(
f" {algo_name} (ws={world_size}): "
f"{world_size - n_fail} OK / {n_fail} FAIL"
)
def run(torch) -> None: def run(torch) -> None:
"""CLI entry point: initialize the process group, invoke worker.""" torch.distributed.init_process_group(backend="ahbm")
dist = torch.distributed cfg = _resolve_cfg(torch)
dist.init_process_group(backend="ahbm") torch.multiprocessing.spawn(
worker( _worker, args=(cfg, torch), nprocs=cfg.world_size,
rank=dist.get_rank(),
world_size=dist.get_world_size(),
torch=torch,
) )
+2 -2
View File
@@ -3,7 +3,7 @@
Full host-to-PE pipeline: Full host-to-PE pipeline:
Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM Host → PCIE_EP → IO_CPU → M_CPU → PE_CPU → SchedulerV2 → PE_DMA → HBM
Single PE: num_sips=1, num_cubes=1, num_pes=1 via DPPolicy override. Single PE: num_cubes=1, num_pes=1 via DPPolicy override.
Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams Both operands use tl.ref (HBM-resident); scheduler_v2 tiles and streams
per-tile DMA internally. per-tile DMA internally.
@@ -30,7 +30,7 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
def run(torch): def run(torch):
"""Run the single-PE GEMM benchmark.""" """Run the single-PE GEMM benchmark."""
dp = DPPolicy(cube="replicate", pe="replicate", dp = DPPolicy(cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1) num_cubes=1, num_pes=1)
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a") a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b") b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
+8 -4
View File
@@ -72,12 +72,16 @@ def run(torch):
K = GPT3_D_MODEL K = GPT3_D_MODEL
N = COLS_PER_PE N = COLS_PER_PE
# X: replicated across all PEs # ADR-0026: DPPolicy is intra-device only. For multi-SIP execution the
# ADR-0024 launcher calls this bench once per SIP (each worker via
# torch.ahbm.set_device(rank)); here the policy describes only the
# cube × PE layout within a single SIP.
# X: replicated across all PEs within the SIP
dp_replicate = DPPolicy(cube="replicate", pe="replicate", dp_replicate = DPPolicy(cube="replicate", pe="replicate",
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
# W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs # W_Q/K/V, out_Q/K/V: column-wise sharded across all PEs within the SIP
dp_sharded = DPPolicy(cube="column_wise", pe="column_wise", dp_sharded = DPPolicy(cube="column_wise", pe="column_wise",
num_sips=N_SIPS, num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE) num_cubes=N_CUBES, num_pes=N_PE_PER_CUBE)
x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x") x = torch.empty((M, K), dtype=DTYPE, dp=dp_replicate, name="x")
wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq") wq = torch.empty((K, GPT3_D_MODEL), dtype=DTYPE, dp=dp_sharded, name="wq")
+2 -2
View File
@@ -1,7 +1,7 @@
"""VA offset verification benchmark. """VA offset verification benchmark.
Verifies that Triton-style base_ptr + pid * stride addressing works correctly Verifies that Triton-style base_ptr + pid * stride addressing works correctly
with full TP sharding (sip/cube/pe all column_wise). Each PE loads its own with intra-SIP TP sharding (cube/pe column_wise). Each PE loads its own
block from a sharded tensor and stores it back. block from a sharded tensor and stores it back.
The kernel uses standard Triton patterns: The kernel uses standard Triton patterns:
@@ -28,7 +28,7 @@ def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
def run(torch): def run(torch):
"""Run the VA offset verification benchmark with full TP sharding.""" """Run the VA offset verification benchmark with full TP sharding."""
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") dp = DPPolicy(cube="column_wise", pe="column_wise")
src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src") src = torch.zeros((M, K), dtype=DTYPE, dp=dp, name="src")
dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst") dst = torch.empty((M, K), dtype=DTYPE, dp=dp, name="dst")
+12 -55
View File
@@ -6,12 +6,7 @@
defaults: defaults:
# Algorithm to run for this benchmark execution. # Algorithm to run for this benchmark execution.
algorithm: ring_allreduce_tcm algorithm: intercube_allreduce
# NOTE: world_size is not set here by default. AhbmCCLBackend derives it
# from the chosen algorithm's entry (if it sets ``world_size``) or from
# topology.yaml (``sips × cubes_per_sip × pes_per_cube``). This mirrors
# real PyTorch DDP where ranks/world_size come from env vars, not code.
# IPCQ ring buffer location. # IPCQ ring buffer location.
# tcm — PE-local TCM (fast, small, conflicts with compute TCM access) # tcm — PE-local TCM (fast, small, conflicts with compute TCM access)
@@ -30,59 +25,21 @@ defaults:
# Slot size in bytes (must hold one tile worth of data). # Slot size in bytes (must hold one tile worth of data).
slot_size: 4096 slot_size: 4096
# PE_DMA virtual channel chunk size (D8). First implementation does not # PE_DMA virtual channel chunk size (D8).
# use chunk-level interleave; this is reserved for future precision.
vc_chunk_size: 256 vc_chunk_size: 256
# Credit return fast path message size (D9). Used by bottleneck-BW # Credit return fast path message size (D9).
# latency calculation. 16-64 bytes typical.
ipcq_credit_size_bytes: 16 ipcq_credit_size_bytes: 16
algorithms: algorithms:
# ── ring all-reduce, buffer in PE_TCM ── # ── intercube all-reduce (pe0-only, cube mesh + inter-SIP) ──
# Defaults to topology-derived world_size (full system, 256 ranks). # Reduces across the 4×4 cube mesh within each SIP, then inter-SIP
# Use a smaller tile size at high rank counts so f16 sums stay within # exchange on root cube, then broadcast back. SIP topology is read
# the verification tolerance and op_log replay scales. # from topology.yaml → system.sips.topology. Kernel auto-selects
ring_allreduce_tcm: # ring / torus / mesh inter-SIP exchange pattern.
module: kernbench.ccl.algorithms.ring_allreduce intercube_allreduce:
topology: ring_1d module: kernbench.ccl.algorithms.intercube_allreduce
buffer_kind: tcm
n_elem: 8
# ── ring all-reduce, buffer in PE-local HBM ──
ring_allreduce_hbm:
module: kernbench.ccl.algorithms.ring_allreduce
topology: ring_1d
buffer_kind: hbm
n_elem: 8
# ── ring all-reduce, buffer in cube SRAM ──
ring_allreduce_sram:
module: kernbench.ccl.algorithms.ring_allreduce
topology: ring_1d
buffer_kind: sram
n_elem: 8
# ── 2D mesh all-reduce: perfect square only (2×2 = 4 PEs) ──
mesh_allreduce_4:
module: kernbench.ccl.algorithms.mesh_allreduce
topology: mesh_2d
buffer_kind: tcm
world_size: 4
n_elem: 16
# ── tree all-reduce (binary, 7 PEs) ──
tree_allreduce_7:
module: kernbench.ccl.algorithms.tree_allreduce
topology: tree_binary
buffer_kind: tcm
world_size: 7
n_elem: 16
# ── hierarchical all-reduce (3-level: intra-cube → inter-cube → inter-SIP) ──
# Uses bidirectional ring reduce + chain broadcast. ~25 rounds vs 255 flat.
hierarchical_allreduce:
module: kernbench.ccl.algorithms.hierarchical_allreduce
topology: none topology: none
buffer_kind: tcm buffer_kind: tcm
n_elem: 16 n_elem: 8
root_cube: 15
+4 -4
View File
@@ -2,7 +2,7 @@
## Status ## Status
Proposed (Revision 4문서 일관성 + grep audit 구체화) Accepted (Revision 5Phase 2 landed 2026-04-14, 523 passed + 1 strict xfail)
## Context ## Context
@@ -69,9 +69,9 @@ class DPPolicy:
class DPPolicy: class DPPolicy:
"""Intra-device (cube × PE) data-parallel policy. """Intra-device (cube × PE) data-parallel policy.
SIP-level placement is controlled by ``torch.cuda.set_device(rank)`` SIP-level placement is controlled by ``torch.ahbm.set_device(rank)``
(ADR-0024) and, for model-level TP, by Megatron-style parallel layers (ADR-0024 D10) and, for model-level TP, by Megatron-style parallel
(ADR-0027). DPPolicy does not cross SIP boundaries. layers (ADR-0027). DPPolicy does not cross SIP boundaries.
""" """
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate" cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -129,8 +129,8 @@ N_ELEM = 8
def worker(rank: int, world_size: int, torch) -> None: def worker(rank: int, world_size: int, torch) -> None:
"""Per-rank business logic — mirrors a real PyTorch DDP worker.""" """Per-rank business logic — mirrors a real PyTorch DDP worker."""
dp = DPPolicy( dp = DPPolicy(
sip="replicate", cube="replicate", pe="column_wise", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=world_size, num_cubes=1, num_pes=world_size,
) )
tensor = torch.zeros( tensor = torch.zeros(
(1, world_size * N_ELEM), dtype="f16", dp=dp, name="hello_in", (1, world_size * N_ELEM), dtype="f16", dp=dp, name="hello_in",
+2 -2
View File
@@ -114,8 +114,8 @@ def run(torch):
a = torch.zeros( a = torch.zeros(
(1, WORLD_SIZE * N_ELEM), dtype="f16", (1, WORLD_SIZE * N_ELEM), dtype="f16",
dp=DPPolicy( dp=DPPolicy(
sip="replicate", cube="replicate", pe="column_wise", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_cubes=1,
), ),
name="hello_in", name="hello_in",
) )
@@ -1,29 +0,0 @@
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
Each PE sends its tile to the E neighbor and receives one tile from W,
then stores the received tile back into its own HBM slice. The simplest
possible demonstration of ``tl.send`` / ``tl.recv``.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend."""
return (n_elem,)
def kernel(t_ptr, n_elem, tl):
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
# Send our local HBM tile to the E neighbor.
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
tl.send(dir="E", src=src)
# Receive a tile from W and store it into our slice (overwrite).
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
tl.store(pe_addr, recv)
@@ -1,192 +0,0 @@
"""Hierarchical all-reduce kernel (ADR-0023).
3-level reduce + broadcast exploiting the topology hierarchy:
Level 1 — Intra-cube (8 PEs, E/W, fastest link):
Bidirectional ring reduce to PE 0.
Level 2 — Inter-cube within SIP (16 cubes, N/S, UCIe):
Bidirectional ring reduce of PE 0s to cube 0 PE 0.
Level 3 — Inter-SIP (2 SIPs, parent):
Pair exchange between SIP representatives.
Broadcast — Reverse chain through levels 2 and 1.
Bidirectional reduce: left-half sends toward node 0 via dir_dec,
right-half sends via dir_inc (wrapping). Representative receives from
both sides. Rounds per level = ceil((group_size - 1) / 2).
Direction pairing (ring):
Send via dir_dec at PE K → recv via dir_inc at PE K-1
Send via dir_inc at PE K → recv via dir_dec at PE K+1
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Positional kernel args for the ahbm backend."""
pes_per_cube = 8
num_sips = max(1, world_size // 128) if world_size > 128 else 1
cubes_per_sip = world_size // (pes_per_cube * num_sips)
return (n_elem, pes_per_cube, cubes_per_sip, num_sips)
def neighbors(rank: int, world_size: int, neighbor_map: dict) -> dict:
"""Build the 3-level neighbor map."""
pes_per_cube = 8
num_sips = max(1, world_size // 128) if world_size > 128 else 1
cubes_per_sip = world_size // (pes_per_cube * num_sips)
pe_id = rank % pes_per_cube
cube_global = rank // pes_per_cube
sip_id = cube_global // cubes_per_sip
local_cube_id = cube_global % cubes_per_sip
result = {}
# Level 1: intra-cube ring (E/W, all PEs)
cube_base = cube_global * pes_per_cube
result["E"] = cube_base + (pe_id + 1) % pes_per_cube
result["W"] = cube_base + (pe_id - 1) % pes_per_cube
# Level 2: inter-cube ring (N/S, PE 0 only)
if pe_id == 0 and cubes_per_sip > 1:
sip_base = sip_id * cubes_per_sip * pes_per_cube
next_cube_pe0 = sip_base + ((local_cube_id + 1) % cubes_per_sip) * pes_per_cube
prev_cube_pe0 = sip_base + ((local_cube_id - 1) % cubes_per_sip) * pes_per_cube
result["N"] = next_cube_pe0
result["S"] = prev_cube_pe0
# Level 3: inter-SIP (parent, PE 0 cube 0 only)
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
other_sip_pe0 = ((sip_id + 1) % num_sips) * cubes_per_sip * pes_per_cube
result["parent"] = other_sip_pe0
return result
def _bidir_reduce(tl, acc, my_id, group_size, dir_inc, dir_dec, shape, dtype):
"""Bidirectional ring reduce to node 0.
Left half (1..half): chain reduces via dir_dec (toward lower IDs).
Each PE recvs from higher PE (via dir_inc) and sends to lower (via dir_dec).
Right half (half+1..N-1): chain reduces via dir_inc (wraps to node 0).
Each PE recvs from lower PE (via dir_dec) and sends to higher (via dir_inc).
Node 0: recvs left sum via dir_inc, right sum via dir_dec.
Direction pairing: send dir_dec at K → recv dir_inc at K-1.
send dir_inc at K → recv dir_dec at K+1.
"""
if group_size <= 1:
return acc
half = group_size // 2
if my_id == 0:
# Representative: recv left-half sum via dir_inc (from PE 1)
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
acc = acc + recv
# Recv right-half sum via dir_dec (from PE N-1, wrapped)
if group_size - half - 1 >= 1:
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
acc = acc + recv
elif my_id <= half:
# Left half: recv from PE my_id+1 via dir_inc, send to PE my_id-1 via dir_dec
if my_id < half: # not the far-edge
recv = tl.recv(dir=dir_inc, shape=shape, dtype=dtype)
acc = acc + recv
tl.send(dir=dir_dec, src=acc)
else:
# Right half: recv from PE my_id-1 via dir_dec, send to PE my_id+1 via dir_inc
if my_id > half + 1: # not the near-edge
recv = tl.recv(dir=dir_dec, shape=shape, dtype=dtype)
acc = acc + recv
tl.send(dir=dir_inc, src=acc)
return acc
def _chain_broadcast(tl, acc, my_id, group_size, dir_inc, shape, dtype):
"""Linear chain broadcast from node 0 via dir_inc.
Node 0 sends via dir_inc → node 1. Node 1 recvs via dir_dec (implicit
from the ring pairing), stores, sends via dir_inc → node 2. Etc.
Recv direction = the opposite: send dir_inc at K → recv dir_dec at K+1.
"""
if group_size <= 1:
return acc
# In ring pairing: send via dir_inc at K → recv via dir_dec at K+1.
# dir_dec is the "other" direction. We infer it from the ring:
# if dir_inc is "E", peer recvs via "W"; if "N", peer recvs via "S".
_recv_dir = {"E": "W", "W": "E", "N": "S", "S": "N"}.get(dir_inc, dir_inc)
if my_id == 0:
tl.send(dir=dir_inc, src=acc)
else:
acc = tl.recv(dir=_recv_dir, shape=shape, dtype=dtype)
if my_id < group_size - 1:
tl.send(dir=dir_inc, src=acc)
return acc
def kernel(t_ptr, n_elem, pes_per_cube, cubes_per_sip, num_sips, tl):
"""Hierarchical all-reduce.
Args:
t_ptr: HBM base address (column-sharded VA).
n_elem: f16 elements per tile.
pes_per_cube: PEs per cube (typically 8).
cubes_per_sip: cubes per SIP (typically 16).
num_sips: number of SIPs (typically 2).
tl: TLContext (auto-injected).
"""
pe_id = tl.program_id(axis=0)
cube_global = tl.program_id(axis=1)
sip_id = cube_global // cubes_per_sip
local_cube_id = cube_global % cubes_per_sip
rank = cube_global * pes_per_cube + pe_id
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
shape = (n_elem,)
dtype = "f16"
# ── Level 1: intra-cube bidirectional reduce to PE 0 ──
acc = _bidir_reduce(
tl, acc, my_id=pe_id, group_size=pes_per_cube,
dir_inc="E", dir_dec="W", shape=shape, dtype=dtype,
)
# ── Level 2: inter-cube bidirectional reduce to cube 0 (PE 0 only) ──
if pe_id == 0 and cubes_per_sip > 1:
acc = _bidir_reduce(
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
dir_inc="N", dir_dec="S", shape=shape, dtype=dtype,
)
# ── Level 3: inter-SIP exchange (PE 0 cube 0 only) ──
if pe_id == 0 and local_cube_id == 0 and num_sips > 1:
tl.send(dir="parent", src=acc)
recv = tl.recv(dir="parent", shape=shape, dtype=dtype)
acc = acc + recv
# ── Broadcast back ──
# Level 2: cube 0 PE 0 → all PE 0s via chain
if pe_id == 0 and cubes_per_sip > 1:
acc = _chain_broadcast(
tl, acc, my_id=local_cube_id, group_size=cubes_per_sip,
dir_inc="N", shape=shape, dtype=dtype,
)
# Level 1: PE 0 → all PEs in cube via chain
acc = _chain_broadcast(
tl, acc, my_id=pe_id, group_size=pes_per_cube,
dir_inc="E", shape=shape, dtype=dtype,
)
tl.store(pe_addr, acc)
@@ -0,0 +1,189 @@
"""Intercube all-reduce kernel (pe0-only, same-lane across cubes).
Reduces across the 4×4 cube mesh within each SIP, then exchanges
between SIPs using the configured SIP topology, and broadcasts back.
Supported SIP topologies (selected via ``sip_topo_kind``):
0 — ring_1d: global_E/global_W ring, n_sips-1 rounds
1 — torus_2d: row ring (global_E/W) + col ring (global_S/N)
2 — mesh_2d: row chain reduce+broadcast + col chain reduce+broadcast
IPCQ wiring is handled by ``configure_sfr_intercube_multisip``.
"""
from __future__ import annotations
SIP_TOPO_RING = 0
SIP_TOPO_TORUS = 1
SIP_TOPO_MESH = 2
TOPO_NAME_TO_KIND = {
"ring_1d": SIP_TOPO_RING,
"torus_2d": SIP_TOPO_TORUS,
"mesh_2d": SIP_TOPO_TORUS,
"mesh_2d_no_wrap": SIP_TOPO_MESH,
}
def kernel_args(world_size: int, n_elem: int) -> tuple:
cube_w = 4
cube_h = 4
return (n_elem, cube_w, cube_h, world_size)
def _inter_sip_ring(acc, n_sips, n_elem, tl):
current = acc
for _ in range(n_sips - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
# Row ring (global_E / global_W)
current = acc
for _ in range(sip_topo_w - 1):
tl.send(dir="global_E", src=current)
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
# Col ring (global_S / global_N)
current = acc
for _ in range(sip_topo_h - 1):
tl.send(dir="global_S", src=current)
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
return acc
def _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl):
sip_row = sip_rank // sip_topo_w
sip_col = sip_rank % sip_topo_w
# Row reduce W → E
if sip_col == 0:
tl.send(dir="global_E", src=acc)
elif sip_col < sip_topo_w - 1:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_E", src=acc)
else:
recv = tl.recv(dir="global_W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Row broadcast E → W
if sip_col == sip_topo_w - 1:
tl.send(dir="global_W", src=acc)
elif sip_col > 0:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
tl.send(dir="global_W", src=acc)
else:
acc = tl.recv(dir="global_E", shape=(n_elem,), dtype="f16")
# Col reduce N → S
if sip_row == 0:
tl.send(dir="global_S", src=acc)
elif sip_row < sip_topo_h - 1:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="global_S", src=acc)
else:
recv = tl.recv(dir="global_N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# Col broadcast S → N
if sip_row == sip_topo_h - 1:
tl.send(dir="global_N", src=acc)
elif sip_row > 0:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
tl.send(dir="global_N", src=acc)
else:
acc = tl.recv(dir="global_S", shape=(n_elem,), dtype="f16")
return acc
def allreduce_intercube_multidevice(
t_ptr, n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h, tl,
):
"""Intercube all-reduce (pe0-only) with configurable SIP topology.
Args:
t_ptr: VA base of the row-wise-sharded tensor on this SIP.
n_elem: f16 elements per cube tile.
cube_w: cube mesh width (columns).
cube_h: cube mesh height (rows).
n_sips: number of SIPs.
sip_rank: this SIP's rank (0-based).
sip_topo_kind: 0=ring, 1=torus_2d, 2=mesh_2d.
sip_topo_w: SIP mesh width (for 2D topologies, 0 for ring).
sip_topo_h: SIP mesh height (for 2D topologies, 0 for ring).
tl: TLContext (auto-injected).
"""
cube_id = tl.program_id(axis=1)
row = cube_id // cube_w
col = cube_id % cube_w
nbytes = n_elem * 2
pe_addr = t_ptr + cube_id * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# ── Phase 1: row reduce W → E ──
if col == 0:
tl.send(dir="E", src=acc)
elif col < cube_w - 1:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="E", src=acc)
else:
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 2: col reduce N → S on rightmost column ──
if col == cube_w - 1:
if row == 0:
tl.send(dir="S", src=acc)
elif row < cube_h - 1:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
tl.send(dir="S", src=acc)
else:
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
# ── Phase 3: inter-SIP exchange on root cube ──
root_cube = (cube_h - 1) * cube_w + (cube_w - 1)
if cube_id == root_cube and n_sips > 1:
if sip_topo_kind == SIP_TOPO_RING:
acc = _inter_sip_ring(acc, n_sips, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_TORUS:
acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
elif sip_topo_kind == SIP_TOPO_MESH:
acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl)
# ── Phase 4: col broadcast S → N on rightmost column ──
if col == cube_w - 1:
if row == cube_h - 1:
tl.send(dir="N", src=acc)
elif row > 0:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
tl.send(dir="N", src=acc)
else:
acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16")
# ── Phase 5: row broadcast E → W ──
if col == cube_w - 1:
tl.send(dir="W", src=acc)
elif col > 0:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.send(dir="W", src=acc)
else:
acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16")
tl.store(pe_addr, acc)
kernel = allreduce_intercube_multidevice
@@ -1,73 +0,0 @@
"""2D-mesh all-reduce kernel (ADR-0023).
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
1. Row reduce: ring all-reduce along E/W within each row.
2. Column reduce: ring all-reduce along N/S within each column.
After both phases, every rank holds the global sum.
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
data flow so Phase 2 produces correct final HBM contents. Math/recv
handles are passed directly to the next send, avoiding store→reload
which doesn't propagate correctly with timing-only Phase 1 math.
"""
from __future__ import annotations
import math
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend.
Mesh all-reduce requires ``world_size`` to be a perfect square —
the mesh side length is ``sqrt(world_size)``.
"""
side = int(round(math.sqrt(world_size)))
if side * side != world_size:
raise ValueError(
f"mesh_allreduce requires a square world_size; got {world_size}"
)
return (n_elem, side)
def kernel(t_ptr, n_elem, side, tl):
"""All-reduce on a square mesh.
Args:
t_ptr: HBM base address (column-sharded VA shared across ranks)
n_elem: number of f16 elements per tile
side: mesh side length (sqrt(world_size))
tl: TLContext (ADR-0022).
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
current = acc
# ── Phase 1: row ring (E direction) ──
# Ring forwards each received tile (not the cumulative acc) so every
# tile passes through every rank exactly once.
for _ in range(side - 1):
tl.send(dir="E", src=current)
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
# store/reload here — the math handle's scratch addr is the source for
# the first column send and Phase 2 ipcq_copy replays from there.
current = acc
# ── Phase 2: column ring (S direction) ──
for _ in range(side - 1):
tl.send(dir="S", src=current)
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
tl.store(pe_addr, acc)
@@ -1,80 +0,0 @@
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
of all PE tiles.
Strategy
--------
Each PE starts with its own tile in HBM. The kernel:
1. Loads the local tile into a TensorHandle (the accumulator).
2. In each of ``world_size - 1`` rounds:
- Sends the current accumulator/recv slot to the E neighbor.
- Receives a tile from the W neighbor — the recv handle points
into the per-direction TCM slot.
- Adds the received tile to the accumulator using the TensorHandle
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
3. Stores the final accumulator back to HBM via tl.store. The store is
recorded in op_log with both src and dst, so Phase 2 will copy the
replayed math result from PE-local scratch into HBM.
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
not yet computed, so the accumulator data flowing through Phase 1 may
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
in stable t_start order, producing correct final HBM contents.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend.
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
"""
return (n_elem, world_size)
def kernel(t_ptr, n_elem, world_size, tl):
"""Ring all-reduce.
Args:
t_ptr: HBM base address of the column-sharded tensor — all PEs
share this base. The per-PE slice lives at
``t_ptr + global_rank * n_elem * 2``.
n_elem: number of f16 elements per tile.
world_size: total number of participating ranks (passed by host).
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
global rank from ``program_id(axis=0)`` (local PE) and
``program_id(axis=1)`` (cube id):
rank = cube_id * pes_per_cube + local_pe
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2 # f16
# Each PE reads from its own slice of the shared base address
pe_addr = t_ptr + rank * nbytes
# Load the local tile — handle points at HBM[pe_addr].
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# The ring forwards each received tile to the next neighbor (NOT the
# cumulative accumulator), so every rank's tile passes through every
# rank exactly once. The accumulator sums the new arrival each round.
current = acc
for _step in range(world_size - 1):
tl.send(dir="E", src=current)
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
# numpy in Phase 2 via DataExecutor). The result handle lives at
# an auto-allocated PE-local scratch addr.
acc = acc + recv
current = recv # forward W's tile to E next round
# Final result back to this PE's HBM slice. Op_log captures the
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
# accumulated value into HBM for verification.
tl.store(pe_addr, acc)
@@ -1,80 +0,0 @@
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
Two-phase binary tree all-reduce:
Phase 1 (reduce up):
- leaf nodes send their value to ``parent``
- internal nodes recv from each child, sum, then send to ``parent``
- root accumulates child contributions; final acc holds global sum
Phase 2 (broadcast down):
- root sends acc to ``child_left`` and ``child_right`` (if present)
- internal nodes recv from ``parent``, then forward to children
- all ranks store the final acc to HBM
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
data flow so Phase 2 produces correct final HBM contents. The kernel
deliberately avoids the store→reload→send pattern: math/recv handles
are passed directly to the next send so PE_DMA snapshots a deterministic
source addr that Phase 2 can replay.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend."""
return (n_elem, world_size)
def kernel(t_ptr, n_elem, world_size, tl):
"""Tree all-reduce.
Args:
t_ptr: HBM base address.
n_elem: number of f16 elements per tile.
world_size: total number of participating ranks (passed by host).
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# Compute children/parent existence (matches tree_binary topology generator)
has_parent = rank > 0
left = 2 * rank + 1
right = 2 * rank + 2
has_left = left < world_size
has_right = right < world_size
# ── Phase 1: reduce up ──
if has_left:
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
acc = acc + recv
if has_right:
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
acc = acc + recv
if has_parent:
# Send the math/load handle directly — its addr is either the
# original HBM tile (leaf) or the PE-local scratch where the
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
tl.send(dir="parent", src=acc)
# ── Phase 2: broadcast down ──
if has_parent:
# Replace acc with the value broadcast from the parent (the global
# sum). The recv handle points at the parent-direction TCM slot.
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
if has_left:
tl.send(dir="child_left", src=acc)
if has_right:
tl.send(dir="child_right", src=acc)
# Final store to HBM for the bench's verification path.
tl.store(pe_addr, acc)
+5 -1
View File
@@ -219,7 +219,11 @@ def install_ipcq(
"neighbor_table": neighbor_table, "neighbor_table": neighbor_table,
} }
_OPPOSITE_DIR = {"E": "W", "W": "E", "N": "S", "S": "N"} _OPPOSITE_DIR = {
"E": "W", "W": "E", "N": "S", "S": "N",
"global_E": "global_W", "global_W": "global_E",
"global_N": "global_S", "global_S": "global_N",
}
def reverse_direction(my_rank: int, peer_rank: int, my_dir: str) -> str | None: def reverse_direction(my_rank: int, peer_rank: int, my_dir: str) -> str | None:
"""Find peer's direction that reciprocates my_dir→peer_rank. """Find peer's direction that reciprocates my_dir→peer_rank.
+104
View File
@@ -0,0 +1,104 @@
"""SFR configuration for intercube + inter-SIP IPCQ wiring.
Provides ``configure_sfr_intercube_multisip`` which programs PE_IPCQ
neighbor tables for:
1. Intercube within each SIP — pe0 of every cube connects to pe0 of
its N/S/E/W mesh neighbors (no wrap-around).
2. Inter-SIP on ALL cubes — pe0 of cube_c on sip_A connects to pe0 of
cube_c on each peer SIP, using ``global_E``/``global_W`` (ring) or
``global_N``/``global_S``/``global_E``/``global_W`` (mesh/torus)
direction labels. Wiring all cubes allows the kernel to
dynamically elect the root cube at runtime.
SIP-level topology is read from ``topology.yaml`` →
``system.sips.topology`` (e.g. ``ring_1d``, ``mesh_2d``).
Intercube mesh dimensions come from ``sip.cube_mesh.w/h``.
Internally delegates to ``install_ipcq`` with a computed ``rank_to_pe``
(pe0-only) and a closure-captured ``neighbors()`` function.
"""
from __future__ import annotations
import types
from typing import Any
from kernbench.ccl.install import install_ipcq
from kernbench.ccl.topologies import _BUILTIN as _TOPO_BUILTINS
def configure_sfr_intercube_multisip(
engine: Any,
spec: dict,
cfg: dict,
) -> dict[str, Any]:
"""Wire IPCQ for intercube (pe0, mesh) + inter-SIP (pe0, all cubes).
Args:
engine: GraphEngine with ``_components``.
spec: topology spec dict (from topology.yaml).
cfg: merged algorithm config (from ``resolve_algorithm_config``).
Returns:
The install plan dict from ``install_ipcq``.
"""
cm = spec["sip"]["cube_mesh"]
mesh_w = int(cm["w"])
mesh_h = int(cm["h"])
n_cubes = mesh_w * mesh_h
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
sip_topology = str(
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
)
if sip_topology not in _TOPO_BUILTINS:
raise ValueError(
f"Unknown sip topology '{sip_topology}'. "
f"Available: {list(_TOPO_BUILTINS)}"
)
sip_topo_fn = _TOPO_BUILTINS[sip_topology]
world_size = n_sips * n_cubes
pe_idx_to_pe: list[tuple[int, int, int]] = [
(sip, cube, 0)
for sip in range(n_sips)
for cube in range(n_cubes)
]
def _neighbors(pe_idx: int, ws: int, _base: dict) -> dict[str, int]:
sip = pe_idx // n_cubes
cube = pe_idx % n_cubes
row = cube // mesh_w
col = cube % mesh_w
nbrs: dict[str, int] = {}
# Intercube within SIP (mesh, no wrap-around)
if col < mesh_w - 1:
nbrs["E"] = sip * n_cubes + (row * mesh_w + col + 1)
if col > 0:
nbrs["W"] = sip * n_cubes + (row * mesh_w + col - 1)
if row < mesh_h - 1:
nbrs["S"] = sip * n_cubes + ((row + 1) * mesh_w + col)
if row > 0:
nbrs["N"] = sip * n_cubes + ((row - 1) * mesh_w + col)
# Inter-SIP on ALL cubes
if n_sips > 1:
sip_nbrs = sip_topo_fn(sip, n_sips)
for d, peer_sip in sip_nbrs.items():
nbrs[f"global_{d}"] = peer_sip * n_cubes + cube
return nbrs
mock_module = types.SimpleNamespace(neighbors=_neighbors)
cfg_copy = dict(cfg)
cfg_copy["world_size"] = world_size
cfg_copy["topology"] = "none"
return install_ipcq(
engine, spec, cfg_copy,
algo_module=mock_module,
rank_to_pe=pe_idx_to_pe,
)
-492
View File
@@ -1,492 +0,0 @@
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
no PE_DMA, no fabric simulation. Just enough to verify *functional*
correctness of an IPCQ-based collective algorithm.
Cross-rank send/recv is implemented with greenlet cooperative scheduling
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
queues are unbounded.
Typical usage in a test::
from kernbench.ccl.testing import run_kernel_in_mock
from kernbench.ccl.algorithms.ring_allreduce import kernel
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
outputs = run_kernel_in_mock(
kernel_fn=kernel, world_size=4, topology="ring_1d",
inputs=inputs, kernel_args=(16,),
)
for r in range(4):
assert np.allclose(outputs[r], sum(inputs))
"""
from __future__ import annotations
from collections import deque
from typing import Any, Callable
import numpy as np
from greenlet import greenlet
from kernbench.ccl.topologies import resolve_topology
from kernbench.common.ipcq_types import IpcqInvalidDirection
from kernbench.common.pe_commands import TensorHandle
# ── Per-rank fake state ──────────────────────────────────────────────
class _MockRankState:
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
def __init__(
self,
rank: int,
world_size: int,
neighbors: dict[str, int],
input_arr: np.ndarray,
pes_per_cube: int = 0,
) -> None:
self.rank = rank
self.world_size = world_size
# PEs per cube for program_id(axis=0/1). If 0 or world_size,
# all ranks are in one cube (legacy single-cube behavior).
self.pes_per_cube = pes_per_cube if pes_per_cube > 0 else world_size
self.neighbors = neighbors # direction → peer rank
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
self._hbm: dict[int, np.ndarray] = {}
self._tcm: dict[int, np.ndarray] = {}
# ``t_ptr`` is the address the kernel sees. Real benches use a
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
# Mirror that here: each rank's slice lives at the rank-specific addr.
nbytes = int(input_arr.nbytes)
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
self._slice_addr = rank * nbytes
self._hbm[self._slice_addr] = input_arr.copy()
# Inbound recv FIFOs: direction → deque[ndarray]
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
# Output (set when kernel calls tl.store at slice address)
self.output: np.ndarray | None = None
# Greenlet for this rank — set later
self.g: greenlet | None = None
# ── Mock TLContext ───────────────────────────────────────────────────
class _MockTL:
"""Drop-in tl shim for mock runtime.
Supports the subset of TLContext API that algorithm authors use:
program_id, num_programs, load, store, send, recv, recv_async, wait,
plus arithmetic operations on TensorHandle (eager numpy execution,
no SimPy involved).
"""
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
self._state = state
self._scheduler = scheduler
self._handle_counter = 0
def _next_id(self) -> str:
self._handle_counter += 1
return f"mt{self._handle_counter}"
@property
def rank(self) -> int:
return self._state.rank
@property
def world_size(self) -> int:
return self._state.world_size
# axis-aware
def program_id(self, axis: int = 0) -> int:
# Multi-cube: axis=0 = PE within cube, axis=1 = global cube id.
# Falls back to flat (all ranks in one cube) if pes_per_cube
# is not set (legacy single-cube tests).
ppc = self._state.pes_per_cube
if axis == 1:
return self._state.rank // ppc
return self._state.rank % ppc
def num_programs(self, axis: int = 0) -> int:
ppc = self._state.pes_per_cube
if axis == 1:
return self._state.world_size // ppc
return ppc
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
a_data = np.asarray(a.data) if a.data is not None else None
b_data = np.asarray(b.data) if b.data is not None else None
if a_data is None or b_data is None:
result = None
elif op == "add":
result = a_data + b_data
elif op == "sub":
result = a_data - b_data
elif op == "mul":
result = a_data * b_data
elif op == "div":
result = a_data / b_data
elif op == "maximum":
result = np.maximum(a_data, b_data)
elif op == "minimum":
result = np.minimum(a_data, b_data)
else:
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
return TensorHandle(
id=self._next_id(),
addr=0, shape=a.shape, dtype=a.dtype,
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
data=result, space="tcm",
)
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
return self._binary_math("maximum", a, b)
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
return self._binary_math("minimum", a, b)
def fma(
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
) -> TensorHandle:
a_data = np.asarray(a.data) if a.data is not None else None
b_data = np.asarray(b.data) if b.data is not None else None
c_data = np.asarray(c.data) if c.data is not None else None
result = (
a_data * b_data + c_data
if (a_data is not None and b_data is not None and c_data is not None)
else None
)
return TensorHandle(
id=self._next_id(),
addr=0, shape=a.shape, dtype=a.dtype,
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
data=result, space="tcm",
)
def clamp(
self,
x: TensorHandle,
min: TensorHandle,
max: TensorHandle,
) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
lo = np.asarray(min.data) if min.data is not None else None
hi = np.asarray(max.data) if max.data is not None else None
result = (
np.minimum(np.maximum(x_data, lo), hi)
if (x_data is not None and lo is not None and hi is not None)
else None
)
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
if x_data is None:
result = None
else:
x_max = np.max(x_data, axis=axis, keepdims=True)
e = np.exp(x_data - x_max)
s = np.sum(e, axis=axis, keepdims=True)
result = e / s
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
@staticmethod
def cdiv(a: int, b: int) -> int:
return -(-int(a) // int(b))
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
if x_data is None:
result = None
elif op == "exp":
result = np.exp(x_data)
elif op == "log":
result = np.log(x_data)
elif op == "sqrt":
result = np.sqrt(x_data)
elif op == "abs":
result = np.abs(x_data)
elif op == "sigmoid":
result = 1.0 / (1.0 + np.exp(-x_data))
elif op == "cos":
result = np.cos(x_data)
elif op == "sin":
result = np.sin(x_data)
else:
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
data = self._state._hbm.get(ptr)
if data is None:
data = np.zeros(shape, dtype=np.float16)
return TensorHandle(
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
)
def store(self, ptr: int, handle: TensorHandle) -> None:
if handle.data is not None:
self._state._hbm[ptr] = np.asarray(handle.data)
if ptr == self._state._slice_addr:
self._state.output = self._state._hbm[ptr]
# IPCQ
def send(
self,
dir: str,
src: TensorHandle | None = None,
*,
src_addr: int | None = None,
nbytes: int | None = None,
shape: tuple[int, ...] | None = None,
dtype: str = "f16",
space: str = "tcm",
) -> None:
if dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
)
if src is not None:
if src.data is not None:
data = np.asarray(src.data)
else:
# Resolve from this rank's local memory at src.addr
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
stored = space_dict.get(src.addr)
if stored is None:
raise RuntimeError(
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
)
data = np.asarray(stored)
else:
data = None
if data is None:
raise RuntimeError("mock tl.send: src is None")
peer_rank = self._state.neighbors[dir]
# Find the reverse direction at the peer, mirroring real IPCQ
# install pairing: N↔S, E↔W, parent↔parent, child_left↔child_left, etc.
_REVERSE = {"N": "S", "S": "N", "E": "W", "W": "E",
"parent": "parent", "child_left": "child_left",
"child_right": "child_right"}
peer_state = self._scheduler.states[peer_rank]
reverse_dir = _REVERSE.get(dir)
# Fall back to "first direction pointing at me" if the explicit
# reverse doesn't exist at the peer (e.g. custom directions).
if reverse_dir is None or reverse_dir not in peer_state.neighbors:
reverse_dir = None
for d, target in peer_state.neighbors.items():
if target == self._state.rank:
reverse_dir = d
break
if reverse_dir is None:
raise RuntimeError(
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
)
peer_state.recv_q[reverse_dir].append(data.copy())
self._scheduler._send_counter += 1
# After delivering, hand control back to scheduler so the receiver
# can wake up.
self._scheduler.yield_()
def recv_async(
self,
dir: str,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> dict:
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
if dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.recv_async: direction {dir!r} not in neighbors"
)
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
def wait(self, future: Any) -> TensorHandle:
"""Block until the recv future has data."""
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
raise TypeError("tl.wait: expected recv future from tl.recv_async")
d = future["dir"]
while not self._state.recv_q[d]:
self._scheduler.yield_()
data = self._state.recv_q[d].popleft()
return self._make_handle(data, d, future["dtype"])
def recv(
self,
dir: str | None = None,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> TensorHandle:
if dir is not None and dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
)
# Wait for data
while True:
if dir is None:
# round-robin over directions
for d in self._state.neighbors:
if self._state.recv_q[d]:
data = self._state.recv_q[d].popleft()
return self._make_handle(data, d, dtype)
else:
if self._state.recv_q[dir]:
data = self._state.recv_q[dir].popleft()
return self._make_handle(data, dir, dtype)
# Yield to other ranks
self._scheduler.yield_()
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
return TensorHandle(
id=f"recv_{direction}",
addr=0, shape=data.shape, dtype=dtype,
nbytes=int(data.nbytes), data=data, space="tcm",
)
# ── Cooperative scheduler ────────────────────────────────────────────
class _MockScheduler:
"""Round-robin cooperative scheduler over rank greenlets."""
def __init__(self, states: list[_MockRankState]) -> None:
self.states = states
self._parent: greenlet | None = None
self._cur_idx = 0
def yield_(self) -> None:
"""Called from inside a rank greenlet to give other ranks a turn."""
assert self._parent is not None
self._parent.switch()
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
from kernbench.triton_emu.tl_context import TLContext
self._parent = greenlet.getcurrent()
n = len(self.states)
# Per-rank tl shim
tls: dict[int, _MockTL] = {}
def _spawn(rank_idx: int) -> greenlet:
state = self.states[rank_idx]
tl = _MockTL(state, self)
tls[rank_idx] = tl
def _entry():
# Activate this rank's tl for TensorHandle operator overloads
TLContext._set_active(tl) # type: ignore[attr-defined]
try:
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
finally:
TLContext._set_active(None) # type: ignore[attr-defined]
return greenlet(_entry)
for state in self.states:
state.g = _spawn(state.rank)
# Drive each rank round-robin until all dead. Detect global deadlock.
# A global send counter tracks whether any greenlet delivered data
# in the current round. This is more reliable than queue-depth
# tracking because a recv+send pair in the same round nets to zero
# depth change yet still represents real progress.
self._send_counter = 0
max_idle_rounds = 10_000
idle_rounds = 0
while True:
alive = [s for s in self.states if s.g is not None and not s.g.dead]
if not alive:
break
counter_before = self._send_counter
for s in self.states:
if s.g is None or s.g.dead:
continue
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
s.g.switch()
TLContext._set_active(None) # type: ignore[attr-defined]
any_died = any(s.g is not None and s.g.dead for s in self.states)
if self._send_counter > counter_before or any_died:
idle_rounds = 0
else:
idle_rounds += 1
if idle_rounds >= max_idle_rounds:
raise RuntimeError(
"mock CCL runtime: deadlock detected (no progress for "
f"{max_idle_rounds} rounds)"
)
return [
s.output if s.output is not None else s._hbm.get(s._slice_addr)
for s in self.states
]
# ── Public entry ────────────────────────────────────────────────────
def run_kernel_in_mock(
kernel_fn: Callable,
world_size: int,
topology: str,
inputs: list[np.ndarray],
kernel_args: tuple = (),
algo_module: Any | None = None,
pes_per_cube: int = 0,
) -> list[np.ndarray]:
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
Args:
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
world_size: number of ranks
topology: builtin topology name (e.g. "ring_1d")
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
local tile at HBM address 0.
kernel_args: extra positional args after t_ptr
algo_module: optional module providing ``neighbors()`` override
pes_per_cube: PEs per cube for multi-cube program_id mapping.
0 → single-cube legacy (all ranks in one cube).
Returns:
Per-rank output ndarrays — whatever the kernel wrote via tl.store
(or the original input if the kernel didn't store).
"""
if len(inputs) != world_size:
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
topo_fn = resolve_topology(topology, algo_module=algo_module)
states = [
_MockRankState(
rank=r, world_size=world_size,
neighbors=topo_fn(r, world_size),
input_arr=inputs[r],
pes_per_cube=pes_per_cube,
)
for r in range(world_size)
]
sched = _MockScheduler(states)
return sched.run(kernel_fn, kernel_args)
+35
View File
@@ -73,6 +73,39 @@ def tree_binary(rank: int, world_size: int) -> NeighborMap:
return n return n
def torus_2d(rank: int, world_size: int) -> NeighborMap:
"""Square 2D torus (N/S/E/W) with wrap-around on all edges.
Alias for mesh_2d (which already wraps). Explicit name for clarity
when used as a SIP-level topology.
"""
return mesh_2d(rank, world_size)
def mesh_2d_no_wrap(rank: int, world_size: int) -> NeighborMap:
"""Square 2D mesh (N/S/E/W) WITHOUT wrap-around.
Edge nodes have fewer neighbors (no wrapping). Used for SIP-level
topologies where physical links don't wrap.
"""
side = int(round(world_size ** 0.5))
if side * side != world_size:
raise ValueError(
f"mesh_2d_no_wrap requires square world_size, got {world_size}"
)
r, c = divmod(rank, side)
n: NeighborMap = {}
if r > 0:
n["N"] = (r - 1) * side + c
if r < side - 1:
n["S"] = (r + 1) * side + c
if c > 0:
n["W"] = r * side + (c - 1)
if c < side - 1:
n["E"] = r * side + (c + 1)
return n
def none(rank: int, world_size: int) -> NeighborMap: def none(rank: int, world_size: int) -> NeighborMap:
"""Empty map — algorithm's neighbors() must build from scratch.""" """Empty map — algorithm's neighbors() must build from scratch."""
return {} return {}
@@ -82,6 +115,8 @@ _BUILTIN: dict[str, TopologyFn] = {
"ring_1d": ring_1d, "ring_1d": ring_1d,
"ring_1d_unidir": ring_1d_unidir, "ring_1d_unidir": ring_1d_unidir,
"mesh_2d": mesh_2d, "mesh_2d": mesh_2d,
"torus_2d": torus_2d,
"mesh_2d_no_wrap": mesh_2d_no_wrap,
"tree_binary": tree_binary, "tree_binary": tree_binary,
"none": none, "none": none,
} }
+87 -67
View File
@@ -1,3 +1,14 @@
"""Data-parallel placement policy (ADR-0026: intra-device only).
``DPPolicy`` describes how a tensor is sharded *within a single SIP* across
that SIP's cubes and PEs. Crossing the SIP boundary is not a DPPolicy
concern: ADR-0024's ``torch.ahbm.set_device(rank)`` picks the SIP, and
Megatron-style TP (ADR-0027) expresses multi-SIP tensors when needed.
``ShardSpec`` is expressed in structural ``(sip, cube, pe)`` coordinates.
The former flat ``pe_index`` field/property is fully removed — callers
needing a flat integer key compute it explicitly at the call site.
"""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@@ -7,25 +18,58 @@ from typing import Literal
@dataclass(frozen=True) @dataclass(frozen=True)
class DPPolicy: class DPPolicy:
"""Three-level data-parallel policy: sip-level + cube-level + pe-level. """Intra-device (cube × PE) data-parallel policy.
Policies: SIP-level placement is controlled by ``torch.ahbm.set_device(rank)``
(ADR-0024). For tensors that must cross SIP boundaries, use
Megatron-style parallel layers (ADR-0027). DPPolicy itself never
crosses a SIP boundary.
Policies (per axis):
- "replicate": full copy at each unit - "replicate": full copy at each unit
- "column_wise": split K (column) axis across units - "column_wise": split K (column) axis across units
- "row_wise": split M (row) axis across units - "row_wise": split M (row) axis across units
Optional overrides (default None = use topology dimensions): Optional overrides (``None`` = use topology dimensions):
- num_pes: override PEs per cube (e.g., 1 for single-PE test) - num_pes: override PEs per cube
- num_cubes: override cubes per SIP (e.g., 1 for single-cube test) - num_cubes: override cubes per SIP
- num_sips: override SIP count
""" """
sip: Literal["replicate", "column_wise", "row_wise"] = "replicate"
cube: Literal["replicate", "column_wise", "row_wise"] = "replicate" cube: Literal["replicate", "column_wise", "row_wise"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate" pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
num_pes: int | None = None num_pes: int | None = None
num_cubes: int | None = None num_cubes: int | None = None
num_sips: int | None = None
@dataclass(frozen=True)
class ShardSpec:
"""Structural shard placement — ``(sip, cube, pe)`` coord (ADR-0026).
Global-flat ``pe_index`` was removed: callers must use structural
coords directly. If a flat integer key is needed in a local context
(e.g. internal dict lookup), compute it explicitly at the call site
and do not expose it in any public API.
"""
sip: int
cube: int
pe: int
offset_bytes: int
nbytes: int
@dataclass(frozen=True)
class _LocalPeShard:
"""Internal — PE resolver's return type (ADR-0026 D3).
Holds a cube-local PE identifier (``local_pe``) plus the shard's
byte payload. Lifted into ``ShardSpec`` with full ``(sip, cube, pe)``
coordinates inside ``resolve_dp_policy``.
"""
local_pe: int
offset_bytes: int
nbytes: int
def _split_shape( def _split_shape(
@@ -52,14 +96,13 @@ def resolve_dp_policy(
itemsize: int, itemsize: int,
num_pe: int, num_pe: int,
num_cubes: int = 1, num_cubes: int = 1,
num_sips: int = 1, target_sip: int,
) -> list[ShardSpec]: ) -> list[ShardSpec]:
"""Resolve a DPPolicy into a list[ShardSpec] with three-level resolution. """Resolve a DPPolicy into a list[ShardSpec] on a single SIP.
SIP-level → cube-level → pe-level. Two-level resolution (cube × PE) within ``target_sip``. Each returned
num_cubes is cubes per SIP (not total). ``ShardSpec`` carries ``sip=target_sip`` and cube/pe local to the SIP.
ShardSpec.pe_index uses flat indexing: No SIP-level split — DPPolicy is intra-device only (ADR-0026).
sip_id * num_cubes * num_pe + cube_id * num_pe + pe_id
""" """
_PE_RESOLVERS = { _PE_RESOLVERS = {
"replicate": replicate, "replicate": replicate,
@@ -70,84 +113,61 @@ def resolve_dp_policy(
if resolver is None: if resolver is None:
raise ValueError(f"Unknown pe-level policy: {policy.pe}") raise ValueError(f"Unknown pe-level policy: {policy.pe}")
cubes_per_sip = num_cubes
all_shards: list[ShardSpec] = [] all_shards: list[ShardSpec] = []
# Level 1: SIP # Level 1: cube within SIP
sip_splits = _split_shape(policy.sip, shape, num_sips, itemsize) cube_splits = _split_shape(policy.cube, shape, num_cubes, itemsize)
for sip_id, (sip_shape, sip_offset) in enumerate(sip_splits):
# Level 2: Cube within SIP
cube_splits = _split_shape(policy.cube, sip_shape, cubes_per_sip, itemsize)
for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits): for cube_id, (cube_shape, cube_offset) in enumerate(cube_splits):
# Level 3: PE within cube # Level 2: PE within cube — resolver returns _LocalPeShard
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe) local_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
for ps in pe_shards: for ls in local_shards:
flat_idx = (
sip_id * cubes_per_sip * num_pe
+ cube_id * num_pe
+ ps.pe_index
)
all_shards.append(ShardSpec( all_shards.append(ShardSpec(
pe_index=flat_idx, sip=target_sip,
offset_bytes=sip_offset + cube_offset + ps.offset_bytes, cube=cube_id,
nbytes=ps.nbytes, pe=ls.local_pe,
offset_bytes=cube_offset + ls.offset_bytes,
nbytes=ls.nbytes,
)) ))
return all_shards return all_shards
@dataclass(frozen=True)
class ShardSpec:
pe_index: int
offset_bytes: int
nbytes: int
def column_wise( def column_wise(
*, shape: tuple[int, int], itemsize: int, num_pe: int, *, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]: ) -> list[_LocalPeShard]:
"""Split K axis into num_pe equal parts. Each PE gets (M, K/P).""" """Split K axis into num_pe equal parts. Each PE gets (M, K/P)."""
M, K = shape M, K = shape
chunk_k = K // num_pe chunk_k = K // num_pe
chunk_bytes = M * chunk_k * itemsize chunk_bytes = M * chunk_k * itemsize
shards = [] return [
for i in range(num_pe): _LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes)
shards.append(ShardSpec( for i in range(num_pe)
pe_index=i, ]
offset_bytes=i * chunk_bytes,
nbytes=chunk_bytes,
))
return shards
def row_wise( def row_wise(
*, shape: tuple[int, int], itemsize: int, num_pe: int, *, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]: ) -> list[_LocalPeShard]:
"""Split M axis into num_pe equal parts. Each PE gets (M/P, K).""" """Split M axis into num_pe equal parts. Each PE gets (M/P, K)."""
M, K = shape M, K = shape
chunk_m = M // num_pe chunk_m = M // num_pe
chunk_bytes = chunk_m * K * itemsize chunk_bytes = chunk_m * K * itemsize
shards = [] return [
for i in range(num_pe): _LocalPeShard(local_pe=i, offset_bytes=i * chunk_bytes, nbytes=chunk_bytes)
shards.append(ShardSpec( for i in range(num_pe)
pe_index=i, ]
offset_bytes=i * chunk_bytes,
nbytes=chunk_bytes,
))
return shards
def replicate( def replicate(
*, shape: tuple[int, int], itemsize: int, num_pe: int, *, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]: ) -> list[_LocalPeShard]:
"""Full copy per PE. Each PE gets (M, K).""" """Full copy per PE. Each PE gets (M, K)."""
M, K = shape M, K = shape
full_bytes = M * K * itemsize full_bytes = M * K * itemsize
return [ return [
ShardSpec(pe_index=i, offset_bytes=0, nbytes=full_bytes) _LocalPeShard(local_pe=i, offset_bytes=0, nbytes=full_bytes)
for i in range(num_pe) for i in range(num_pe)
] ]
@@ -155,20 +175,20 @@ def replicate(
def tiled_column_major( def tiled_column_major(
*, shape: tuple[int, int], itemsize: int, num_pe: int, *, shape: tuple[int, int], itemsize: int, num_pe: int,
tile_m: int, tile_k: int, tile_m: int, tile_k: int,
) -> list[ShardSpec]: ) -> list[_LocalPeShard]:
"""2D tiling, column-major order (K axis first), round-robin across PEs.""" """2D tiling, column-major order (K axis first), round-robin across PEs."""
M, K = shape M, K = shape
tiles_m = ceil(M / tile_m) tiles_m = ceil(M / tile_m)
tiles_k = ceil(K / tile_k) tiles_k = ceil(K / tile_k)
tile_bytes = tile_m * tile_k * itemsize tile_bytes = tile_m * tile_k * itemsize
row_bytes = K * itemsize row_bytes = K * itemsize
shards = [] shards: list[_LocalPeShard] = []
idx = 0 idx = 0
for mi in range(tiles_m): for mi in range(tiles_m):
for ki in range(tiles_k): for ki in range(tiles_k):
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize) offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
shards.append(ShardSpec( shards.append(_LocalPeShard(
pe_index=idx % num_pe, local_pe=idx % num_pe,
offset_bytes=offset, offset_bytes=offset,
nbytes=tile_bytes, nbytes=tile_bytes,
)) ))
@@ -179,20 +199,20 @@ def tiled_column_major(
def tiled_row_major( def tiled_row_major(
*, shape: tuple[int, int], itemsize: int, num_pe: int, *, shape: tuple[int, int], itemsize: int, num_pe: int,
tile_m: int, tile_k: int, tile_m: int, tile_k: int,
) -> list[ShardSpec]: ) -> list[_LocalPeShard]:
"""2D tiling, row-major order (M axis first), round-robin across PEs.""" """2D tiling, row-major order (M axis first), round-robin across PEs."""
M, K = shape M, K = shape
tiles_m = ceil(M / tile_m) tiles_m = ceil(M / tile_m)
tiles_k = ceil(K / tile_k) tiles_k = ceil(K / tile_k)
tile_bytes = tile_m * tile_k * itemsize tile_bytes = tile_m * tile_k * itemsize
row_bytes = K * itemsize row_bytes = K * itemsize
shards = [] shards: list[_LocalPeShard] = []
idx = 0 idx = 0
for ki in range(tiles_k): for ki in range(tiles_k):
for mi in range(tiles_m): for mi in range(tiles_m):
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize) offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
shards.append(ShardSpec( shards.append(_LocalPeShard(
pe_index=idx % num_pe, local_pe=idx % num_pe,
offset_bytes=offset, offset_bytes=offset,
nbytes=tile_bytes, nbytes=tile_bytes,
)) ))
+134 -20
View File
@@ -42,6 +42,59 @@ def _numpy_to_dtype_str(np_dtype) -> str:
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}") raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
# ADR-0027 D3: weak registry of the currently-active RuntimeContext so
# module-level helpers (e.g. ``kernbench.tp.parallel_state``) can resolve
# the ctx without threading it through every call.
import weakref as _weakref
_ACTIVE_CTX_REF: _weakref.ref | None = None
def _get_active_context():
"""Return the most-recently-entered RuntimeContext, or None."""
if _ACTIVE_CTX_REF is None:
return None
return _ACTIVE_CTX_REF()
class _AhbmNamespace:
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
Real-PyTorch parity idiom: ``torch.cuda.set_device(rank)``. KernBench's
backend is 'ahbm' (not CUDA), so this namespace avoids pretending to be
a CUDA runtime.
"""
def __init__(self) -> None:
self._device_by_greenlet: dict = {}
def set_device(self, device: int) -> None:
from greenlet import getcurrent
self._device_by_greenlet[getcurrent()] = int(device)
def current_device(self) -> int | None:
from greenlet import getcurrent
return self._device_by_greenlet.get(getcurrent())
class _AcceleratorNamespace:
"""torch.accelerator — device-agnostic alias (PyTorch 2.x style).
Wraps _AhbmNamespace. Bench code can pick either:
torch.ahbm.set_device(rank) # explicit backend
torch.accelerator.set_device_index(rank) # portable
"""
def __init__(self, ahbm: "_AhbmNamespace") -> None:
self._ahbm = ahbm
def set_device_index(self, device: int) -> None:
self._ahbm.set_device(device)
def current_device_index(self) -> int | None:
return self._ahbm.current_device()
@dataclass @dataclass
class RuntimeContext: class RuntimeContext:
engine: SimEngine engine: SimEngine
@@ -51,7 +104,11 @@ class RuntimeContext:
_handles: list[RequestHandle] = field(default_factory=list, init=False) _handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False) _completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False) # ADR-0027 D0.1: worker-deferred wait queue. When a worker greenlet
# calls ctx.wait(h), the handle is appended here and control yields to
# main. Main's scheduler drain consumes this list.
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
_allocators: dict[tuple[int, int, int], Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False) _va_allocator: Any = field(default=None, init=False)
_tensor_counter: int = field(default=0, init=False) _tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False) _traces: list[dict] = field(default_factory=list, init=False)
@@ -67,6 +124,13 @@ class RuntimeContext:
dc = DistributedContext() dc = DistributedContext()
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc. dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
self.distributed = dc self.distributed = dc
# ADR-0024 D10: torch.ahbm (KernBench-native) + torch.accelerator
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
self.ahbm = _AhbmNamespace()
self.accelerator = _AcceleratorNamespace(self.ahbm)
# ADR-0027 D1.3: torch.multiprocessing.spawn namespace.
from kernbench.runtime_api.multiprocessing import _MultiprocessingNamespace
self.multiprocessing = _MultiprocessingNamespace(self)
def install_ipcq( def install_ipcq(
self, self,
@@ -118,10 +182,16 @@ class RuntimeContext:
return plan return plan
def __enter__(self): def __enter__(self):
global _ACTIVE_CTX_REF
_ACTIVE_CTX_REF = _weakref.ref(self)
return self return self
def __exit__(self, *exc): def __exit__(self, *exc):
global _ACTIVE_CTX_REF
self.cleanup() self.cleanup()
# Clear active-context registry if we are it.
if _ACTIVE_CTX_REF is not None and _ACTIVE_CTX_REF() is self:
_ACTIVE_CTX_REF = None
return False return False
def submit(self, request: Any) -> RequestHandle: def submit(self, request: Any) -> RequestHandle:
@@ -136,10 +206,24 @@ class RuntimeContext:
return handle in self._completed return handle in self._completed
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion: def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
# ADR-0027 D0.2: fast-path for already-completed handles (avoid
# redundant worker→main→worker round-trip).
if handle in self._completed: if handle in self._completed:
completion, trace = self.engine.get_completion(handle) completion, trace = self.engine.get_completion(handle)
return completion return completion
# ADR-0027 D0.2: if called from a worker greenlet (parent is main,
# not dead), defer the wait to the main scheduler — enqueue and
# yield. Main drains env.run, then switches back. On resume the
# handle must be in _completed (D0.3 resume invariant).
from greenlet import getcurrent
g = getcurrent()
if g.parent is not None and not g.parent.dead:
self._pending_worker_waits.append(handle)
g.parent.switch()
# Resume: main drained. Fall through to completion/trace assembly.
# Main context (or single-driver): drive engine directly.
wait_fn = getattr(self.engine, "wait", None) wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None: if wait_fn is not None:
wait_fn(handle) # type: ignore[misc] wait_fn(handle) # type: ignore[misc]
@@ -228,12 +312,7 @@ class RuntimeContext:
# Return PA space # Return PA space
if self._allocators: if self._allocators:
for shard in handle.shards: for shard in handle.shards:
flat_idx = ( alloc = self._allocators.get((shard.sip, shard.cube, shard.pe))
shard.sip * self._num_cubes * self._pes_per_cube
+ shard.cube * self._pes_per_cube
+ shard.pe
)
alloc = self._allocators.get(flat_idx)
if alloc is not None: if alloc is not None:
from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.phyaddr import PhysAddr
alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes) alloc.free_hbm(PhysAddr.decode(shard.pa), shard.nbytes)
@@ -297,17 +376,15 @@ class RuntimeContext:
tcm_scheduler_reserved_bytes=4 * (1 << 20), tcm_scheduler_reserved_bytes=4 * (1 << 20),
sram_bytes_per_cube=32 * (1 << 20), sram_bytes_per_cube=32 * (1 << 20),
) )
# Create allocators scoped to target SIP(s) only # Create allocators scoped to target SIP(s) only.
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id # ADR-0026 D5: dict key is the structural (sip, cube, pe) tuple.
self._pes_per_cube = pes_per_cube self._pes_per_cube = pes_per_cube
self._num_cubes = cubes_per_sip self._num_cubes = cubes_per_sip
self._num_sips = sip_count self._num_sips = sip_count
cubes_x_pes = cubes_per_sip * pes_per_cube
for sip_id in sip_range: for sip_id in sip_range:
for cube_id in range(cubes_per_sip): for cube_id in range(cubes_per_sip):
for pe_id in range(pes_per_cube): for pe_id in range(pes_per_cube):
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id self._allocators[(sip_id, cube_id, pe_id)] = PEMemAllocator(
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg, rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
) )
@@ -394,16 +471,23 @@ class RuntimeContext:
# DPPolicy overrides take precedence over topology dimensions # DPPolicy overrides take precedence over topology dimensions
eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube eff_num_pe = dp.num_pes if dp.num_pes is not None else self._pes_per_cube
eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes eff_num_cubes = dp.num_cubes if dp.num_cubes is not None else self._num_cubes
eff_num_sips = dp.num_sips if dp.num_sips is not None else self._num_sips # ADR-0026 D4: resolve structural coords directly at resolve time.
# ``torch.ahbm.set_device(rank)`` (ADR-0024 D10) selects the target
# SIP; if unset, fall back to SIP 0 for single-driver compatibility.
current_sip = (
self.ahbm.current_device() if hasattr(self, "ahbm") else None
)
if current_sip is None:
current_sip = 0
placement = resolve_dp_policy( placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize, dp, shape=shape_2d, itemsize=itemsize,
num_pe=eff_num_pe, num_cubes=eff_num_cubes, num_pe=eff_num_pe, num_cubes=eff_num_cubes,
num_sips=eff_num_sips, target_sip=int(current_sip),
) )
# Infer target_pe from placement using local (within-cube) PE IDs. # Infer target_pe from placement using local (within-cube) PE IDs.
# This ensures M_CPU only fans out to PEs that own shards, not all PEs. # This ensures M_CPU only fans out to PEs that own shards, not all PEs.
local_pe_ids = sorted({s.pe_index % eff_num_pe for s in placement}) local_pe_ids = sorted({s.pe for s in placement})
if len(local_pe_ids) == 1: if len(local_pe_ids) == 1:
target_pe: int | tuple[int, ...] | str = local_pe_ids[0] target_pe: int | tuple[int, ...] | str = local_pe_ids[0]
elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube: elif len(local_pe_ids) == eff_num_pe and eff_num_pe == self._pes_per_cube:
@@ -501,6 +585,21 @@ class RuntimeContext:
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe, "sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
"nbytes": shard.nbytes, "nbytes": shard.nbytes,
}) })
# ADR-0027: also populate MemoryStore at VA keys so kernels
# reading via VA (the common ``tl.load`` path) see the init
# data. Phase 1 MemoryWriteMsg writes via PA; kernels read via
# VA; Phase 2 DataExecutor reads via the addresses captured in
# op_log (VA for tl.load). Without this, zero-init tensors are
# invisible to kernels in Phase 2.
store = getattr(self.engine, "_memory_store", None)
if store is not None and pattern == "zero" and handle.va_base:
import numpy as np
from kernbench.runtime_api.tensor import _numpy_dtype
np_dtype = _numpy_dtype(dtype)
for shard in handle.shards:
count = shard.nbytes // itemsize
addr = handle.va_base + shard.offset_bytes
store.write("hbm", addr, np.zeros(count, dtype=np_dtype))
return t return t
@@ -509,6 +608,7 @@ class RuntimeContext:
kernel_name: str, kernel_name: str,
kernel_fn: Any, kernel_fn: Any,
*args: Any, *args: Any,
_defer_wait: bool = False,
**kwargs: Any, **kwargs: Any,
) -> RequestHandle: ) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op). """Register and launch a kernel (like a fused torch op).
@@ -518,6 +618,11 @@ class RuntimeContext:
Creates per-SIP KernelLaunchMsg with local va_base per tensor Creates per-SIP KernelLaunchMsg with local va_base per tensor
(like host driver sending per-rank launch commands). (like host driver sending per-rank launch commands).
When ``_defer_wait=True`` (ADR-0024 D7), returns the list of
``(handle, sip_id, meta)`` tuples instead of waiting. Caller is
responsible for waiting — used by collective ops to yield between
submit and wait so all sibling ranks can submit first.
""" """
from collections import defaultdict from collections import defaultdict
@@ -593,11 +698,8 @@ class RuntimeContext:
dp = t._dp_metadata.dp_policy if t._dp_metadata else None dp = t._dp_metadata.dp_policy if t._dp_metadata else None
if dp is None: if dp is None:
return t.shape return t.shape
if dp.sip != "replicate": # ADR-0026: DPPolicy no longer crosses SIP boundaries; cube + PE
if dp.sip == "column_wise": # are the only axes that shrink the local shape.
K = K // self._num_sips
elif dp.sip == "row_wise":
M = M // self._num_sips
if dp.cube != "replicate": if dp.cube != "replicate":
if dp.cube == "column_wise": if dp.cube == "column_wise":
K = K // self._num_cubes K = K // self._num_cubes
@@ -683,6 +785,18 @@ class RuntimeContext:
_pending_handles.append((h, sip_id)) _pending_handles.append((h, sip_id))
last_handle = h last_handle = h
if _defer_wait:
# ADR-0024 D7: return the pending-list so the caller can yield
# between submit and drain. Used by collective ops that need
# all sibling ranks to submit before any rank waits.
return [
(h, sip_id, {
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
for h, sip_id in _pending_handles
]
# Drain pending handles now that every SIP has a launch posted. # Drain pending handles now that every SIP has a launch posted.
for h, sip_id in _pending_handles: for h, sip_id in _pending_handles:
self.wait(h, _meta={ self.wait(h, _meta={
+88 -25
View File
@@ -23,6 +23,7 @@ Host bench code uses only real-PyTorch names:
from __future__ import annotations from __future__ import annotations
import importlib import importlib
import math
from typing import Any from typing import Any
@@ -40,20 +41,44 @@ class AhbmCCLBackend:
self._merged = resolve_algorithm_config(self._cfg_all) self._merged = resolve_algorithm_config(self._cfg_all)
self._algo_module = importlib.import_module(self._merged["module"]) self._algo_module = importlib.import_module(self._merged["module"])
self._world_size = self._resolve_world_size() self._world_size = self._resolve_world_size()
self._pending_collective_handles: list = []
self._dist_ctx: Any = None
# Eager IPCQ install — ``init_process_group`` time. Mirrors NCCL spec = self.ctx.spec or {}
# communicator creation: done once, reused across every subsequent self._n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
# collective call on the same process group. self._sip_topo = str(
self.ctx.install_ipcq( spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
algorithm=self._merged["algorithm"],
world_size_override=self._world_size,
) )
cm = spec.get("sip", {}).get("cube_mesh", {})
self._cube_w = int(cm.get("w", 4))
self._cube_h = int(cm.get("h", 4))
# Resolve SIP topology dims for the kernel
topo_map = getattr(self._algo_module, "TOPO_NAME_TO_KIND", None)
if topo_map is not None:
self._sip_topo_kind = topo_map.get(self._sip_topo, 0)
else:
self._sip_topo_kind = 0
if self._sip_topo == "ring_1d":
self._sip_topo_w, self._sip_topo_h = 0, 0
else:
side = int(round(math.sqrt(self._n_sips)))
self._sip_topo_w, self._sip_topo_h = side, side
# IPCQ install: wire all pe0s across all cubes and SIPs
engine = getattr(self.ctx, "engine", None)
if engine is not None:
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
configure_sfr_intercube_multisip(engine, spec, self._merged)
def _resolve_world_size(self) -> int: def _resolve_world_size(self) -> int:
"""Derive world_size (priority: algorithm override > defaults > topology). """Derive world_size (priority: algorithm override > defaults > topology).
Topology derivation: ADR-0024 D1: topology fallback is SIP count. Each rank represents one
sips × cubes_per_sip × pes_per_cube SIP (TP dimension). Intra-SIP parallelism is expressed via DPPolicy
inside each worker and is independent of world_size.
Explicit ``ccl.yaml`` override still respected — legacy "rank = flat
PE index" tests use this path.
""" """
if "world_size" in self._merged: if "world_size" in self._merged:
return int(self._merged["world_size"]) return int(self._merged["world_size"])
@@ -61,14 +86,7 @@ class AhbmCCLBackend:
if "world_size" in defaults: if "world_size" in defaults:
return int(defaults["world_size"]) return int(defaults["world_size"])
spec = self.ctx.spec or {} spec = self.ctx.spec or {}
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1)) return int(spec.get("system", {}).get("sips", {}).get("count", 1))
cm = spec.get("sip", {}).get("cube_mesh", {})
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
pl = spec.get("cube", {}).get("pe_layout", {})
corners = pl.get("corners", [])
pe_per_corner = int(pl.get("pe_per_corner", 1))
pes_per_cube = pe_per_corner * max(len(corners), 1)
return sips * cubes_per_sip * pes_per_cube
@property @property
def world_size(self) -> int: def world_size(self) -> int:
@@ -89,20 +107,48 @@ class AhbmCCLBackend:
"with a DPPolicy first)" "with a DPPolicy first)"
) )
shards = tensor._handle.shards shards = tensor._handle.shards
if len(shards) != self._world_size: if not shards:
raise RuntimeError( raise RuntimeError(
f"all_reduce tensor has {len(shards)} shards but the " f"all_reduce tensor '{tensor.name}' has no shards"
f"ahbm backend was installed with world_size="
f"{self._world_size}; adjust the tensor's DPPolicy or "
"restart the process group"
) )
n_elem = shards[0].nbytes // tensor.itemsize n_elem = shards[0].nbytes // tensor.itemsize
kernel_fn = self._algo_module.kernel kernel_fn = self._algo_module.kernel
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem) kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
self.ctx.launch(
self._merged["algorithm"], kernel_fn, tensor, *kernel_args, # Resolve sip_rank from the current greenlet's bound rank
from greenlet import getcurrent as _gc
g = _gc()
dist_ctx = getattr(self, "_dist_ctx", None)
if dist_ctx is not None:
sip_rank = int(dist_ctx._rank_by_greenlet.get(g, 0))
else:
sip_rank = 0
extra_args = (
sip_rank,
self._sip_topo_kind,
self._sip_topo_w,
self._sip_topo_h,
) )
pending = self.ctx.launch(
self._merged["algorithm"], kernel_fn, tensor,
*kernel_args, *extra_args,
_defer_wait=True,
)
from greenlet import getcurrent
g = getcurrent()
if g.parent is not None and not g.parent.dead:
# Multi-greenlet mode: hand pending to the backend-level queue so
# the main scheduler drains. Worker just yields.
self._pending_collective_handles.extend(pending)
g.parent.switch()
# On resume, all pending handles have been drained by main.
else:
# Single-driver (no bench scheduler): drain inline.
for h, _sip_id, meta in pending:
self.ctx.wait(h, _meta=meta)
def barrier(self) -> None: def barrier(self) -> None:
# Single-driver model → no cross-process sync needed. Keeping the # Single-driver model → no cross-process sync needed. Keeping the
# method so ``dist.barrier()`` is callable (pytorch-compat surface). # method so ``dist.barrier()`` is callable (pytorch-compat surface).
@@ -121,6 +167,11 @@ class DistributedContext:
def __init__(self) -> None: def __init__(self) -> None:
self._backend: AhbmCCLBackend | None = None self._backend: AhbmCCLBackend | None = None
# ADR-0024 D9: greenlet-local rank registry. Bench launcher calls
# _bind_rank(g, rank) when spawning workers; get_rank() resolves the
# current greenlet to its rank. Unbound greenlets fall back to 0 for
# single-driver test compat.
self._rank_by_greenlet: dict = {}
def init_process_group( def init_process_group(
self, self,
@@ -146,6 +197,7 @@ class DistributedContext:
"DistributedContext not bound to a RuntimeContext" "DistributedContext not bound to a RuntimeContext"
) )
self._backend = AhbmCCLBackend(torch_ctx=ctx) self._backend = AhbmCCLBackend(torch_ctx=ctx)
self._backend._dist_ctx = self
def is_initialized(self) -> bool: def is_initialized(self) -> bool:
return self._backend is not None return self._backend is not None
@@ -155,9 +207,20 @@ class DistributedContext:
return self._backend.world_size return self._backend.world_size
def get_rank(self) -> int: def get_rank(self) -> int:
# Single-driver kernbench: there is only one host rank. """Return the rank bound to the current greenlet (default 0).
ADR-0024 D9: workers spawned by the bench launcher each get a rank
registered via ``_bind_rank``. Callers outside any bound greenlet
fall back to rank 0 for single-driver test compat.
"""
self._ensure_initialized() self._ensure_initialized()
return 0 from greenlet import getcurrent
g = getcurrent()
return int(self._rank_by_greenlet.get(g, 0))
def _bind_rank(self, g: Any, rank: int) -> None:
"""Bind a greenlet to a rank so ``get_rank()`` returns it (ADR-0024 D9)."""
self._rank_by_greenlet[g] = int(rank)
def get_backend(self) -> str: def get_backend(self) -> str:
self._ensure_initialized() self._ensure_initialized()
@@ -0,0 +1,152 @@
"""``torch.multiprocessing.spawn``-compatible namespace (ADR-0027 D1).
Real-PyTorch API *signature* parity only — execution model is a cooperative
greenlet scheduler in a single Python process (D1.0). Non-goals: process
isolation, independent address space, failure isolation, OS-level scheduler
fairness, mp.Queue/Lock.
Attached to ``RuntimeContext`` as ``ctx.multiprocessing`` in
``__post_init__`` (D1.3).
"""
from __future__ import annotations
from typing import Any, Callable
class SpawnException(RuntimeError):
"""Raised from ``_MultiprocessingNamespace.spawn`` on worker failure.
``errors`` contains only root-cause ranks — the rank(s) whose body
raised. Sibling greenlets terminated via ``throw(SystemExit)`` during
cleanup are NOT recorded (SystemExit does not satisfy ``except
Exception`` in the entry wrapper).
"""
def __init__(self, errors: dict[int, Exception]):
self.errors = errors
first = next(iter(errors.items()), None)
msg = (
f"spawn failed on ranks {sorted(errors.keys())}"
+ (
f": rank {first[0]} raised {first[1]!r}"
if first is not None
else ""
)
)
super().__init__(msg)
def _drain_pending(ctx: Any) -> None:
"""Drain worker-wait + collective-pending queues in main context (D0.4/D0.5).
Loop-until-empty: runs until both queues are simultaneously empty. Safe
under the current model where main-context ``ctx.wait`` never re-enqueues
(D0.5 main-context non-reentrance invariant); also safe under future
extensions where drain can add sub-handles (SimPy causality gives finite
depth).
"""
distributed = getattr(ctx, "distributed", None)
backend = getattr(distributed, "_backend", None) if distributed else None
def _collective_nonempty() -> bool:
if backend is None:
return False
pending = getattr(backend, "_pending_collective_handles", None)
return bool(pending)
while ctx._pending_worker_waits or _collective_nonempty():
# (a) Worker-driven waits (D0.1). FIFO.
while ctx._pending_worker_waits:
h = ctx._pending_worker_waits.pop(0)
if h not in ctx._completed:
wait_fn = getattr(ctx.engine, "wait", None)
if wait_fn is not None:
wait_fn(h)
# Populate _completed so fast-path in ctx.wait short-circuits
# on the return leg.
ctx._completed.add(h)
# (b) Collective backend queue (ADR-0024 D7 + D0.4-(2)).
if backend is not None:
pending_list = getattr(backend, "_pending_collective_handles", None)
if pending_list is not None:
while pending_list:
h, _sip_id, meta = pending_list.pop(0)
# Main context: ctx.wait drives engine directly and does
# NOT re-enqueue (D0.5 invariant).
ctx.wait(h, _meta=meta)
class _MultiprocessingNamespace:
"""torch.multiprocessing-compat facade bound to a RuntimeContext."""
def __init__(self, ctx: Any) -> None:
self._ctx = ctx
def spawn(
self,
fn: Callable,
args: tuple = (),
nprocs: int = 1,
join: bool = True,
) -> None:
"""Spawn ``nprocs`` worker greenlets, each calling ``fn(rank, *args)``.
Mirrors ``torch.multiprocessing.spawn`` signature (minus ``daemon``).
Runs the D0.4 round-robin scheduler loop until all workers finish,
draining pending queues between rounds.
"""
from greenlet import greenlet
ctx = self._ctx
dist = ctx.distributed
gs: list = []
errors: dict[int, Exception] = {}
for rank in range(nprocs):
def _entry(r: int = rank) -> None:
try:
fn(r, *args)
except Exception as e:
errors[r] = e
raise
g = greenlet(_entry)
if dist is not None and hasattr(dist, "_bind_rank"):
dist._bind_rank(g, rank)
gs.append(g)
try:
while True:
alive = [g for g in gs if not g.dead]
if not alive:
break
for g in alive:
if not g.dead:
g.switch()
_drain_pending(ctx)
except Exception as outer:
# D0.4-(4) sibling cleanup. Abort live greenlets, clear state.
for other in gs:
if not other.dead:
try:
other.throw(SystemExit)
except BaseException:
# SystemExit inherits BaseException; greenlet.throw
# re-raises in caller if target doesn't catch it.
# Silent — we're already in cleanup.
pass
backend = getattr(dist, "_backend", None)
if backend is not None:
if hasattr(backend, "_barrier") and hasattr(backend._barrier, "reset"):
try:
backend._barrier.reset()
except Exception:
pass
pending_collective = getattr(
backend, "_pending_collective_handles", None,
)
if pending_collective is not None:
pending_collective.clear()
ctx._pending_worker_waits.clear()
raise SpawnException(errors) from outer
# join=True: we already waited for all workers above.
+72 -6
View File
@@ -66,13 +66,64 @@ def _numpy_dtype(dtype: str) -> np.dtype:
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16)) return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
# ADR-0027 T5.g: closed-set registry of host-read barrier entry-points.
# Any new Tensor API with host-observable read semantics must be added here
# AND implement the barrier call. Code review + this registry keep the set
# consistent (Python introspection-based auto-detection is a non-goal).
# Note on ``copy_``: the source read is barriered via ``source.numpy()``.
# A target-side write barrier was specified in an earlier revision of
# ADR-0027 D0.5 but is intentionally not applied (global-pending target
# barrier can prematurely drain cross-rank collectives → deadlock).
_HOST_READ_BARRIERS: frozenset[str] = frozenset({
"numpy",
"data",
"__getitem__",
"__repr__",
"copy_", # source-side via source.numpy(); target-side not barriered
})
def _host_read_barrier(tensor: "Tensor") -> None:
"""ADR-0027 D0.5: drain pending worker-wait queue before a host-observable
read/write.
Scope: the barrier yields to main when ``ctx._pending_worker_waits`` is
non-empty AND the caller is a worker greenlet. Collective pending
(``backend._pending_collective_handles``) is **deliberately excluded**
from this check — collective handles represent cross-rank protocol that
must be drained only at scheduler synchronisation points (all workers
yielded). A collective's own yield (inside ``all_reduce``) already
ensures that once the collective call returns to the worker, post-drain
values are visible, so subsequent host reads see materialised data
without needing to trigger drain themselves. Including collective
pending here would cause an unrelated rank's barrier to prematurely
request drain of a cross-rank operation → deadlock.
No-op when called from main context or when the worker-wait queue is
empty (fast-path avoids needless context switches).
"""
ctx = None
if tensor._ctx_ref is not None:
ctx = tensor._ctx_ref()
if ctx is None:
return
worker_pending = getattr(ctx, "_pending_worker_waits", None)
if not worker_pending:
return # fast-path
from greenlet import getcurrent
g = getcurrent()
if g.parent is None or g.parent.dead:
return # main context: caller drains directly when needed
g.parent.switch()
def deploy_tensor( def deploy_tensor(
*, *,
name: str, name: str,
shape: tuple[int, ...], shape: tuple[int, ...],
dtype: str, dtype: str,
placement: list[ShardSpec], placement: list[ShardSpec],
allocators: dict[int, PEMemAllocator], allocators: dict[tuple[int, int, int], PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm", mem_kind: Literal["hbm", "tcm"] = "hbm",
va_allocator=None, va_allocator=None,
) -> TensorHandle: ) -> TensorHandle:
@@ -86,15 +137,15 @@ def deploy_tensor(
shards: list[TensorShard] = [] shards: list[TensorShard] = []
for spec in placement: for spec in placement:
alloc = allocators[spec.pe_index] alloc = allocators[(spec.sip, spec.cube, spec.pe)]
if mem_kind == "hbm": if mem_kind == "hbm":
pa = alloc.alloc_hbm(spec.nbytes) pa = alloc.alloc_hbm(spec.nbytes)
else: else:
pa = alloc.alloc_tcm(spec.nbytes) pa = alloc.alloc_tcm(spec.nbytes)
shards.append(TensorShard( shards.append(TensorShard(
sip=alloc._sip_id, sip=spec.sip,
cube=alloc._cube_id, cube=spec.cube,
pe=alloc._pe_id, pe=spec.pe,
pa=pa.encode(), pa=pa.encode(),
nbytes=spec.nbytes, nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes, offset_bytes=spec.offset_bytes,
@@ -217,7 +268,9 @@ class Tensor:
"""Read a shard-aligned slice. Returns a numpy array. """Read a shard-aligned slice. Returns a numpy array.
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case. Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
ADR-0027 D0.5: host-read barrier.
""" """
_host_read_barrier(self)
start, stop = self._resolve_shard_index(key) start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop) shard = self._shard_for_range(start, stop)
if self._memory_store is None: if self._memory_store is None:
@@ -272,6 +325,8 @@ class Tensor:
def __repr__(self) -> str: def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"] parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None: if self._memory_store is not None and self._handle is not None:
# ADR-0027 D0.5: barrier on data-containing repr path.
_host_read_barrier(self)
arr = self.data arr = self.data
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}") parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
else: else:
@@ -308,7 +363,11 @@ class Tensor:
Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are
gathered into a single full-shape ndarray according to each shard's gathered into a single full-shape ndarray according to each shard's
``offset_bytes`` / ``nbytes`` range. ``offset_bytes`` / ``nbytes`` range.
ADR-0027 D0.5: acts as a host-read barrier — drains pending waits +
collective handles before reading, ensuring post-drain values.
""" """
_host_read_barrier(self)
np_dtype = _numpy_dtype(self.dtype) np_dtype = _numpy_dtype(self.dtype)
# Host-side tensor (created via torch.from_numpy) has no shards. # Host-side tensor (created via torch.from_numpy) has no shards.
if self._host_buffer is not None: if self._host_buffer is not None:
@@ -340,6 +399,12 @@ class Tensor:
re-scattered into self's shard layout. re-scattered into self's shard layout.
Shapes must match. Returns self. Shapes must match. Returns self.
ADR-0027 D0.5: source-side read barrier is triggered inside
``source.numpy()``. Target-side write barrier is not applied here
because it would require cross-rank coordination when other ranks
have pending collectives (see _host_read_barrier docstring on
collective pending being cross-rank).
""" """
if self._handle is None or self._memory_store is None: if self._handle is None or self._memory_store is None:
raise RuntimeError( raise RuntimeError(
@@ -394,7 +459,8 @@ class Tensor:
) -> Tensor: ) -> Tensor:
"""Set DP placement metadata (like torch.Tensor.to()).""" """Set DP placement metadata (like torch.Tensor.to())."""
if placement is None: if placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)] placement = [ShardSpec(sip=0, cube=0, pe=0,
offset_bytes=0, nbytes=self.nbytes)]
self._dp_metadata = DPMetadata( self._dp_metadata = DPMetadata(
placement=placement, dp_policy=dp_policy, placement=placement, dp_policy=dp_policy,
sip=sip, cube=cube, target_pe=target_pe, sip=sip, cube=cube, target_pe=target_pe,
+11 -4
View File
@@ -101,12 +101,19 @@ class DataExecutor:
p = op.params p = op.params
if "src_a_addr" not in p: if "src_a_addr" not in p:
return # composite record without full params return # composite record without full params
space = p.get("addr_space", "tcm") default_space = p.get("addr_space", "tcm")
# ADR-0027: per-operand + output spaces (fall back to single space
# for legacy records without explicit space keys).
src_a_space = p.get("src_a_space", default_space)
src_b_space = p.get("src_b_space", default_space)
dst_space = p.get("dst_space", default_space)
dtype_in = p.get("dtype_in", "f16") dtype_in = p.get("dtype_in", "f16")
dtype_out = p.get("dtype_out", dtype_in) dtype_out = p.get("dtype_out", dtype_in)
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in) a = self.store.read(src_a_space, p["src_a_addr"],
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in) shape=p.get("shape_a"), dtype=dtype_in)
b = self.store.read(src_b_space, p["src_b_addr"],
shape=p.get("shape_b"), dtype=dtype_in)
# Compute in higher precision if specified # Compute in higher precision if specified
dtype_acc = p.get("dtype_acc", "f32") dtype_acc = p.get("dtype_acc", "f32")
@@ -114,7 +121,7 @@ class DataExecutor:
b_f = b.astype(_resolve_dtype(dtype_acc)) b_f = b.astype(_resolve_dtype(dtype_acc))
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out)) result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
self.store.write(space, p["dst_addr"], result) self.store.write(dst_space, p["dst_addr"], result)
def _execute_math(self, op: OpRecord) -> None: def _execute_math(self, op: OpRecord) -> None:
"""Execute math op: unary, binary, or reduction.""" """Execute math op: unary, binary, or reduction."""
+33 -1
View File
@@ -79,6 +79,14 @@ class OpLogger:
snaps.append(None) snaps.append(None)
params["input_snapshots"] = snaps params["input_snapshots"] = snaps
elif op_name == "dma_write": elif op_name == "dma_write":
# ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch)
# sources are repopulated by Phase 2 math/gemm replay —
# capturing a Phase-1-time snapshot here would pick up stale
# data from a PRIOR kernel's Phase 2 output that aliased the
# same scratch address, causing the later kernel's replay
# to write that stale value instead of the fresh math
# result. See ADR-0027 postmortem (TP gemm → all_reduce).
if params.get("src_space") == "hbm":
try: try:
arr = self._memory_store.read( arr = self._memory_store.read(
params["src_space"], params["src_addr"], params["src_space"], params["src_addr"],
@@ -167,6 +175,13 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"dtype_in": msg.a.dtype, "dtype_in": msg.a.dtype,
"dtype_out": msg.out.dtype, "dtype_out": msg.out.dtype,
"m": msg.m, "k": msg.k, "n": msg.n, "m": msg.m, "k": msg.k, "n": msg.n,
# ADR-0027: preserve per-operand + output MemoryStore spaces so
# Phase 2 replay can resolve HBM-resident operands (e.g. tl.load
# results keep space="hbm"). Absent → DataExecutor falls back
# to the legacy single-space mode via ``addr_space``.
"src_a_space": getattr(msg.a, "space", "tcm"),
"src_b_space": getattr(msg.b, "space", "tcm"),
"dst_space": getattr(msg.out, "space", "tcm"),
} }
if isinstance(msg, MathCmd): if isinstance(msg, MathCmd):
return "math", msg.op, { return "math", msg.op, {
@@ -181,10 +196,27 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"axis": msg.axis, "axis": msg.axis,
} }
if isinstance(msg, CompositeCmd): if isinstance(msg, CompositeCmd):
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", { params: dict[str, Any] = {
"op": msg.op, "op": msg.op,
"out_addr": msg.out_addr, "out_addr": msg.out_addr,
"out_nbytes": msg.out_nbytes, "out_nbytes": msg.out_nbytes,
} }
# ADR-0027: preserve operand info so Phase 2 DataExecutor can replay
# the composite's numerical effect (treat it like a GemmCmd).
if msg.op == "gemm" and msg.a is not None and msg.b is not None:
params.update({
"src_a_addr": msg.a.addr,
"src_b_addr": msg.b.addr,
"shape_a": msg.a.shape,
"shape_b": msg.b.shape,
"dtype_in": msg.a.dtype,
"dtype_out": msg.a.dtype,
"src_a_space": getattr(msg.a, "space", "hbm"),
"src_b_space": getattr(msg.b, "space", "hbm"),
"dst_space": "hbm",
# dst_addr alias so DataExecutor._execute_gemm picks it up.
"dst_addr": msg.out_addr,
})
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", params
# Fallback for unknown data_op messages # Fallback for unknown data_op messages
return "unknown", type(msg).__name__, {} return "unknown", type(msg).__name__, {}
+21
View File
@@ -0,0 +1,21 @@
"""kernbench.tp — Megatron-style Tensor Parallelism (ADR-0027).
Public API re-exports.
"""
from kernbench.tp.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
initialize_model_parallel,
)
__all__ = [
"ColumnParallelLinear",
"RowParallelLinear",
"get_tensor_model_parallel_rank",
"get_tensor_model_parallel_world_size",
"initialize_model_parallel",
]
+23
View File
@@ -0,0 +1,23 @@
"""Kernel used by ``kernbench.tp`` layers (ADR-0027 D4/D5).
Intentionally self-contained inside the ``tp`` package — the ``tp`` package
must not import from ``benches/``. Future work: move to a shared
``kernbench.kernels`` module so benches and TP can share.
"""
from __future__ import annotations
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE: str = "f16") -> None:
"""Single-PE GEMM: out = a @ b via load → dot → store.
Uses the ``tl.load + tl.dot + tl.store`` path. Unlike ``tl.composite``
(which is absorbed by the PE scheduler into TileTokens that don't reach
the op_log), this path emits explicit ``DmaReadCmd`` / ``GemmCmd`` /
``DmaWriteCmd`` records, which DataExecutor replays numerically in
Phase 2.
"""
M, K, N = int(M), int(K), int(N)
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
out = tl.dot(a, b)
tl.store(int(out_ptr), out)
+150
View File
@@ -0,0 +1,150 @@
"""Megatron-style parallel layers (ADR-0027 D4/D5).
- ``ColumnParallelLinear``: weight's out_features axis split across TP ranks.
forward(x) is local gemm; no collective.
- ``RowParallelLinear``: weight's in_features axis split across TP ranks.
forward(x) ends with ``dist.all_reduce`` to sum partial products.
Both layers use the intra-device ``DPPolicy`` (ADR-0026). TP shard
ownership is determined by ``torch.ahbm.set_device(rank)`` (ADR-0024 D10).
Yield-safety contract (ADR-0027 D4/D5): every forward path contains at
least one ``ctx.wait`` (via ``torch.launch``) or one collective; this
keeps the scheduler loop making progress.
"""
from __future__ import annotations
from typing import Any
from kernbench.policy.placement.dp import DPPolicy
from kernbench.tp.kernels import _gemm_kernel
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_world_size,
)
class ColumnParallelLinear:
"""Weight's K (out_features) axis distributed across TP ranks.
forward(x):
x: (M, N) — full-replicated across ranks
W_k: (N, K / world_size) — this rank's slice (on its SIP)
y_k = x @ W_k → (M, K / world_size)
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype: str = "f16",
torch: Any = None,
) -> None:
if torch is None:
raise TypeError("ColumnParallelLinear requires torch=<RuntimeContext>")
ws = get_tensor_model_parallel_world_size()
if out_features % ws != 0:
raise ValueError(
f"out_features ({out_features}) must be divisible by TP world "
f"size ({ws})"
)
self.in_features = in_features
self.out_features = out_features
self.k_local = out_features // ws
self.dtype = dtype
self._torch = torch
# Per-rank weight slice. ``set_device(rank)`` (ADR-0024 D10) places
# it on SIP ``rank``. Intra-SIP layout comes from DPPolicy (ADR-0026).
self.weight = torch.zeros(
(in_features, self.k_local),
dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="col_parallel_w",
)
# Bias omitted in initial scope (ADR-0027 D9).
self.bias = None
if bias:
raise NotImplementedError(
"bias=True is deferred (ADR-0027 D9 initial scope)"
)
def forward(self, x):
M = int(x.shape[0])
out = self._torch.empty(
(M, self.k_local),
dtype=x.dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="col_parallel_out",
)
self._torch.launch(
"col_parallel_gemm",
_gemm_kernel,
x, self.weight, out,
M, self.in_features, self.k_local,
)
return out
class RowParallelLinear:
"""Weight's N (in_features) axis distributed across TP ranks.
forward(x):
x: (M, N / world_size) — rank-local slice (ColumnParallel output)
W_k: (N / world_size, K) — this rank's slice
y_k = x @ W_k → (M, K) — partial sum
y = all_reduce(y_k, op="sum") → (M, K) on every rank
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype: str = "f16",
torch: Any = None,
) -> None:
if torch is None:
raise TypeError("RowParallelLinear requires torch=<RuntimeContext>")
ws = get_tensor_model_parallel_world_size()
if in_features % ws != 0:
raise ValueError(
f"in_features ({in_features}) must be divisible by TP world "
f"size ({ws})"
)
self.in_features = in_features
self.out_features = out_features
self.n_local = in_features // ws
self.dtype = dtype
self._torch = torch
self.weight = torch.zeros(
(self.n_local, out_features),
dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="row_parallel_w",
)
self.bias = None
if bias:
raise NotImplementedError(
"bias=True is deferred (ADR-0027 D9 initial scope)"
)
def forward(self, x):
M = int(x.shape[0])
y_partial = self._torch.empty(
(M, self.out_features),
dtype=x.dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="row_parallel_partial",
)
self._torch.launch(
"row_parallel_gemm",
_gemm_kernel,
x, self.weight, y_partial,
M, self.n_local, self.out_features,
)
self._torch.distributed.all_reduce(y_partial, op="sum")
return y_partial
+5
View File
@@ -0,0 +1,5 @@
"""Forward/backward mappings stub (ADR-0027 — future backward work).
Inference-only initial scope. Backward hooks land when training simulation
arrives.
"""
+83
View File
@@ -0,0 +1,83 @@
"""TP group state (ADR-0027 D3).
Single global TP group. Initial scope: TP size == world_size (pure TP;
mixed DP+TP is future work).
"""
from __future__ import annotations
_TP_WORLD_SIZE: int | None = None
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
"""Initialize the TP process group.
Must be called after ``torch.distributed.init_process_group``.
Only ``tensor_model_parallel_size == world_size`` is supported in the
initial scope.
"""
global _TP_WORLD_SIZE
# Import here to avoid cycle when tp is imported before a ctx exists.
_ws = _current_world_size()
if tensor_model_parallel_size != _ws:
raise NotImplementedError(
f"Only TP == world_size supported; got TP={tensor_model_parallel_size}, "
f"world_size={_ws}"
)
_TP_WORLD_SIZE = tensor_model_parallel_size
def get_tensor_model_parallel_world_size() -> int:
"""Return the TP group's world size.
Raises if not initialised — callers must call
:func:`initialize_model_parallel` first.
"""
if _TP_WORLD_SIZE is None:
raise RuntimeError(
"TP group not initialised; call initialize_model_parallel() first"
)
return _TP_WORLD_SIZE
def get_tensor_model_parallel_rank() -> int:
"""Return this worker's rank within the TP group.
Delegates to the greenlet-local rank registered by the spawn launcher
(ADR-0024 D9 via ``torch.distributed.get_rank``).
"""
# Resolve via the global torch.distributed facade on the active ctx.
return _current_rank()
def _reset_for_tests() -> None:
"""Clear _TP_WORLD_SIZE so ordering-sensitive tests can re-init."""
global _TP_WORLD_SIZE
_TP_WORLD_SIZE = None
# ── helpers (resolve current ctx) ────────────────────────────────────
def _current_ctx():
"""Best-effort resolution of the currently-active RuntimeContext.
In KernBench, the ``ctx`` is passed as the ``torch`` positional in
bench/worker code. Since parallel_state is a module-global helper,
we look it up via a weak registry maintained by RuntimeContext.
"""
from kernbench.runtime_api.context import _get_active_context
ctx = _get_active_context()
if ctx is None:
raise RuntimeError(
"No active RuntimeContext; kernbench.tp requires one "
"(call init_process_group / spawn under a live ctx)"
)
return ctx
def _current_world_size() -> int:
return _current_ctx().distributed.get_world_size()
def _current_rank() -> int:
return _current_ctx().distributed.get_rank()
+34
View File
@@ -0,0 +1,34 @@
"""TP primitive ops (ADR-0027 D6).
``copy_to_tp_region`` / ``reduce_from_tp_region`` are forward-only in the
initial scope (backward pass is future work). ``scatter`` / ``gather`` are
not implemented — they require an all-gather kernel that is not yet
available in KernBench (see ADR-0027 D9).
"""
from __future__ import annotations
from typing import Any
def copy_to_tp_region(x: Any) -> Any:
"""Forward: identity. Backward: all-reduce. (Training is future.)"""
return x
def reduce_from_tp_region(x: Any, torch: Any) -> Any:
"""Forward: all-reduce. Backward: identity."""
torch.distributed.all_reduce(x, op="sum")
return x
def scatter_to_tp_region(x: Any) -> Any:
raise NotImplementedError(
"scatter_to_tp_region deferred — caller should create the sharded "
"tensor directly (ADR-0027 D9)"
)
def gather_from_tp_region(x: Any) -> Any:
raise NotImplementedError(
"gather_from_tp_region deferred — requires all-gather kernel (ADR-0027 D9)"
)
+239
View File
@@ -0,0 +1,239 @@
"""ADR-0026 Phase 1 tests: DPPolicy intra-device only + ShardSpec structural.
These tests encode the contract from ADR-0026:
- DPPolicy no longer accepts ``sip`` or ``num_sips`` kwargs (TypeError).
- ShardSpec carries structural ``(sip, cube, pe)`` coordinates; the old flat
``pe_index`` field/property is fully removed (AttributeError).
- ``resolve_dp_policy(..., target_sip=N)`` stamps every returned ShardSpec
with ``sip=N``; cube and pe fields are local.
- ``RuntimeContext._allocators`` is keyed by ``(sip, cube, pe)`` tuples.
Phase 1: production code is unchanged → these tests SHOULD FAIL until the
Phase 2 diff lands. Phase 2 makes all of them pass.
"""
from __future__ import annotations
import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.runtime_api.tensor import deploy_tensor
# ── D1: DPPolicy no longer accepts sip / num_sips ─────────────────────
def test_dppolicy_rejects_sip_kwarg():
"""DPPolicy(sip=...) must raise TypeError after field removal."""
with pytest.raises(TypeError):
DPPolicy(sip="column_wise", cube="replicate", pe="replicate")
def test_dppolicy_rejects_num_sips_kwarg():
"""DPPolicy(num_sips=...) must raise TypeError after field removal."""
with pytest.raises(TypeError):
DPPolicy(cube="replicate", pe="replicate", num_sips=2)
def test_dppolicy_accepts_only_intra_device_fields():
"""Intra-device fields still work: cube, pe, num_cubes, num_pes."""
dp = DPPolicy(cube="column_wise", pe="column_wise",
num_cubes=2, num_pes=4)
assert dp.cube == "column_wise"
assert dp.pe == "column_wise"
assert dp.num_cubes == 2
assert dp.num_pes == 4
# No sip / num_sips attributes — even reading them must fail.
assert not hasattr(dp, "sip"), "DPPolicy.sip must be removed"
assert not hasattr(dp, "num_sips"), "DPPolicy.num_sips must be removed"
# ── D2: ShardSpec structural coords, no pe_index ──────────────────────
def test_shardspec_has_structural_coords():
"""ShardSpec constructs from (sip, cube, pe, offset_bytes, nbytes)."""
s = ShardSpec(sip=1, cube=2, pe=3, offset_bytes=128, nbytes=64)
assert s.sip == 1
assert s.cube == 2
assert s.pe == 3
assert s.offset_bytes == 128
assert s.nbytes == 64
def test_shardspec_has_no_pe_index_attr():
"""Flat pe_index must be fully removed — no field, no property."""
s = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=8)
with pytest.raises(AttributeError):
_ = s.pe_index # noqa: F841
def test_shardspec_rejects_pe_index_kwarg():
"""ShardSpec(pe_index=...) must raise TypeError."""
with pytest.raises(TypeError):
ShardSpec(pe_index=0, offset_bytes=0, nbytes=8) # type: ignore[call-arg]
# ── D3: resolve_dp_policy(target_sip=...) structural semantics ────────
def test_resolve_dp_policy_target_sip_stamps_shards():
"""All returned shards must carry sip == target_sip; cube/pe local."""
dp = DPPolicy(cube="column_wise", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(4, 32), itemsize=2,
num_pe=4, num_cubes=2, target_sip=1,
)
assert len(shards) == 2 * 4
assert all(s.sip == 1 for s in shards)
assert all(0 <= s.cube < 2 for s in shards)
assert all(0 <= s.pe < 4 for s in shards)
def test_resolve_dp_policy_target_sip_differ_only_in_sip():
"""Same policy + dims on two SIPs → shards identical except .sip."""
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_0 = resolve_dp_policy(
dp, shape=(4, 32), itemsize=2,
num_pe=4, num_cubes=1, target_sip=0,
)
shards_1 = resolve_dp_policy(
dp, shape=(4, 32), itemsize=2,
num_pe=4, num_cubes=1, target_sip=1,
)
assert len(shards_0) == len(shards_1)
for a, b in zip(shards_0, shards_1):
assert a.sip == 0 and b.sip == 1
assert a.cube == b.cube
assert a.pe == b.pe
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
def test_resolve_dp_policy_no_num_sips_param():
"""resolve_dp_policy must not accept num_sips anymore.
Post-Phase-2 signature drops ``num_sips`` (DPPolicy no longer crosses
SIP boundaries) and adds required ``target_sip``. Calling with
``num_sips=...`` must raise TypeError (unexpected keyword argument).
"""
dp = DPPolicy(cube="replicate", pe="replicate")
with pytest.raises(TypeError, match="num_sips"):
resolve_dp_policy(
dp, shape=(4, 8), itemsize=2,
num_pe=1, num_cubes=1, num_sips=2, # type: ignore[call-arg]
)
# ── D5: Allocator dict keyed by (sip, cube, pe) tuples ────────────────
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=2,
pes_per_cube=4,
hbm_bytes_per_cube=_GB,
hbm_slices_per_cube=4,
tcm_bytes_per_pe=_MB,
tcm_scheduler_reserved_bytes=0,
sram_bytes_per_cube=_MB,
)
def _make_tuple_allocators(
num_sips: int = 1, num_cubes: int = 1, num_pe: int = 4,
) -> dict[tuple[int, int, int], PEMemAllocator]:
return {
(s, c, p): PEMemAllocator(
rack_id=0, sip_id=s, cube_id=c, pe_id=p, cfg=_CFG,
)
for s in range(num_sips)
for c in range(num_cubes)
for p in range(num_pe)
}
def test_deploy_tensor_uses_tuple_lookup():
"""deploy_tensor(allocators={(sip,cube,pe): alloc, ...}) succeeds."""
dp = DPPolicy(cube="replicate", pe="column_wise")
placement = resolve_dp_policy(
dp, shape=(4, 16), itemsize=2,
num_pe=4, num_cubes=1, target_sip=0,
)
allocators = _make_tuple_allocators(num_sips=1, num_cubes=1, num_pe=4)
handle = deploy_tensor(
name="t", shape=(4, 16), dtype="f16",
placement=placement, allocators=allocators,
)
assert len(handle.shards) == 4
# Each shard's TensorShard carries structural coords; those coords
# must match the shard's ShardSpec (sip, cube, pe).
for spec, shard in zip(placement, handle.shards):
assert shard.sip == spec.sip
assert shard.cube == spec.cube
assert shard.pe == spec.pe
def test_runtime_context_allocator_keys_are_tuples(topology):
"""After ctx tensor op, ctx._allocators keys are (sip, cube, pe) tuples.
Ensures D5 migration landed (allocator population + lookup).
"""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:0"),
correlation_id="test_adr0026_tuple_keys",
spec=topology.topology_obj.spec,
)
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
_ = ctx.zeros((1, 16), dtype="f16", dp=dp)
assert ctx._allocators, "allocators dict should be populated"
keys = list(ctx._allocators.keys())
assert all(isinstance(k, tuple) and len(k) == 3 for k in keys), (
f"_allocators keys must be (sip, cube, pe) tuples; got {keys[:5]}"
)
# ── D4 (via regression): no SIP-crossing tensor without set_device ────
def test_create_tensor_on_target_sip_via_set_device(topology):
"""torch.ahbm.set_device(1) + DPPolicy(cube=replicate, pe=replicate)
→ all shards land on SIP 1 structurally (no post-hoc shifting needed)."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
# Skip the test if topology has only 1 SIP (nothing to verify).
n_sips = int(
topology.topology_obj.spec.get("system", {})
.get("sips", {}).get("count", 1)
)
if n_sips < 2:
pytest.skip("topology has <2 SIPs; set_device(1) not meaningful")
engine = GraphEngine(topology.topology_obj, enable_data=True)
ctx = RuntimeContext(
engine=engine,
target_device=DeviceSelector("sip:1"),
correlation_id="test_adr0026_set_device",
spec=topology.topology_obj.spec,
)
ctx.ahbm.set_device(1)
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
t = ctx.zeros((1, 16), dtype="f16", dp=dp)
assert t._handle is not None
assert all(s.sip == 1 for s in t._handle.shards), (
f"expected all shards on SIP 1; got {[s.sip for s in t._handle.shards]}"
)
+222
View File
@@ -0,0 +1,222 @@
"""Config-driven multi-device allreduce test application.
Reads ``ccl.yaml`` + ``topology.yaml``, dynamically loads the kernel
module from ``ccl.yaml → module``, and picks the inter-SIP exchange
pattern from ``topology.yaml → system.sips.topology``.
Run directly::
python -m pytest tests/allreduce_app.py -v -s
"""
from __future__ import annotations
import importlib
import math
from pathlib import Path
from typing import Any
import numpy as np
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
from kernbench.policy.placement.dp import DPPolicy
def _sip_topo_dims(sip_topo: str, n_sips: int) -> tuple[int, int]:
if sip_topo == "ring_1d":
return (0, 0)
side = int(round(math.sqrt(n_sips)))
if side * side != n_sips:
raise ValueError(
f"SIP topology '{sip_topo}' requires square n_sips, got {n_sips}"
)
return (side, side)
def run_allreduce(
ctx: Any,
engine: Any,
spec: dict,
*,
algorithm: str | None = None,
ccl_yaml: str | None = None,
) -> dict:
"""Config-driven allreduce: read yaml, load kernel, run.
Everything is resolved from config — no hardcoded kernel imports.
"""
cfg_all = load_ccl_config(ccl_yaml)
cfg = resolve_algorithm_config(cfg_all, algorithm)
# Dynamic import from ccl.yaml → module
algo_module = importlib.import_module(cfg["module"])
kernel_fn = algo_module.kernel
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
n_elem = int(cfg.get("n_elem", 8))
n_sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
sip_topo = str(
spec.get("system", {}).get("sips", {}).get("topology", "ring_1d")
)
cm = spec["sip"]["cube_mesh"]
cube_w = int(cm["w"])
cube_h = int(cm["h"])
n_cubes = cube_w * cube_h
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
sip_topo_w, sip_topo_h = _sip_topo_dims(sip_topo, n_sips)
algo_name = cfg.get("algorithm", "allreduce")
print(f"\n{'=' * 60}")
print(f"algorithm: {algo_name}")
print(f"module: {cfg['module']}")
print(f"sip_topology: {sip_topo}")
print(f"kernel: {kernel_fn.__name__}")
print(f"n_sips: {n_sips}")
print(f"n_cubes: {n_cubes}")
print(f"n_elem: {n_elem}")
print(f"{'=' * 60}")
configure_sfr_intercube_multisip(engine, spec, cfg)
dp = DPPolicy(
cube="row_wise", pe="replicate",
num_pes=1, num_cubes=n_cubes,
)
tensors = []
for sip in range(n_sips):
ctx.ahbm.set_device(sip)
t = ctx.zeros(
(n_cubes, n_elem), dtype="f16", dp=dp,
name=f"sip{sip}",
)
t.copy_(ctx.from_numpy(
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
))
tensors.append(t)
for sip in range(n_sips):
arr = tensors[sip].numpy()
print(f"[SIP {sip}] input cube0[:4] = {arr[0][:4].tolist()} "
f"cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
t_start = engine._env.now
all_pending = []
for sip_rank, t in enumerate(tensors):
pending = ctx.launch(
algo_name, kernel_fn, t,
n_elem, cube_w, cube_h, n_sips, sip_rank,
sip_topo_kind, sip_topo_w, sip_topo_h,
_defer_wait=True,
)
all_pending.extend(pending)
for h, sip_id, meta in all_pending:
ctx.wait(h, _meta=meta)
t_end = engine._env.now
latency_ns = t_end - t_start
print(f"\n[{algo_name} ws={n_sips}] sim latency = "
f"{latency_ns:.1f} ns ({latency_ns / 1000:.3f} us)")
for key, (_, trace) in engine._results.items():
if not isinstance(trace, dict):
continue
total = trace.get("total_ns", 0.0)
pe_exec = trace.get("pe_exec_ns", 0.0) or 0.0
network = total - pe_exec
print(f" [{key}] total={total:.1f} ns "
f"pe_exec={pe_exec:.1f} ns network={network:.1f} ns")
expected = float(n_cubes * sum(range(1, n_sips + 1)))
print()
for sip in range(n_sips):
arr = tensors[sip].numpy()
print(f"[SIP {sip}] output cube0[:4] = {arr[0][:4].tolist()}")
print(f"[SIP {sip}] output cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
ok_cubes = 0
for sip in range(n_sips):
arr = tensors[sip].numpy()
for cube_id in range(n_cubes):
assert np.allclose(
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
), (
f"SIP{sip} cube {cube_id}: "
f"got {arr[cube_id][:4]}, expected {expected}"
)
ok_cubes += 1
print(f"\n {algo_name} (ws={n_sips}): {ok_cubes} OK")
return {
"expected": expected,
"latency_ns": latency_ns,
"ok_cubes": ok_cubes,
}
# ── pytest entry point ───────────────────────────────────────────────
import pytest
import yaml
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
CONFIGS = [
pytest.param("intercube_allreduce", "ring_1d", 2, id="ring_2sip"),
pytest.param("intercube_allreduce", "torus_2d", 4, id="torus_4sip"),
pytest.param("intercube_allreduce", "mesh_2d_no_wrap", 4, id="mesh_4sip"),
]
def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm):
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
with open(TOPOLOGY_PATH) as f:
topo_cfg = yaml.safe_load(f)
topo_cfg["system"]["sips"]["count"] = n_sips
topo_cfg["system"]["sips"]["topology"] = sip_topology
topo_path = tmp_path / "topology.yaml"
with open(topo_path, "w") as f:
yaml.dump(topo_cfg, f, default_flow_style=False)
ccl_path = Path(__file__).parent.parent / "ccl.yaml"
with open(ccl_path) as f:
ccl_cfg = yaml.safe_load(f)
ccl_cfg["defaults"]["algorithm"] = algorithm
tmp_ccl = tmp_path / "ccl.yaml"
with open(tmp_ccl, "w") as f:
yaml.dump(ccl_cfg, f, default_flow_style=False)
return str(topo_path), str(tmp_ccl)
@pytest.mark.parametrize("algorithm,sip_topology,n_sips", CONFIGS)
def test_allreduce(tmp_path, algorithm, sip_topology, n_sips):
topo_path, ccl_path = _write_temp_configs(
tmp_path, sip_topology, n_sips, algorithm,
)
topo = resolve_topology(topo_path)
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id=f"test_{algorithm}_{sip_topology}",
spec=spec,
) as ctx:
result = run_allreduce(
ctx, engine, spec,
algorithm=algorithm, ccl_yaml=ccl_path,
)
assert result["ok_cubes"] > 0
-150
View File
@@ -1,150 +0,0 @@
"""End-to-end matrix tests for the unified ``ccl_allreduce`` bench.
Each parametrized case writes a tmp ``ccl.yaml`` overlay that selects a
specific (algorithm, world_size, buffer_kind, n_elem) combination, then
runs the bench via the CLI and asserts the printed line reports all
ranks OK.
This single test file replaces the per-variant bench tests
(test_ccl_allreduce_e2e, test_ccl_mesh_allreduce, test_ccl_tree_allreduce,
test_ccl_multicube, test_ccl_multisip).
"""
from __future__ import annotations
import os
import textwrap
import pytest
import kernbench.cli.main as cli_main
CCL_YAML_TEMPLATE = textwrap.dedent("""\
defaults:
algorithm: {algorithm}
buffer_kind: {buffer_kind}
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
{algorithm}:
module: {module}
topology: {topology}
buffer_kind: {buffer_kind}
{world_size_line}{n_elem_line}
""")
def _write_ccl_yaml(
tmp_path,
*,
algorithm: str,
module: str,
topology: str,
buffer_kind: str = "tcm",
world_size: int | None = None,
n_elem: int | None = None,
) -> str:
"""Write a tmp ccl.yaml in tmp_path and return its directory."""
ws_line = f" world_size: {world_size}\n" if world_size is not None else ""
nel_line = f" n_elem: {n_elem}\n" if n_elem is not None else ""
body = CCL_YAML_TEMPLATE.format(
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size_line=ws_line,
n_elem_line=nel_line,
)
yaml_path = tmp_path / "ccl.yaml"
yaml_path.write_text(body)
return str(tmp_path)
CASES = [
# algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws
#
# Full-system (256-rank, cross-SIP) — run only ONCE (tcm). Buffer
# variant differences are purely IPCQ slot placement; the compute path
# is identical. Cross-SIP routing is the real thing being verified here.
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", None, 8, 256,
id="ring_full_system",
marks=pytest.mark.slow,
),
# Buffer variants at 8-rank (fast — same kernel, different slot space).
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 8, 32, 8,
id="ring_tcm_8",
),
pytest.param(
"ring_allreduce_hbm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "hbm", 8, 32, 8,
id="ring_hbm_8",
),
pytest.param(
"ring_allreduce_sram", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "sram", 8, 32, 8,
id="ring_sram_8",
),
# Multi-cube (16-rank, cross-cube within 1 SIP).
pytest.param(
"ring_allreduce_16", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 16, 16, 16,
id="ring_multi_cube",
),
# Mesh + tree algorithms.
pytest.param(
"mesh_allreduce_4", "kernbench.ccl.algorithms.mesh_allreduce",
"mesh_2d", "tcm", 4, 16, 4,
id="mesh_2x2",
),
pytest.param(
"tree_allreduce_7", "kernbench.ccl.algorithms.tree_allreduce",
"tree_binary", "tcm", 7, 16, 7,
id="tree_binary_7",
),
]
@pytest.mark.parametrize(
"algorithm,module,topology,buffer_kind,world_size,n_elem,expected_ws",
CASES,
)
def test_ccl_allreduce_matrix(
tmp_path, capsys, monkeypatch,
algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws,
):
"""Each (algorithm × buffer × world_size) combo passes through the
unified bench and yields all ranks OK."""
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")
)
yaml_dir = _write_ccl_yaml(
tmp_path,
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size=world_size,
n_elem=n_elem,
)
monkeypatch.chdir(yaml_dir)
rc = cli_main.main([
"run",
"--topology", os.path.join(project_root, "topology.yaml"),
"--bench", "ccl_allreduce",
"--verify-data",
])
assert rc == 0
out = capsys.readouterr().out
assert "FAIL" not in out, f"unexpected FAIL in output:\n{out}"
assert f"{algorithm} (ws={expected_ws}): {expected_ws} OK" in out, (
f"expected '{algorithm} (ws={expected_ws}): {expected_ws} OK' "
f"in output:\n{out}"
)
-125
View File
@@ -1,125 +0,0 @@
"""Tests for IPCQ deadlock detection (ADR-0023 D14 F3)."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import pytest
import simpy
from kernbench.ccl import diagnostics
from kernbench.common.ipcq_types import (
IpcqEndpoint,
IpcqInitEntry,
IpcqRecvCmd,
IpcqRequest,
)
from kernbench.components.builtin.pe_ipcq import PeIpcqComponent
from kernbench.runtime_api.kernel import IpcqInitMsg
from kernbench.topology.types import Node
@dataclass
class _FakeTxn:
request: Any
done: simpy.Event
result_data: dict[str, Any] = field(default_factory=dict)
def _make_isolated_pe_ipcq(env):
node = Node(
id="sip0.cube0.pe0.pe_ipcq", kind="pe_ipcq",
impl="builtin.pe_ipcq", attrs={}, pos_mm=None,
)
comp = PeIpcqComponent(node, ctx=None)
comp.in_ports["host"] = simpy.Store(env)
comp.out_ports["sip0.cube0.pe0.pe_dma"] = simpy.Store(env)
comp.start(env)
peer_credit = simpy.Store(env)
ep = IpcqEndpoint(
sip=0, cube=0, pe=1, buffer_kind="tcm",
rx_base_pa=0x10_000, rx_base_va=0,
n_slots=4, slot_size=4096,
)
init_msg = IpcqInitMsg(
correlation_id="t", request_id="t",
target_sips=(0,), target_cubes=(0,), target_pe=0,
entries=(IpcqInitEntry(
direction="W", peer=ep,
my_rx_base_pa=0x40_000, my_rx_base_va=0,
n_slots=4, slot_size=4096,
peer_credit_store=peer_credit,
),),
backpressure_mode="sleep",
buffer_kind="tcm",
credit_size_bytes=16,
)
done = env.event()
comp.in_ports["host"].put(_FakeTxn(request=init_msg, done=done))
env.run(until=done)
return comp
def test_pointer_dump_includes_blocked_state():
"""A blocked recv should still be visible in the pointer dump."""
env = simpy.Environment()
comp = _make_isolated_pe_ipcq(env)
# Issue a recv that will block (no data has arrived)
recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r1")
req = IpcqRequest(command=recv_cmd, done=env.event())
comp.in_ports["host"].put(req)
env.run(until=10)
assert not req.done.triggered
# Pointer dump should show my_tail=0 and peer_head_cache=0
# We need to use the engine API but for an isolated component, just call directly
class FakeEngine:
_components = {"sip0.cube0.pe0.pe_ipcq": comp}
dump = diagnostics.pointer_dump(FakeEngine())
assert "my_tail=0" in dump
assert "peer_head_cache=0" in dump
def test_deadlock_detection_recv_without_send():
"""A recv with no matching sender → SimPy schedule empties → engine
raises ``IpcqDeadlock`` with a pointer dump.
"""
from kernbench.ccl.diagnostics import IpcqDeadlock
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
def deadlock_kernel(t_ptr, n_elem, tl):
# Every PE just receives, no sends → no one delivers → deadlock
tl.recv(dir="W", shape=(n_elem,), dtype="f16")
topo = resolve_topology("topology.yaml")
def run(torch):
torch.install_ipcq(
algorithm="ring_allreduce_tcm", world_size_override=8,
)
a = torch.zeros(
(1, 8 * 8),
dtype="f16",
dp=DPPolicy(
sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1,
),
name="dl_in",
)
torch.launch("dl", deadlock_kernel, a, 8)
with pytest.raises(IpcqDeadlock):
run_bench(
topology=topo, bench_fn=run,
device=resolve_device("all"),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True
),
)
-70
View File
@@ -1,70 +0,0 @@
"""Tests for CCL diagnostics: trace + pointer dump (ADR-0023 D14)."""
from __future__ import annotations
import os
from kernbench.ccl import diagnostics
# ── trace toggle ─────────────────────────────────────────────────────
def test_trace_disabled_by_default(monkeypatch):
monkeypatch.delenv("KERNBENCH_CCL_TRACE", raising=False)
diagnostics.reload_trace_setting()
assert diagnostics.trace_enabled() is False
def test_trace_enabled_via_env(monkeypatch):
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
diagnostics.reload_trace_setting()
assert diagnostics.trace_enabled() is True
def test_trace_record_send(monkeypatch, capsys):
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
diagnostics.reload_trace_setting()
diagnostics.log_send(t_ns=100.0, sender="sip0.cube0.pe0",
direction="E", nbytes=64, sender_seq=0)
out = capsys.readouterr().out
assert "send" in out
assert "sip0.cube0.pe0" in out
assert "dir=E" in out
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
diagnostics.reload_trace_setting()
def test_trace_record_recv(monkeypatch, capsys):
monkeypatch.setenv("KERNBENCH_CCL_TRACE", "1")
diagnostics.reload_trace_setting()
diagnostics.log_recv(t_ns=200.0, receiver="sip0.cube0.pe1",
direction="W", nbytes=64)
out = capsys.readouterr().out
assert "recv" in out
assert "sip0.cube0.pe1" in out
monkeypatch.delenv("KERNBENCH_CCL_TRACE")
diagnostics.reload_trace_setting()
# ── pointer dump ────────────────────────────────────────────────────
def test_pointer_dump_format():
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
from kernbench.ccl.install import (
install_ipcq, load_ccl_config, resolve_algorithm_config,
)
topo = resolve_topology("topology.yaml").topology_obj
engine = GraphEngine(topo, enable_data=True)
cfg = resolve_algorithm_config(load_ccl_config(), name="ring_allreduce_tcm")
install_ipcq(engine, topo.spec, cfg)
dump = diagnostics.pointer_dump(engine)
# 8 ranks × 2 directions = 16 lines (plus 8 PE headers)
assert "sip0.cube0.pe0" in dump
assert "E:" in dump
assert "W:" in dump
assert "my_head=" in dump
assert "peer_tail_cache=" in dump
-81
View File
@@ -1,81 +0,0 @@
"""Validate the hello-world example from docs/ccl-author-guide.md.
This is the simplest possible CCL kernel — each PE sends its tile E
and receives a tile from W. After running, each rank's slice should
contain the data of the previous rank.
"""
from __future__ import annotations
import numpy as np
from kernbench.ccl.algorithms import hello_send
from kernbench.ccl.testing import run_kernel_in_mock
def test_hello_send_4_ranks_mock():
n_elem = 8
inputs = [np.full((n_elem,), float(r + 1), dtype=np.float16) for r in range(4)]
outputs = run_kernel_in_mock(
kernel_fn=hello_send.kernel,
world_size=4,
topology="ring_1d",
inputs=inputs,
kernel_args=(n_elem,),
)
# rank r should have rank (r-1) % 4's data
for r in range(4):
prev = inputs[(r - 1) % 4]
assert np.array_equal(outputs[r], prev), f"rank {r}: got {outputs[r]}"
def test_hello_send_via_simpy_runner():
"""Same but through real SimPy + IPCQ."""
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
topo = resolve_topology("topology.yaml")
n_elem = 8
world_size = 8
def run(torch):
# World size for this hello test is 8 (one cube). ccl.yaml no
# longer carries a default world_size — pass it explicitly.
plan = torch.install_ipcq(
algorithm="ring_allreduce_tcm", world_size_override=world_size,
)
a = torch.zeros(
(1, world_size * n_elem), dtype="f16",
dp=DPPolicy(
sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1,
),
name="hello_in",
)
store = torch.engine.memory_store
base = a._handle.va_base or a._handle.shards[0].pa
nbytes = n_elem * 2
for r in range(world_size):
store.write("hbm", base + r * nbytes,
np.full((n_elem,), float(r + 1), dtype=np.float16))
torch.launch("hello_send", hello_send.kernel, a, n_elem)
# Each rank should hold the previous rank's data after the round
for r in range(world_size):
arr = store.read("hbm", base + r * nbytes, shape=(n_elem,), dtype="f16")
prev_value = float(((r - 1) % world_size) + 1)
assert np.allclose(arr, prev_value), f"rank {r}: got {arr}, expected {prev_value}"
result = run_bench(
topology=topo, bench_fn=run,
device=resolve_device("all"),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True
),
)
assert result.completion.ok
-117
View File
@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from kernbench.ccl.install import ( from kernbench.ccl.install import (
install_ipcq,
linear_rank_to_pe, linear_rank_to_pe,
load_ccl_config, load_ccl_config,
resolve_algorithm_config, resolve_algorithm_config,
@@ -26,28 +25,14 @@ def test_resolve_algorithm_config_default():
cfg = load_ccl_config() cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg) merged = resolve_algorithm_config(cfg)
assert merged["algorithm"] == cfg["defaults"]["algorithm"] assert merged["algorithm"] == cfg["defaults"]["algorithm"]
# ccl.yaml no longer carries defaults.world_size — backend derives
# it from topology.yaml at install time. Just check the field is
# absent here (verified per-test where install_ipcq is called).
assert "world_size" not in merged or merged["world_size"] >= 1 assert "world_size" not in merged or merged["world_size"] >= 1
def test_resolve_algorithm_config_override():
cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg, name="ring_allreduce_hbm")
assert merged["algorithm"] == "ring_allreduce_hbm"
assert merged["buffer_kind"] == "hbm" # algo override
# defaults still apply
assert merged["n_slots"] == cfg["defaults"]["n_slots"]
def test_linear_rank_to_pe(): def test_linear_rank_to_pe():
engine, topo = _engine() engine, topo = _engine()
spec = topo.spec spec = topo.spec
# Cube 0 of SIP 0
assert linear_rank_to_pe(0, spec) == (0, 0, 0) assert linear_rank_to_pe(0, spec) == (0, 0, 0)
assert linear_rank_to_pe(7, spec) == (0, 0, 7) assert linear_rank_to_pe(7, spec) == (0, 0, 7)
# Should not exceed total PE count
pes_per_sip = ( pes_per_sip = (
spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"] spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"]
* spec["cube"]["pe_layout"]["pe_per_corner"] * spec["cube"]["pe_layout"]["pe_per_corner"]
@@ -56,105 +41,3 @@ def test_linear_rank_to_pe():
sips = spec["system"]["sips"]["count"] sips = spec["system"]["sips"]["count"]
total = sips * pes_per_sip total = sips * pes_per_sip
assert total >= 8 assert total >= 8
def test_install_ipcq_neighbors_correct():
engine, topo = _engine()
cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
# Force a single-cube 8-rank install for the assertions below.
merged["world_size"] = 8
plan = install_ipcq(engine, topo.spec, merged)
assert plan["world_size"] == 8
assert plan["buffer_kind"] == "tcm"
# Each rank should have E and W entries
for r, nbrs in plan["neighbor_table"].items():
assert "E" in nbrs
assert "W" in nbrs
# Inspect installed PE_IPCQ for rank 0
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
qp_e = ipcq.queue_pairs["E"]
qp_w = ipcq.queue_pairs["W"]
assert qp_e["peer"].pe == 1 # rank 0's E neighbor is rank 1
assert qp_w["peer"].pe == 7 # rank 0's W neighbor is rank 7
# rx_base addresses should be unique
assert qp_e["my_rx_base_pa"] != qp_w["my_rx_base_pa"]
def test_install_ipcq_credit_stores_wired():
engine, topo = _engine()
cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
merged["world_size"] = 8
install_ipcq(engine, topo.spec, merged)
# rank 0 (pe0) sending E goes to rank 1 (pe1)
# rank 0's peer_credit_store on E direction should equal rank 1's credit_inbox
pe0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
pe1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
qp_e = pe0.queue_pairs["E"]
assert qp_e["peer_credit_store"] is pe1.credit_inbox
# ── ADR-0025 D1: reverse_direction opposite-preference ───────────────
def test_reverse_direction_opposite_preference_2rank_ring():
"""ADR-0025 D1: In a 2-rank bidirectional ring both E and W point to the
same peer; reverse_direction must pick the OPPOSITE direction (W for E,
E for W) so rx_base targets the semantically-correct slot.
Concretely: rank 0 sending via E to rank 1 must target rank 1's W-rx
buffer (not rank 1's E-rx), because rank 1's kernel recv(W) reads from
its W-rx.
"""
engine, topo = _engine()
cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
merged["world_size"] = 2
install_ipcq(engine, topo.spec, merged)
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
rank1_e_rx = ipcq1.queue_pairs["E"]["my_rx_base_pa"]
rank1_w_rx = ipcq1.queue_pairs["W"]["my_rx_base_pa"]
qp0_e = ipcq0.queue_pairs["E"]
qp0_w = ipcq0.queue_pairs["W"]
# rank 0's E entry should target rank 1's W-rx (opposite), NOT rank 1's E-rx.
assert qp0_e["peer"].rx_base_pa == rank1_w_rx, (
f"expected rank 0's E peer.rx_base_pa == rank 1's W-rx ({rank1_w_rx:#x}), "
f"got {qp0_e['peer'].rx_base_pa:#x} (matches E-rx: {rank1_e_rx:#x}) — "
f"reverse_direction picked same-label instead of opposite"
)
# rank 0's W entry should target rank 1's E-rx (opposite).
assert qp0_w["peer"].rx_base_pa == rank1_e_rx
def test_reverse_direction_opposite_preference_4rank_ring_sanity():
"""ADR-0025 D1 sanity: ws>=3 ring. E and W have distinct peers, so
opposite-preference produces same result as old dict-order first-match.
This test should PASS both under current and post-fix code.
"""
engine, topo = _engine()
cfg = load_ccl_config()
merged = resolve_algorithm_config(cfg, name="ring_allreduce_tcm")
merged["world_size"] = 4
install_ipcq(engine, topo.spec, merged)
ipcq0 = engine._components["sip0.cube0.pe0.pe_ipcq"]
ipcq1 = engine._components["sip0.cube0.pe1.pe_ipcq"]
ipcq3 = engine._components["sip0.cube0.pe3.pe_ipcq"]
# rank 0 E → rank 1 → rank 1's W-rx
qp0_e = ipcq0.queue_pairs["E"]
assert qp0_e["peer"].rx_base_pa == ipcq1.queue_pairs["W"]["my_rx_base_pa"]
# rank 0 W → rank 3 (last in ring) → rank 3's E-rx
qp0_w = ipcq0.queue_pairs["W"]
assert qp0_w["peer"].rx_base_pa == ipcq3.queue_pairs["E"]["my_rx_base_pa"]
-83
View File
@@ -1,83 +0,0 @@
"""Tests for the mock CCL runtime (ADR-0023 D15)."""
from __future__ import annotations
import numpy as np
from kernbench.ccl.algorithms import ring_allreduce
from kernbench.ccl.testing import run_kernel_in_mock
def test_ring_allreduce_4_ranks():
"""Run the ring all-reduce kernel under the mock runtime, no SimPy."""
n_elem = 8
inputs = [
np.full((n_elem,), float(r + 1), dtype=np.float16)
for r in range(4)
]
expected = sum(inputs) # [10, 10, ..., 10]
outputs = run_kernel_in_mock(
kernel_fn=ring_allreduce.kernel,
world_size=4,
topology="ring_1d",
inputs=inputs,
kernel_args=(n_elem, 4),
)
assert len(outputs) == 4
for r in range(4):
assert np.allclose(outputs[r], expected)
def test_ring_allreduce_8_ranks():
n_elem = 16
inputs = [
np.full((n_elem,), float(r + 1), dtype=np.float16)
for r in range(8)
]
expected = sum(inputs) # [36, 36, ...]
outputs = run_kernel_in_mock(
kernel_fn=ring_allreduce.kernel,
world_size=8,
topology="ring_1d",
inputs=inputs,
kernel_args=(n_elem, 8),
)
for r in range(8):
assert np.allclose(outputs[r], expected)
def test_ring_allreduce_random_data():
n_elem = 32
rng = np.random.default_rng(42)
inputs = [rng.standard_normal(n_elem).astype(np.float16) for _ in range(4)]
expected = sum(inputs)
outputs = run_kernel_in_mock(
kernel_fn=ring_allreduce.kernel,
world_size=4,
topology="ring_1d",
inputs=inputs,
kernel_args=(n_elem, 4),
)
for r in range(4):
assert np.allclose(outputs[r], expected, rtol=1e-2, atol=1e-2)
def test_mock_runtime_invalid_direction_raises():
"""A kernel that uses an unsupported direction should raise."""
import pytest
def bad_kernel(t_ptr, n_elem, tl):
tl.send(dir="N", src_addr=0, nbytes=2, shape=(1,), dtype="f16", space="hbm")
inputs = [np.array([1.0], dtype=np.float16) for _ in range(2)]
with pytest.raises(Exception):
run_kernel_in_mock(
kernel_fn=bad_kernel,
world_size=2,
topology="ring_1d",
inputs=inputs,
kernel_args=(1,),
)
-87
View File
@@ -1,87 +0,0 @@
"""CCL performance validation tests (ADR-0023 D13 T5).
Sanity-checks the simulated latency of the unified ``ccl_allreduce`` bench.
Uses 8-rank (single cube) for all buffer variants — the latency model
is topology-aware, so buffer_kind differences are visible even at small
scale. Full-system (256-rank) cross-SIP latency is covered by the
``test_ccl_allreduce_matrix[ring_full_system]`` slow test.
"""
from __future__ import annotations
import importlib
import os
import pytest
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
def _engine_factory(topology, device):
return GraphEngine(getattr(topology, "topology_obj", topology), enable_data=True)
def _run_8rank(algorithm: str, buffer_kind: str = "tcm") -> float:
"""Run an 8-rank ring via the unified bench with a tmp ccl.yaml overlay.
Returns simulated kernel total_ns."""
import tempfile
body = f"""\
defaults:
algorithm: {algorithm}
buffer_kind: {buffer_kind}
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
{algorithm}:
module: kernbench.ccl.algorithms.ring_allreduce
topology: ring_1d
buffer_kind: {buffer_kind}
world_size: 8
n_elem: 32
"""
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, "ccl.yaml"), "w") as f:
f.write(body)
old_cwd = os.getcwd()
os.chdir(tmp)
try:
topo = resolve_topology(os.path.join(project_root, "topology.yaml"))
bench_mod = importlib.import_module("benches.ccl_allreduce")
result = run_bench(
topology=topo, bench_fn=bench_mod.run,
device=resolve_device("all"),
engine_factory=_engine_factory,
)
finally:
os.chdir(old_cwd)
assert result.completion.ok, f"{algorithm} did not complete"
last_kernel = None
for tr in (result.traces or []):
if tr.get("phase") == "kernel":
last_kernel = tr
assert last_kernel is not None, f"{algorithm} produced no kernel trace"
return float(last_kernel.get("total_ns", 0.0))
@pytest.mark.parametrize("buffer_kind", ["tcm", "hbm", "sram"])
def test_ccl_latency_positive(buffer_kind):
"""Every buffer kind must produce a positive simulated latency."""
algo = f"ring_allreduce_{buffer_kind}"
ns = _run_8rank(algo, buffer_kind)
assert ns > 0
def test_ccl_latency_under_reasonable_bound():
"""8-rank ring all-reduce (tile=32 f16) should finish well under 1ms."""
ns = _run_8rank("ring_allreduce_tcm", "tcm")
assert ns < 1_000_000 # < 1 ms simulated
@@ -0,0 +1,119 @@
"""End-to-end distributed test for intercube allreduce.
Exercises the full process-group path:
dist.init_process_group(backend="ahbm")
→ mp.spawn(nprocs=n_sips)
→ each worker: set_device → allocate → fill → dist.all_reduce → verify
This is the same flow a real DDP training script would use.
"""
from __future__ import annotations
import os
import textwrap
from pathlib import Path
import numpy as np
import pytest
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
N_CUBES = 16
N_ELEM = 8
def _write_ccl_yaml(tmp_path) -> str:
body = textwrap.dedent("""\
defaults:
algorithm: intercube_allreduce
buffer_kind: tcm
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
intercube_allreduce:
module: kernbench.ccl.algorithms.intercube_allreduce
topology: none
buffer_kind: tcm
n_elem: 8
root_cube: 15
""")
(tmp_path / "ccl.yaml").write_text(body)
return str(tmp_path)
def _worker(rank: int, n_sips: int, torch) -> None:
"""Per-SIP worker: allocate, fill, all_reduce, verify."""
from kernbench.policy.placement.dp import DPPolicy
torch.ahbm.set_device(rank)
dp = DPPolicy(
cube="row_wise", pe="replicate",
num_pes=1, num_cubes=N_CUBES,
)
tensor = torch.zeros(
(N_CUBES, N_ELEM), dtype="f16", dp=dp,
name=f"sip{rank}",
)
init_arr = np.full((N_CUBES, N_ELEM), float(rank + 1), dtype=np.float16)
tensor.copy_(torch.from_numpy(init_arr))
print(f"[SIP {rank}] input cube0[:4] = {tensor.numpy()[0][:4].tolist()}")
torch.distributed.all_reduce(tensor, op="sum")
arr = tensor.numpy()
expected = float(N_CUBES * sum(range(1, n_sips + 1)))
print(f"[SIP {rank}] output cube0[:4] = {arr[0][:4].tolist()}")
print(f"[SIP {rank}] output cube15[:4] = {arr[15][:4].tolist()}")
for cube_id in range(N_CUBES):
assert np.allclose(arr[cube_id], expected, rtol=1e-1, atol=1e-1), (
f"SIP{rank} cube {cube_id}: "
f"got {arr[cube_id][:4]}, expected {expected}"
)
if rank == 0:
print(f"\n intercube_allreduce (ws={n_sips}): "
f"{n_sips * N_CUBES} OK")
def test_distributed_intercube_allreduce(tmp_path, monkeypatch):
"""Full distributed path: init_process_group → mp.spawn → all_reduce."""
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
monkeypatch.chdir(_write_ccl_yaml(tmp_path))
topo = resolve_topology(str(TOPOLOGY_PATH))
engine = GraphEngine(topo.topology_obj, enable_data=True)
spec = topo.topology_obj.spec
n_sips = int(spec["system"]["sips"]["count"])
with RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="dist_intercube_ar",
spec=spec,
) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
assert ctx.distributed.get_world_size() == n_sips
t_start = engine._env.now
ctx.multiprocessing.spawn(
_worker, args=(n_sips, ctx), nprocs=n_sips,
)
t_end = engine._env.now
print(f"\n[distributed] sim latency = "
f"{t_end - t_start:.1f} ns ({(t_end - t_start) / 1000:.3f} us)")
+113
View File
@@ -0,0 +1,113 @@
"""Tests for configure_sfr_intercube_multisip neighbor table wiring.
Verifies that IPCQ neighbor tables are correctly installed for
intercube (pe0, 4×4 mesh N/S/E/W) + inter-SIP (pe0, all cubes,
global_E/global_W) communication.
"""
from __future__ import annotations
from pathlib import Path
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
N_CUBES = 16
def _engine_and_spec():
topo = resolve_topology(str(TOPOLOGY_PATH))
engine = GraphEngine(topo.topology_obj, enable_data=True)
return engine, topo.topology_obj.spec
def _merged_cfg():
cfg = load_ccl_config()
return resolve_algorithm_config(cfg, name="intercube_allreduce")
class TestConfigureSfrNeighborTables:
def test_world_size_and_rank_to_pe(self):
engine, spec = _engine_and_spec()
cfg = _merged_cfg()
plan = configure_sfr_intercube_multisip(engine, spec, cfg)
n_sips = int(spec["system"]["sips"]["count"])
assert plan["world_size"] == n_sips * N_CUBES
assert len(plan["rank_to_pe"]) == n_sips * N_CUBES
for pe_idx, (sip, cube, pe) in enumerate(plan["rank_to_pe"]):
assert pe == 0, f"pe_idx {pe_idx}: pe must be 0, got {pe}"
def test_corner_cube0_has_E_and_S_only(self):
"""Cube 0 (row=0, col=0) is NW corner: only E and S neighbors."""
engine, spec = _engine_and_spec()
cfg = _merged_cfg()
configure_sfr_intercube_multisip(engine, spec, cfg)
ipcq = engine._components["sip0.cube0.pe0.pe_ipcq"]
qp = ipcq.queue_pairs
assert "E" in qp, "cube 0 must have E neighbor"
assert "S" in qp, "cube 0 must have S neighbor"
assert "W" not in qp, "cube 0 (col=0) must NOT have W neighbor"
assert "N" not in qp, "cube 0 (row=0) must NOT have N neighbor"
assert qp["E"]["peer"].cube == 1
assert qp["S"]["peer"].cube == 4
def test_interior_cube5_has_all_four(self):
"""Cube 5 (row=1, col=1) is interior: N/S/E/W all present."""
engine, spec = _engine_and_spec()
cfg = _merged_cfg()
configure_sfr_intercube_multisip(engine, spec, cfg)
ipcq = engine._components["sip0.cube5.pe0.pe_ipcq"]
qp = ipcq.queue_pairs
assert qp["N"]["peer"].cube == 1
assert qp["S"]["peer"].cube == 9
assert qp["E"]["peer"].cube == 6
assert qp["W"]["peer"].cube == 4
def test_root_cube15_has_inter_sip(self):
"""Cube 15 (root, SE corner) has N, W + global_E/global_W."""
engine, spec = _engine_and_spec()
cfg = _merged_cfg()
configure_sfr_intercube_multisip(engine, spec, cfg)
ipcq0 = engine._components["sip0.cube15.pe0.pe_ipcq"]
qp0 = ipcq0.queue_pairs
assert "N" in qp0
assert "W" in qp0
assert "E" not in qp0, "cube 15 (col=3) must NOT have E"
assert "S" not in qp0, "cube 15 (row=3) must NOT have S"
assert "global_E" in qp0, "root cube must have global_E"
assert "global_W" in qp0, "root cube must have global_W"
assert qp0["global_E"]["peer"].sip == 1
assert qp0["global_E"]["peer"].cube == 15
ipcq1 = engine._components["sip1.cube15.pe0.pe_ipcq"]
qp1 = ipcq1.queue_pairs
assert qp1["global_E"]["peer"].sip == 0
assert qp1["global_E"]["peer"].cube == 15
def test_all_cubes_have_inter_sip(self):
"""ALL cubes (not just root) are wired for inter-SIP."""
engine, spec = _engine_and_spec()
cfg = _merged_cfg()
configure_sfr_intercube_multisip(engine, spec, cfg)
root_cube = int(cfg.get("root_cube", N_CUBES - 1))
for cube_id in range(N_CUBES):
ipcq = engine._components[f"sip0.cube{cube_id}.pe0.pe_ipcq"]
qp = ipcq.queue_pairs
assert "global_E" in qp, (
f"sip0.cube{cube_id}.pe0 missing global_E"
)
assert "global_W" in qp, (
f"sip0.cube{cube_id}.pe0 missing global_W"
)
if cube_id == root_cube:
assert qp["global_E"]["peer"].sip != 0, (
f"root cube {root_cube} global_E must point to another SIP"
)
-80
View File
@@ -1,80 +0,0 @@
"""Tests for recv_mode='copy_to_dst' (ADR-0023 D9.5)."""
from __future__ import annotations
import numpy as np
def test_recv_copy_to_dst_via_simpy_runner():
"""Run a kernel that uses tl.recv(..., dst_addr=..., dst_space=...).
Verify the data is moved to the dst location after recv.
"""
import importlib
from kernbench.policy.placement.dp import DPPolicy
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
from kernbench.common.pe_commands import TensorHandle
def kernel(t_ptr, n_elem, dst_buf_addr, tl):
rank = tl.program_id(axis=0)
ws = tl.num_programs(axis=0)
nbytes = n_elem * 2
# Each PE sends own data, then recv into a custom dst slot
current = TensorHandle(
id="loc", addr=t_ptr + rank * nbytes,
shape=(n_elem,), dtype="f16",
nbytes=nbytes, data=None, space="hbm",
)
tl.send(dir="E", src=current)
# copy_to_dst: move into a per-rank scratch HBM addr
recv = tl.recv(
dir="W", shape=(n_elem,), dtype="f16",
dst_addr=dst_buf_addr + rank * nbytes,
dst_space="hbm",
)
# Sanity: recv handle should now point to our dst addr
assert recv.addr == dst_buf_addr + rank * nbytes
assert recv.space == "hbm"
topo = resolve_topology("topology.yaml")
def run(torch):
plan = torch.install_ipcq(
algorithm="ring_allreduce_tcm", world_size_override=8,
)
a = torch.zeros(
(1, 8 * 8),
dtype="f16",
dp=DPPolicy(
sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1,
),
name="copy_in",
)
store = torch.engine.memory_store
base = a._handle.va_base or a._handle.shards[0].pa
nbytes = 8 * 2
for r in range(8):
store.write("hbm", base + r * nbytes,
np.full((8,), float(r + 1), dtype=np.float16))
# Use a separate dst region (synthetic addresses)
dst_buf = 0xC0FFEE_0000
torch.launch("ring_allreduce_tcm", kernel, a, 8, dst_buf)
# After the kernel, dst_buf + r*16 should contain rank (r-1)%8's data
for r in range(8):
arr = store.read("hbm", dst_buf + r * nbytes, shape=(8,), dtype="f16")
expected = float(((r - 1) % 8) + 1)
assert np.allclose(arr, expected), f"rank {r}: got {arr}, expected {expected}"
result = run_bench(
topology=topo, bench_fn=run,
device=resolve_device("all"),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True
),
)
assert result.completion.ok
+16 -16
View File
@@ -48,8 +48,8 @@ def test_from_numpy_creates_host_tensor():
assert h._handle is None assert h._handle is None
# Submit a no-op so run_bench has at least one handle. # Submit a no-op so run_bench has at least one handle.
torch.zeros((1, 8), dtype="f16", torch.zeros((1, 8), dtype="f16",
dp=DPPolicy(sip="replicate", cube="replicate", pe="replicate", dp=DPPolicy(cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1), num_cubes=1, num_pes=1),
name="dummy") name="dummy")
_run_with(body) _run_with(body)
@@ -63,8 +63,8 @@ def test_copy_and_numpy_single_pe():
a single-PE (no real sharding) tensor.""" a single-PE (no real sharding) tensor."""
def body(torch): def body(torch):
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", dp = DPPolicy(cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1) num_cubes=1, num_pes=1)
t = torch.zeros((1, 16), dtype="f16", dp=dp, name="t") t = torch.zeros((1, 16), dtype="f16", dp=dp, name="t")
src = np.arange(16, dtype=np.float16).reshape(1, 16) src = np.arange(16, dtype=np.float16).reshape(1, 16)
t.copy_(torch.from_numpy(src)) t.copy_(torch.from_numpy(src))
@@ -83,8 +83,8 @@ def test_copy_and_numpy_multi_pe_column_wise():
def body(torch): def body(torch):
n_pe = 8 n_pe = 8
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", dp = DPPolicy(cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=n_pe) num_cubes=1, num_pes=n_pe)
t = torch.zeros((1, n_pe * 4), dtype="f16", dp=dp, name="t") t = torch.zeros((1, n_pe * 4), dtype="f16", dp=dp, name="t")
src = np.arange(n_pe * 4, dtype=np.float16).reshape(1, n_pe * 4) src = np.arange(n_pe * 4, dtype=np.float16).reshape(1, n_pe * 4)
t.copy_(torch.from_numpy(src)) t.copy_(torch.from_numpy(src))
@@ -107,8 +107,8 @@ def test_copy_and_numpy_multi_cube():
n_pe_per_cube = 8 n_pe_per_cube = 8
n_cubes = 2 n_cubes = 2
total = n_cubes * n_pe_per_cube # 16 total = n_cubes * n_pe_per_cube # 16
dp = DPPolicy(sip="replicate", cube="column_wise", pe="column_wise", dp = DPPolicy(cube="column_wise", pe="column_wise",
num_sips=1, num_cubes=n_cubes) num_cubes=n_cubes)
t = torch.zeros((1, total * 4), dtype="f16", dp=dp, name="t") t = torch.zeros((1, total * 4), dtype="f16", dp=dp, name="t")
src = np.arange(total * 4, dtype=np.float16).reshape(1, total * 4) src = np.arange(total * 4, dtype=np.float16).reshape(1, total * 4)
t.copy_(torch.from_numpy(src)) t.copy_(torch.from_numpy(src))
@@ -126,8 +126,8 @@ def test_copy_shape_mismatch_raises():
"""copy_ with mismatched shapes raises ValueError.""" """copy_ with mismatched shapes raises ValueError."""
def body(torch): def body(torch):
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", dp = DPPolicy(cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1) num_cubes=1, num_pes=1)
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t") t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
src = np.zeros((1, 16), dtype=np.float16) src = np.zeros((1, 16), dtype=np.float16)
with pytest.raises(ValueError, match="copy_ shape mismatch"): with pytest.raises(ValueError, match="copy_ shape mismatch"):
@@ -143,8 +143,8 @@ def test_setitem_getitem_single_pe():
"""Scalar and slice assignment on a single-PE tensor round-trips.""" """Scalar and slice assignment on a single-PE tensor round-trips."""
def body(torch): def body(torch):
dp = DPPolicy(sip="replicate", cube="replicate", pe="replicate", dp = DPPolicy(cube="replicate", pe="replicate",
num_sips=1, num_cubes=1, num_pes=1) num_cubes=1, num_pes=1)
t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t") t = torch.zeros((1, 8), dtype="f16", dp=dp, name="t")
# Scalar broadcast # Scalar broadcast
@@ -169,8 +169,8 @@ def test_setitem_getitem_multi_pe_shard_aligned():
def body(torch): def body(torch):
n_pe = 8 n_pe = 8
n_elem = 4 # per shard n_elem = 4 # per shard
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", dp = DPPolicy(cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=n_pe) num_cubes=1, num_pes=n_pe)
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
# Write each shard with its rank value # Write each shard with its rank value
@@ -197,8 +197,8 @@ def test_setitem_cross_shard_raises():
def body(torch): def body(torch):
n_pe = 4 n_pe = 4
n_elem = 4 n_elem = 4
dp = DPPolicy(sip="replicate", cube="replicate", pe="column_wise", dp = DPPolicy(cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1, num_pes=n_pe) num_cubes=1, num_pes=n_pe)
t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t") t = torch.zeros((1, n_pe * n_elem), dtype="f16", dp=dp, name="t")
with pytest.raises(NotImplementedError, match="spans multiple shards"): with pytest.raises(NotImplementedError, match="spans multiple shards"):
t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8) t[0, 2:6] = 1.0 # crosses shard 0 (0:4) and shard 1 (4:8)
+90 -127
View File
@@ -1,157 +1,120 @@
"""Tests for SIP-level tensor parallelism. """Tests for SIP-level tensor parallelism — ADR-0026 structural model.
Validates: DPPolicy no longer carries a ``sip`` axis (ADR-0026 D1). SIP placement is
SP1. DPPolicy accepts sip field (default "replicate", backward compat) now expressed structurally: each call to ``resolve_dp_policy(target_sip=N)``
SP2. sip="column_wise": tensor K-axis split across SIPs, each SIP gets K//num_sips emits shards pinned to SIP N. Multi-SIP parallelism is composed by calling
SP3. sip="row_wise": tensor M-axis split across SIPs the resolver once per SIP (typically driven by the ADR-0024 launcher, one
SP4. 3-level resolve: sip × cube × pe produces correct flat indices and offsets worker greenlet per rank, each worker using ``torch.ahbm.set_device(rank)``).
SP5. sip="replicate": all SIPs get full copy (existing behavior)
SP6. PE_CPU sets num_programs from shard count per cube Covered here:
SP7. End-to-end: TP kernel with sip="column_wise" completes on multi-SIP topology SP1. ``target_sip`` stamps every shard.
SP2. Two-SIP placement: union of two resolver calls covers the whole
tensor K-axis when the combined bench treats them as column-split.
SP3. Same for row-wise.
SP4. Cube + PE sharding within a SIP remains correct across SIPs.
SP5. PE_CPU num_programs contract (unchanged by ADR-0026).
""" """
import pytest from __future__ import annotations
from pathlib import Path
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
# ── SP1. DPPolicy sip field ────────────────────────────────────────── # ── SP1. target_sip stamps shards ────────────────────────────────────
def test_dp_policy_sip_default_replicate(): def test_target_sip_stamps_all_shards():
"""DPPolicy without sip= defaults to 'replicate'."""
dp = DPPolicy(cube="replicate", pe="column_wise") dp = DPPolicy(cube="replicate", pe="column_wise")
assert dp.sip == "replicate"
def test_dp_policy_sip_column_wise():
"""DPPolicy accepts sip='column_wise'."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
assert dp.sip == "column_wise"
# ── SP2. sip="column_wise" ──────────────────────────────────────────────
def test_sip_column_wise_splits_across_sips():
"""sip='column_wise' with 2 SIPs: each SIP gets K//2 columns."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy( shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2, dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2, num_pe=8, num_cubes=1, target_sip=3,
) )
# 2 SIPs × 1 cube × 8 PEs = 16 shards assert all(s.sip == 3 for s in shards)
assert len(shards) == 16 assert all(0 <= s.pe < 8 for s in shards)
assert all(s.cube == 0 for s in shards)
# SIP0 shards: first half of K (0 to K//2)
# SIP1 shards: second half of K (K//2 to K)
total_bytes = 128 * 256 * 2 # 64KB
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0 offsets start at 0
assert sip0_shards[0].offset_bytes == 0
# SIP1 offsets start at half
assert sip1_shards[0].offset_bytes == total_bytes // 2
# Total coverage
assert sum(s.nbytes for s in sip0_shards) == total_bytes // 2
assert sum(s.nbytes for s in sip1_shards) == total_bytes // 2
# ── SP3. sip="row_wise" ────────────────────────────────────────────── # ── SP2. column-wise placement composed across two SIPs ─────────────
def test_sip_row_wise_splits_across_sips(): def test_compose_two_sips_column_wise_covers_tensor():
"""sip='row_wise' with 2 SIPs: each SIP gets M//2 rows.""" """Bench splits K-axis across 2 SIPs by calling resolve twice and
dp = DPPolicy(sip="row_wise", cube="replicate", pe="column_wise") giving each SIP half of the tensor (half-shape + offset). Shards
shards = resolve_dp_policy( from both SIPs together cover the whole K axis."""
full_shape = (128, 256)
itemsize = 2
# Per-SIP half-shape (K split across SIPs).
half_shape = (128, 128)
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
sip0_bytes = sum(s.nbytes for s in shards_sip0)
sip1_bytes = sum(s.nbytes for s in shards_sip1)
assert sip0_bytes + sip1_bytes == total_bytes
assert all(s.sip == 0 for s in shards_sip0)
assert all(s.sip == 1 for s in shards_sip1)
# ── SP3. row-wise placement composed across two SIPs ────────────────
def test_compose_two_sips_row_wise_covers_tensor():
full_shape = (128, 256)
itemsize = 2
half_shape = (64, 256) # per-SIP half of M
dp = DPPolicy(cube="replicate", pe="column_wise")
shards_sip0 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=0,
)
shards_sip1 = resolve_dp_policy(
dp, shape=half_shape, itemsize=itemsize,
num_pe=8, num_cubes=1, target_sip=1,
)
total_bytes = full_shape[0] * full_shape[1] * itemsize
assert sum(s.nbytes for s in shards_sip0) + sum(s.nbytes for s in shards_sip1) == total_bytes
# ── SP4. cube × PE sharding is independent per SIP ──────────────────
def test_cube_pe_sharding_independent_per_sip():
"""Intra-SIP cube + PE layout matches across SIPs; only sip field differs."""
dp = DPPolicy(cube="column_wise", pe="column_wise")
s0 = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2, dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=1, num_sips=2, num_pe=4, num_cubes=2, target_sip=0,
) )
assert len(shards) == 16 s1 = resolve_dp_policy(
sip0_shards = [s for s in shards if s.pe_index < 8]
sip1_shards = [s for s in shards if s.pe_index >= 8]
# SIP0: rows 0..63, SIP1: rows 64..127
total_bytes = 128 * 256 * 2
assert sip0_shards[0].offset_bytes == 0
assert sip1_shards[0].offset_bytes == total_bytes // 2
# ── SP4. 3-level resolve ─────────────────────────────────────────────
def test_3level_resolve_flat_index():
"""3-level: sip × cube × pe produces correct flat indices."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2, dp, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2, num_pe=4, num_cubes=2, target_sip=1,
) )
# 2 SIPs × 2 cubes × 8 PEs = 32 shards assert len(s0) == len(s1) == 2 * 4
assert len(shards) == 32 for a, b in zip(s0, s1):
assert a.sip == 0 and b.sip == 1
# Flat index: sip_id * cubes_per_sip * num_pe + cube_id * num_pe + pe_id assert (a.cube, a.pe, a.offset_bytes, a.nbytes) == (
indices = [s.pe_index for s in shards] b.cube, b.pe, b.offset_bytes, b.nbytes
# SIP0: 0..15, SIP1: 16..31
assert min(indices) == 0
assert max(indices) == 31
assert len(set(indices)) == 32 # all unique
def test_3level_offsets_cover_full_tensor():
"""3-level sharding covers the entire tensor with no gaps."""
dp = DPPolicy(sip="column_wise", cube="replicate", pe="column_wise")
shards = resolve_dp_policy(
dp, shape=(128, 256), itemsize=2,
num_pe=4, num_cubes=1, num_sips=2,
) )
# 2 SIPs × 1 cube × 4 PEs = 8 shards
# sip="column_wise": K=128 per SIP, pe="column_wise": 32 cols per PE
total = 128 * 256 * 2
# For non-replicate, total shard bytes == tensor bytes
# (replicate within cube means cube shards overlap, but sip shards don't)
sip0_bytes = sum(s.nbytes for s in shards if s.pe_index < 4)
sip1_bytes = sum(s.nbytes for s in shards if s.pe_index >= 4)
assert sip0_bytes + sip1_bytes == total
# ── SP5. sip="replicate" backward compat ───────────────────────────── # ── SP5. PE_CPU num_programs (contract unchanged) ───────────────────
def test_sip_replicate_backward_compat():
"""sip='replicate' produces same result as before (2-level)."""
dp_old = DPPolicy(cube="replicate", pe="column_wise")
dp_new = DPPolicy(sip="replicate", cube="replicate", pe="column_wise")
shards_old = resolve_dp_policy(
dp_old, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
shards_new = resolve_dp_policy(
dp_new, shape=(128, 256), itemsize=2,
num_pe=8, num_cubes=2, num_sips=2,
)
assert len(shards_old) == len(shards_new)
for a, b in zip(shards_old, shards_new):
assert a.pe_index == b.pe_index
assert a.offset_bytes == b.offset_bytes
assert a.nbytes == b.nbytes
# ── SP6. PE_CPU num_programs ──────────────────────────────────────────
def test_pe_cpu_sets_num_programs(): def test_pe_cpu_sets_num_programs():
"""PE_CPU should create TLContext with num_programs = PEs per cube.""" """TLContext reports num_programs from its initializer — used by PE_CPU
# This test validates the interface contract. when it launches a kernel on behalf of its shards."""
# After implementation, PE_CPU should derive num_programs from the
# number of PE shards in the kernel launch's target cube.
from kernbench.triton_emu.tl_context import TLContext from kernbench.triton_emu.tl_context import TLContext
# With 8 PEs per cube, num_programs should be 8
tl = TLContext(pe_id=3, num_programs=8) tl = TLContext(pe_id=3, num_programs=8)
assert tl.program_id(0) == 3 assert tl.program_id(0) == 3
assert tl.num_programs(0) == 8 assert tl.num_programs(0) == 8
+23 -17
View File
@@ -2,11 +2,13 @@ import pytest
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
from kernbench.policy.placement.dp import ( from kernbench.policy.placement.dp import (
DPPolicy,
ShardSpec, ShardSpec,
column_wise, column_wise,
tiled_column_major,
replicate, replicate,
resolve_dp_policy,
row_wise, row_wise,
tiled_column_major,
tiled_row_major, tiled_row_major,
) )
from kernbench.runtime_api.kernel import ( from kernbench.runtime_api.kernel import (
@@ -40,9 +42,9 @@ _CFG = AddressConfig(
) )
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]: def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
return { return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe) for i in range(num_pe)
} }
@@ -133,7 +135,7 @@ def test_column_wise_placement():
assert len(shards) == 8 assert len(shards) == 8
expected_nbytes = 1024 * 64 * 2 # 128 KB expected_nbytes = 1024 * 64 * 2 # 128 KB
for i, s in enumerate(shards): for i, s in enumerate(shards):
assert s.pe_index == i assert s.local_pe == i
assert s.nbytes == expected_nbytes assert s.nbytes == expected_nbytes
# offsets are contiguous # offsets are contiguous
assert shards[0].offset_bytes == 0 assert shards[0].offset_bytes == 0
@@ -151,7 +153,7 @@ def test_row_wise_placement():
assert len(shards) == 8 assert len(shards) == 8
expected_nbytes = 128 * 512 * 2 # 128 KB expected_nbytes = 128 * 512 * 2 # 128 KB
for i, s in enumerate(shards): for i, s in enumerate(shards):
assert s.pe_index == i assert s.local_pe == i
assert s.nbytes == expected_nbytes assert s.nbytes == expected_nbytes
assert shards[0].offset_bytes == 0 assert shards[0].offset_bytes == 0
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
@@ -166,7 +168,7 @@ def test_replicate_placement():
assert len(shards) == 8 assert len(shards) == 8
full_nbytes = 1024 * 512 * 2 # 1 MB full_nbytes = 1024 * 512 * 2 # 1 MB
for i, s in enumerate(shards): for i, s in enumerate(shards):
assert s.pe_index == i assert s.local_pe == i
assert s.nbytes == full_nbytes assert s.nbytes == full_nbytes
assert s.offset_bytes == 0 # each is a full copy assert s.offset_bytes == 0 # each is a full copy
@@ -188,10 +190,10 @@ def test_tiled_column_major():
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3 # tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7 # tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
# tile (m=2,k=0) → PE0, ... # tile (m=2,k=0) → PE0, ...
assert shards[0].pe_index == 0 assert shards[0].local_pe == 0
assert shards[1].pe_index == 1 assert shards[1].local_pe == 1
assert shards[7].pe_index == 7 assert shards[7].local_pe == 7
assert shards[8].pe_index == 0 # wraps around assert shards[8].local_pe == 0 # wraps around
# total coverage # total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
@@ -212,10 +214,10 @@ def test_tiled_row_major():
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3 # tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7 # tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
# tile (m=0,k=2) → PE0, ... # tile (m=0,k=2) → PE0, ...
assert shards[0].pe_index == 0 assert shards[0].local_pe == 0
assert shards[1].pe_index == 1 assert shards[1].local_pe == 1
assert shards[7].pe_index == 7 assert shards[7].local_pe == 7
assert shards[8].pe_index == 0 # wraps around assert shards[8].local_pe == 0 # wraps around
# total coverage # total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2 assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
@@ -226,7 +228,11 @@ def test_tiled_row_major():
def test_deploy_tensor_hbm(): def test_deploy_tensor_hbm():
"""Deploy with column_wise placement → TensorHandle with valid PA shards.""" """Deploy with column_wise placement → TensorHandle with valid PA shards."""
allocs = _make_allocators() allocs = _make_allocators()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
th = deploy_tensor( th = deploy_tensor(
name="W", name="W",
shape=(1024, 512), shape=(1024, 512),
@@ -253,7 +259,7 @@ def test_deploy_tensor_hbm():
def test_deploy_tensor_tcm(): def test_deploy_tensor_tcm():
"""Deploy with TCM → uses pe_tcm_addr allocation.""" """Deploy with TCM → uses pe_tcm_addr allocation."""
allocs = _make_allocators() allocs = _make_allocators()
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)] placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=256)]
th = deploy_tensor( th = deploy_tensor(
name="small", name="small",
shape=(128,), shape=(128,),
@@ -271,7 +277,7 @@ def test_deploy_tensor_overflow():
"""Allocation exceeding PE HBM capacity raises AllocationError.""" """Allocation exceeding PE HBM capacity raises AllocationError."""
allocs = _make_allocators() allocs = _make_allocators()
# 6 GB per PE slice, try to allocate 7 GB # 6 GB per PE slice, try to allocate 7 GB
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB) big_shard = ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=7 * _GB)
with pytest.raises(AllocationError): with pytest.raises(AllocationError):
deploy_tensor( deploy_tensor(
name="toobig", name="toobig",
-106
View File
@@ -1,106 +0,0 @@
"""Tests for tl.recv_async + tl.wait (ADR-0023 D4)."""
from __future__ import annotations
import numpy as np
from kernbench.ccl.testing import run_kernel_in_mock
def kernel_async_recv(t_ptr, n_elem, tl):
"""Each PE issues recv_async first, then send, then wait — this exercises
the non-blocking path. Uses TensorHandle math (PE_MATH) for accumulation
so Phase 2 produces correct final HBM contents."""
rank = tl.program_id(axis=0)
world_size = tl.num_programs(axis=0)
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
current = acc
for _step in range(world_size - 1):
future = tl.recv_async(dir="W", shape=(n_elem,), dtype="f16")
tl.send(dir="E", src=current)
recv = tl.wait(future)
acc = acc + recv
current = recv # forward W's tile to E next round
tl.store(pe_addr, acc)
def test_recv_async_mock_runtime():
n_elem = 8
inputs = [
np.full((n_elem,), float(r + 1), dtype=np.float16)
for r in range(4)
]
expected = sum(inputs)
outputs = run_kernel_in_mock(
kernel_fn=kernel_async_recv,
world_size=4,
topology="ring_1d",
inputs=inputs,
kernel_args=(n_elem,),
)
for r in range(4):
assert np.allclose(outputs[r], expected)
def test_recv_async_simpy_runner():
"""Run the async kernel through the real SimPy stack via the
install_ipcq + launch path.
"""
import importlib
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
# Re-use the standard 8-PE bench skeleton but swap in the async kernel.
topo = resolve_topology("topology.yaml")
# Build a tiny inline bench module
import types
mod = types.ModuleType("inline_bench_async")
from kernbench.policy.placement.dp import DPPolicy
def run(torch):
plan = torch.install_ipcq(
algorithm="ring_allreduce_tcm", world_size_override=8,
)
a = torch.zeros(
(1, 8 * 8),
dtype="f16",
dp=DPPolicy(
sip="replicate", cube="replicate", pe="column_wise",
num_sips=1, num_cubes=1,
),
name="async_in",
)
store = torch.engine.memory_store
base = a._handle.va_base or a._handle.shards[0].pa
nbytes = 8 * 2
for r in range(8):
store.write("hbm", base + r * nbytes,
np.full((8,), float(r + 1), dtype=np.float16))
torch.launch("ring_allreduce_tcm", kernel_async_recv, a, 8)
for r in range(8):
result = store.read("hbm", base + r * nbytes, shape=(8,), dtype="f16")
expected = float(sum(range(1, 9))) # 36
assert np.allclose(result, expected, rtol=1e-2, atol=1e-2), \
f"rank {r}: got {result}, expected {expected}"
mod.run = run
result = run_bench(
topology=topo, bench_fn=mod.run,
device=resolve_device("all"),
engine_factory=lambda t, d: GraphEngine(
getattr(t, "topology_obj", t), enable_data=True
),
)
assert result.completion.ok
+234
View File
@@ -0,0 +1,234 @@
"""ADR-0027 T2: TP layer shape + numerical correctness (D4/D5).
Phase 1: ``kernbench.tp.layers`` doesn't exist → import failure. Phase 2
lands D4/D5 and T2 passes with deterministic non-zero weight patterns.
"""
from __future__ import annotations
import numpy as np
import pytest
def _make_ctx(topology):
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
return RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_t2",
spec=topology.topology_obj.spec,
)
# ── Shape / structural ───────────────────────────────────────────────
def test_column_parallel_weight_shape_per_rank(topology):
"""ColumnParallelLinear weight per rank is (in_features, out // ws)."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc = tp.ColumnParallelLinear(
in_features=256, out_features=512, torch=ctx,
)
assert fc.weight.shape == (256, 512 // ws)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
def test_row_parallel_weight_shape_per_rank(topology):
"""RowParallelLinear weight per rank is (in_features // ws, out_features)."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc = tp.RowParallelLinear(
in_features=512, out_features=256, torch=ctx,
)
assert fc.weight.shape == (512 // ws, 256)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
# ── T2.a: ColumnParallel deterministic numerical ─────────────────────
def test_column_parallel_forward_matches_matmul(topology):
"""T2.a: ColumnParallelLinear.forward output == x @ W_rank (rtol 1e-2)."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
M = 4
D_in, D_out = 32, 32 * ws
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc = tp.ColumnParallelLinear(
in_features=D_in, out_features=D_out, torch=ctx,
)
# Deterministic non-zero weight: rank-scaled constant.
k_local = D_out // ws
weight_np = np.full(
(D_in, k_local), 0.01 * (rank + 1), dtype=np.float16,
)
src = Tensor(shape=(D_in, k_local), dtype="f16", name="host_w")
src._host_buffer = weight_np
fc.weight.copy_(src)
# Input: full-replicated constant.
x_np = np.full((M, D_in), 0.5, dtype=np.float16)
x = ctx.zeros(
(M, D_in), dtype="f16",
dp=_replicate_dp(), name=f"t2a_x_r{rank}",
)
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
hx._host_buffer = x_np
x.copy_(hx)
y = fc.forward(x)
out = y.numpy()
expected = x_np.astype(np.float32) @ weight_np.astype(np.float32)
assert out.shape == (M, k_local)
assert np.allclose(out.astype(np.float32), expected,
rtol=1e-2, atol=1e-2), (
f"rank {rank}: output does not match x @ W_local"
)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
# ── T2.b: RowParallel observable equality ────────────────────────────
def test_row_parallel_forward_concat_matmul_equality(topology):
"""T2.b (primary): RowParallel output == concat(x) @ concat(W) (all-reduced)."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
M = 4
D_in, D_out = 32 * ws, 32 # must divide ws evenly
results: dict[int, np.ndarray] = {}
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc = tp.RowParallelLinear(
in_features=D_in, out_features=D_out, torch=ctx,
)
# Per-rank W_k = constant 0.01 * (rank + 1)
n_local = D_in // ws
weight_np = np.full(
(n_local, D_out), 0.01 * (rank + 1), dtype=np.float16,
)
src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w")
src._host_buffer = weight_np
fc.weight.copy_(src)
# Input x_k = constant 0.1 * (rank + 1) (pretending it was
# column-sharded from upstream).
x_np = np.full((M, n_local), 0.1 * (rank + 1), dtype=np.float16)
x = ctx.zeros(
(M, n_local), dtype="f16",
dp=_replicate_dp(), name=f"t2b_x_r{rank}",
)
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
hx._host_buffer = x_np
x.copy_(hx)
y = fc.forward(x)
results[rank] = y.numpy().astype(np.float32)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
# Host-side reference: compute sum_r (x_r @ W_r) = y (same on all ranks).
expected = np.zeros((M, D_out), dtype=np.float32)
n_local = D_in // ws
for r in range(ws):
x_r = np.full((M, n_local), 0.1 * (r + 1), dtype=np.float32)
w_r = np.full((n_local, D_out), 0.01 * (r + 1), dtype=np.float32)
expected += x_r @ w_r
for r, out in results.items():
assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), (
f"rank {r}: all-reduced output != expected partial sum"
)
# ── T2.c: rank-consistency post all-reduce ───────────────────────────
def test_row_parallel_rank_identity_post_all_reduce(topology):
"""T2.c: after all_reduce, all ranks see elementwise-identical output."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
M = 2
D_in, D_out = 16 * ws, 16
results: dict[int, np.ndarray] = {}
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc = tp.RowParallelLinear(
in_features=D_in, out_features=D_out, torch=ctx,
)
n_local = D_in // ws
weight_np = np.full((n_local, D_out), 0.01, dtype=np.float16)
src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w")
src._host_buffer = weight_np
fc.weight.copy_(src)
x_np = np.full((M, n_local), 0.1, dtype=np.float16)
x = ctx.zeros(
(M, n_local), dtype="f16",
dp=_replicate_dp(), name=f"t2c_x_r{rank}",
)
hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x")
hx._host_buffer = x_np
x.copy_(hx)
y = fc.forward(x)
results[rank] = y.numpy()
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
ref = results[0]
for r, out in results.items():
assert np.allclose(out, ref, rtol=1e-2, atol=1e-2), (
f"rank {r} output differs from rank 0 — all_reduce failed to make "
f"outputs elementwise identical"
)
def _replicate_dp():
from kernbench.policy.placement.dp import DPPolicy
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
+238
View File
@@ -0,0 +1,238 @@
"""ADR-0027 T6: End-to-end 2-layer MLP with TP.
Phase 1: fails at imports. Phase 2 lands the TP package + D7 bench pattern
and these pass with numerical-correctness checks.
"""
from __future__ import annotations
import numpy as np
import pytest
def _make_ctx(topology):
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
return RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_t6",
spec=topology.topology_obj.spec,
)
def _replicate_dp():
from kernbench.policy.placement.dp import DPPolicy
return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
# ── T6.a: zero-weight smoke ──────────────────────────────────────────
def test_mlp_zero_weight_produces_zero_output(topology):
"""T6.a: zero-init weight → output ≈ 0 for every rank."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
B, D_in, D_hidden, D_out = 1, 32, 32 * ws, 32
outputs: dict[int, np.ndarray] = {}
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
x = ctx.zeros((B, D_in), dtype="f16",
dp=_replicate_dp(), name=f"t6a_x_r{rank}")
from kernbench.runtime_api.tensor import Tensor
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16)
x.copy_(hx)
h = fc1.forward(x)
y = fc2.forward(h)
outputs[rank] = y.numpy()
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
for r, out in outputs.items():
assert np.allclose(out, 0.0, atol=1e-2), (
f"rank {r}: zero-weight output should be ~0; got mean={out.mean()}"
)
# ── T6.b: deterministic weight + numerical check ─────────────────────
def test_mlp_deterministic_weight_matches_reference(topology):
"""T6.b: non-zero deterministic weights → output matches numpy reference."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
# W1 (D_in, D_hidden) — column-sharded; per rank: (D_in, D_hidden/ws)
# W2 (D_hidden, D_out) — row-sharded; per rank: (D_hidden/ws, D_out)
# Constant values: W1 = 0.02, W2 = 0.03, x = 0.1 (all fp16).
X_VAL, W1_VAL, W2_VAL = 0.1, 0.02, 0.03
outputs: dict[int, np.ndarray] = {}
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
# W1 slice (per rank column slice)
k_local_1 = D_hidden // ws
w1_np = np.full((D_in, k_local_1), W1_VAL, dtype=np.float16)
src1 = Tensor(shape=w1_np.shape, dtype="f16", name="host_w1")
src1._host_buffer = w1_np
fc1.weight.copy_(src1)
# W2 slice (per rank row slice)
n_local_2 = D_hidden // ws
w2_np = np.full((n_local_2, D_out), W2_VAL, dtype=np.float16)
src2 = Tensor(shape=w2_np.shape, dtype="f16", name="host_w2")
src2._host_buffer = w2_np
fc2.weight.copy_(src2)
# Input x (full-replicated constant)
x = ctx.zeros((B, D_in), dtype="f16",
dp=_replicate_dp(), name=f"t6b_x_r{rank}")
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
hx._host_buffer = np.full((B, D_in), X_VAL, dtype=np.float16)
x.copy_(hx)
h = fc1.forward(x)
y = fc2.forward(h)
outputs[rank] = y.numpy().astype(np.float32)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
# Host reference: y = x @ W1_full @ W2_full
w1_full = np.full((D_in, D_hidden), W1_VAL, dtype=np.float32)
w2_full = np.full((D_hidden, D_out), W2_VAL, dtype=np.float32)
x_full = np.full((B, D_in), X_VAL, dtype=np.float32)
expected = x_full @ w1_full @ w2_full
for r, out in outputs.items():
assert out.shape == (B, D_out)
assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), (
f"rank {r}: MLP output != reference "
f"(got mean={out.mean():.4f}, expected={expected.mean():.4f})"
)
# ── T6.c: rank-consistency after final all_reduce ────────────────────
def test_mlp_rank_consistency_after_all_reduce(topology):
"""T6.c: all ranks see elementwise-identical final output."""
import kernbench.tp as tp
from kernbench.runtime_api.tensor import Tensor
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
outputs: dict[int, np.ndarray] = {}
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
# Zero weights OK for this check — just need all_reduce to run.
x = ctx.zeros((B, D_in), dtype="f16",
dp=_replicate_dp(), name=f"t6c_x_r{rank}")
hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x")
hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16)
x.copy_(hx)
h = fc1.forward(x)
y = fc2.forward(h)
outputs[rank] = y.numpy()
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
ref = outputs[0]
for r, out in outputs.items():
assert np.array_equal(out, ref), (
f"rank {r} output differs from rank 0 — all-reduce should "
f"make every rank see the same final tensor"
)
# ── T6.d: shape contract ─────────────────────────────────────────────
def test_mlp_shape_contract(topology):
"""T6.d: ColumnParallel → (B, D_hidden/ws); RowParallel → (B, D_out)."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx)
x = ctx.zeros((B, D_in), dtype="f16",
dp=_replicate_dp(), name=f"t6d_x_r{rank}")
h = fc1.forward(x)
assert h.shape == (B, D_hidden // ws), (
f"ColumnParallel output shape: {h.shape} != (B, D_hidden/ws)"
)
y = fc2.forward(h)
assert y.shape == (B, D_out), (
f"RowParallel output shape: {y.shape} != (B, D_out)"
)
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
# ── liveness: deadlock 없음 (pytest timeout 간접 검증) ───────────────
def test_mlp_completes_without_deadlock(topology):
"""Structural: full E2E spawn returns within a reasonable wall-clock.
Relies on the test suite's overall timeout harness. If this hangs
beyond ~60s it would surface as a pytest timeout — a deadlock
regression in the scheduler loop would manifest here."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
def _worker(rank: int):
ctx.ahbm.set_device(rank)
fc1 = tp.ColumnParallelLinear(16, 16 * ws, torch=ctx)
fc2 = tp.RowParallelLinear(16 * ws, 16, torch=ctx)
x = ctx.zeros((1, 16), dtype="f16",
dp=_replicate_dp(), name=f"t6live_r{rank}")
h = fc1.forward(x)
y = fc2.forward(h)
_ = y.numpy()
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
+85
View File
@@ -0,0 +1,85 @@
"""ADR-0027 T1: TP parallel_state (D3).
Phase 1: ``kernbench.tp`` module does not exist yet — tests fail at import.
Phase 2 (D2/D3) lands the package and these pass.
"""
from __future__ import annotations
import pytest
def _make_ctx(topology):
from kernbench.runtime_api.context import RuntimeContext
from kernbench.runtime_api.types import DeviceSelector
from kernbench.sim_engine.engine import GraphEngine
engine = GraphEngine(topology.topology_obj, enable_data=True)
return RuntimeContext(
engine=engine,
target_device=DeviceSelector("all"),
correlation_id="test_t1",
spec=topology.topology_obj.spec,
)
def test_tp_package_importable():
"""D2: kernbench.tp must be importable."""
import kernbench.tp as tp
assert hasattr(tp, "initialize_model_parallel")
assert hasattr(tp, "get_tensor_model_parallel_world_size")
assert hasattr(tp, "get_tensor_model_parallel_rank")
def test_initialize_model_parallel_matches_world_size(topology, tmp_path, monkeypatch):
"""D3: TP size must equal dist world_size; otherwise NotImplementedError."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
assert tp.get_tensor_model_parallel_world_size() == ws
def test_initialize_mismatched_ws_raises(topology):
"""D3: calling with tp_size != world_size raises NotImplementedError."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
with pytest.raises(NotImplementedError):
tp.initialize_model_parallel(ws + 1)
def test_get_tp_rank_is_greenlet_local(topology):
"""D3: get_tensor_model_parallel_rank returns greenlet-local rank
(delegates to torch.distributed.get_rank, ADR-0024 D9)."""
import kernbench.tp as tp
with _make_ctx(topology) as ctx:
ctx.distributed.init_process_group(backend="ahbm")
ws = ctx.distributed.get_world_size()
tp.initialize_model_parallel(ws)
observed: list[int] = []
def _worker(rank: int):
observed.append(tp.get_tensor_model_parallel_rank())
ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)
assert sorted(observed) == list(range(ws))
def test_get_world_size_before_init_raises():
"""D3: uninitialised TP group → accessing world_size fails informatively."""
from kernbench.tp import parallel_state
# Reset internal state if previous tests (or parallel workers) left it set.
parallel_state._reset_for_tests()
with pytest.raises((RuntimeError, AssertionError, TypeError)):
_ = parallel_state.get_tensor_model_parallel_world_size() + 0
+19 -7
View File
@@ -12,7 +12,7 @@ import pytest
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import column_wise, ShardSpec from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.runtime_api.tensor import ( from kernbench.runtime_api.tensor import (
TensorHandle, TensorHandle,
TensorShard, TensorShard,
@@ -37,9 +37,9 @@ _CFG = AddressConfig(
) )
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]: def _make_allocators(num_pe: int = 8) -> dict[tuple[int, int, int], PEMemAllocator]:
return { return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG) (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe) for i in range(num_pe)
} }
@@ -88,7 +88,11 @@ def test_deploy_tensor_assigns_va_base():
"""deploy_tensor with VA allocator assigns va_base to TensorHandle.""" """deploy_tensor with VA allocator assigns va_base to TensorHandle."""
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
th = deploy_tensor( th = deploy_tensor(
name="W", name="W",
@@ -107,7 +111,11 @@ def test_deploy_tensor_va_covers_all_shards():
"""VA allocation covers the entire tensor; each shard is at va_base + offset.""" """VA allocation covers the entire tensor; each shard is at va_base + offset."""
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
th = deploy_tensor( th = deploy_tensor(
name="W", name="W",
@@ -128,7 +136,11 @@ def test_deploy_tensor_does_not_install_mmu_mappings():
allocs = _make_allocators() allocs = _make_allocators()
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
mmus = _make_mmus() mmus = _make_mmus()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1024, 512), itemsize=2,
num_pe=8, num_cubes=1, target_sip=0,
)
deploy_tensor( deploy_tensor(
name="W", name="W",
@@ -153,7 +165,7 @@ def test_tensor_va_property():
allocs = _make_allocators(1) allocs = _make_allocators(1)
va_alloc = _make_va_allocator() va_alloc = _make_va_allocator()
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=4096)] placement = [ShardSpec(sip=0, cube=0, pe=0, offset_bytes=0, nbytes=4096)]
t = Tensor(shape=(2048,), dtype="f16", name="test") t = Tensor(shape=(2048,), dtype="f16", name="test")
t._handle = deploy_tensor( t._handle = deploy_tensor(
+15 -5
View File
@@ -20,7 +20,7 @@ from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
from kernbench.policy.address.pe_mmu import PeMMU from kernbench.policy.address.pe_mmu import PeMMU
from kernbench.policy.address.phyaddr import PhysAddr from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.address.va_allocator import VirtualAllocator from kernbench.policy.address.va_allocator import VirtualAllocator
from kernbench.policy.placement.dp import DPPolicy, column_wise from kernbench.policy.placement.dp import DPPolicy, resolve_dp_policy
from kernbench.runtime_api.tensor import deploy_tensor from kernbench.runtime_api.tensor import deploy_tensor
from kernbench.sim_engine.engine import GraphEngine from kernbench.sim_engine.engine import GraphEngine
from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.context import RuntimeContext
@@ -70,7 +70,7 @@ def _make_standalone(shape, num_pe=NUM_PE):
sram_bytes_per_cube=32 * _MB, sram_bytes_per_cube=32 * _MB,
) )
allocators = { allocators = {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg) (0, 0, i): PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=cfg)
for i in range(num_pe) for i in range(num_pe)
} }
va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096) va_alloc = VirtualAllocator(va_base=0x1_0000_0000, va_size=64 * _GB, page_size=4096)
@@ -110,7 +110,11 @@ def test_2d_va_translates_to_local_hbm():
cols_per_pe = K // NUM_PE cols_per_pe = K // NUM_PE
block_bytes = M * cols_per_pe * ELEM_BYTES block_bytes = M * cols_per_pe * ELEM_BYTES
placement = column_wise(shape=(M, K), itemsize=ELEM_BYTES, num_pe=NUM_PE) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(M, K), itemsize=ELEM_BYTES,
num_pe=NUM_PE, num_cubes=1, target_sip=0,
)
handle = deploy_tensor( handle = deploy_tensor(
name="src", shape=(M, K), dtype="fp16", name="src", shape=(M, K), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc, placement=placement, allocators=allocators, va_allocator=va_alloc,
@@ -178,7 +182,11 @@ def test_1d_va_translates_to_local_hbm():
elems_per_pe = N_1D // NUM_PE elems_per_pe = N_1D // NUM_PE
block_bytes = elems_per_pe * ELEM_BYTES block_bytes = elems_per_pe * ELEM_BYTES
placement = column_wise(shape=(1, N_1D), itemsize=ELEM_BYTES, num_pe=NUM_PE) placement = resolve_dp_policy(
DPPolicy(cube="replicate", pe="column_wise"),
shape=(1, N_1D), itemsize=ELEM_BYTES,
num_pe=NUM_PE, num_cubes=1, target_sip=0,
)
handle = deploy_tensor( handle = deploy_tensor(
name="src_1d", shape=(N_1D,), dtype="fp16", name="src_1d", shape=(N_1D,), dtype="fp16",
placement=placement, allocators=allocators, va_allocator=va_alloc, placement=placement, allocators=allocators, va_allocator=va_alloc,
@@ -207,7 +215,9 @@ def test_1d_e2e_completes():
correlation_id="vo6", spec=graph.spec, correlation_id="vo6", spec=graph.spec,
) )
dp = DPPolicy(sip="column_wise", cube="column_wise", pe="column_wise") # ADR-0026: DPPolicy is intra-device only; SIP scoping comes from the
# RuntimeContext's target_device. This 1D e2e runs on a single SIP.
dp = DPPolicy(cube="column_wise", pe="column_wise")
src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d") src = ctx.zeros((N_1D,), dtype=DTYPE, dp=dp, name="src_1d")
dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d") dst = ctx.empty((N_1D,), dtype=DTYPE, dp=dp, name="dst_1d")
+1
View File
@@ -4,6 +4,7 @@ system:
sips: sips:
count: 2 count: 2
topology: ring_1d
components: components:
switch: { kind: switch, impl: builtin.switch, attrs: { overhead_ns: 5.0 } } switch: { kind: switch, impl: builtin.switch, attrs: { overhead_ns: 5.0 } }