Files
kernbench2/tests/test_ccl_allreduce_matrix.py
T
ywkang 4ba0a83e71 Implement ADR-0024 Phase A: SIP-level TP launcher MVP
Scope (Phase A):
- D1: world_size fallback = SIP count (rank = SIP, TP boundary)
- D9: greenlet-local get_rank + _bind_rank (single-driver fallback = 0)
- D10: torch.ahbm.set_device + torch.accelerator.set_device_index alias
- D11: tensor placement scoped to current-device SIP (post-hoc pe_index
  shift — ADR-0026 replaces with structural coords)
- D12/D13: multi-greenlet run() with simple round-robin scheduler;
  hybrid dispatch (ws == SIP count → multi-greenlet, else legacy
  single-worker for ccl.yaml override compat)
- D7 partial: backend.all_reduce submit + yield + wait via launch()'s
  new _defer_wait flag; parent-less greenlets skip yield
- Relaxed shard-count check (len(shards) > 0 instead of == world_size)
- rank_to_pe = SIP-representative [(r, 0, 0)] when ws <= n_sips

Deferred to Phase B:
- Engine-routed install (D2) — keeps sideband
- install_plan.py module (D6) — keeps install.py
- Epoch barrier (D7 full) — simple yield is sufficient for ring ws=2 mock
- Validator registry (D8)
- Cross-SIP multi-greenlet + real kernel integration — matrix
  ring_default_ws hangs in SimPy drain despite ADR-0025 direction fix;
  marked xfail(run=False) pending Phase B diagnosis (suspected per-rank
  kernel_args / program_id mismatch)

Tests:
- test_ccl_ddp_launcher.py (6 new tests) — D1/D9/D10/D11/D12/D13
- test_ccl_allreduce_matrix.py — ring_default_ws xfail'd, override
  cases (ring_tcm_8 / hbm_8 / sram_8 / multi_cube / mesh_2x2 /
  tree_binary_7) all pass via legacy path

514 tests pass, 1 xfail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 09:00:28 -07:00

159 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""End-to-end matrix tests for the unified ``ccl_allreduce`` bench.
Each parametrized case writes a tmp ``ccl.yaml`` overlay that selects a
specific (algorithm, world_size, buffer_kind, n_elem) combination, then
runs the bench via the CLI and asserts the printed line reports all
ranks OK.
This single test file replaces the per-variant bench tests
(test_ccl_allreduce_e2e, test_ccl_mesh_allreduce, test_ccl_tree_allreduce,
test_ccl_multicube, test_ccl_multisip).
"""
from __future__ import annotations
import os
import textwrap
import pytest
import kernbench.cli.main as cli_main
CCL_YAML_TEMPLATE = textwrap.dedent("""\
defaults:
algorithm: {algorithm}
buffer_kind: {buffer_kind}
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
{algorithm}:
module: {module}
topology: {topology}
buffer_kind: {buffer_kind}
{world_size_line}{n_elem_line}
""")
def _write_ccl_yaml(
tmp_path,
*,
algorithm: str,
module: str,
topology: str,
buffer_kind: str = "tcm",
world_size: int | None = None,
n_elem: int | None = None,
) -> str:
"""Write a tmp ccl.yaml in tmp_path and return its directory."""
ws_line = f" world_size: {world_size}\n" if world_size is not None else ""
nel_line = f" n_elem: {n_elem}\n" if n_elem is not None else ""
body = CCL_YAML_TEMPLATE.format(
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size_line=ws_line,
n_elem_line=nel_line,
)
yaml_path = tmp_path / "ccl.yaml"
yaml_path.write_text(body)
return str(tmp_path)
CASES = [
# algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws
#
# Default fallback — no world_size override → ADR-0024 D1 derives
# from topology (SIP count = 2). Exercises the new SIP-level TP
# launcher + cross-SIP ring.
# XFAIL: ADR-0024 Phase A delivers launcher infrastructure; Phase B
# will finish cross-SIP ring kernel integration. Today this hangs in
# the SimPy drain despite ADR-0025's direction-addressing fix —
# suspected per-rank-tensor kernel_args / program_id mismatch under
# multi-greenlet dispatch. Separate Phase will diagnose.
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", None, 8, 2,
id="ring_default_ws",
marks=pytest.mark.xfail(
reason="ADR-0024 Phase B: cross-SIP multi-greenlet kernel integration",
run=False, # skip execution to avoid hang; revisit in Phase B
),
),
# Buffer variants at 8-rank (fast — same kernel, different slot space).
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 8, 32, 8,
id="ring_tcm_8",
),
pytest.param(
"ring_allreduce_hbm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "hbm", 8, 32, 8,
id="ring_hbm_8",
),
pytest.param(
"ring_allreduce_sram", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "sram", 8, 32, 8,
id="ring_sram_8",
),
# Multi-cube (16-rank, cross-cube within 1 SIP).
pytest.param(
"ring_allreduce_16", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 16, 16, 16,
id="ring_multi_cube",
),
# Mesh + tree algorithms.
pytest.param(
"mesh_allreduce_4", "kernbench.ccl.algorithms.mesh_allreduce",
"mesh_2d", "tcm", 4, 16, 4,
id="mesh_2x2",
),
pytest.param(
"tree_allreduce_7", "kernbench.ccl.algorithms.tree_allreduce",
"tree_binary", "tcm", 7, 16, 7,
id="tree_binary_7",
),
]
@pytest.mark.parametrize(
"algorithm,module,topology,buffer_kind,world_size,n_elem,expected_ws",
CASES,
)
def test_ccl_allreduce_matrix(
tmp_path, capsys, monkeypatch,
algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws,
):
"""Each (algorithm × buffer × world_size) combo passes through the
unified bench and yields all ranks OK."""
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")
)
yaml_dir = _write_ccl_yaml(
tmp_path,
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size=world_size,
n_elem=n_elem,
)
monkeypatch.chdir(yaml_dir)
rc = cli_main.main([
"run",
"--topology", os.path.join(project_root, "topology.yaml"),
"--bench", "ccl_allreduce",
"--verify-data",
])
assert rc == 0
out = capsys.readouterr().out
assert "FAIL" not in out, f"unexpected FAIL in output:\n{out}"
assert f"{algorithm} (ws={expected_ws}): {expected_ws} OK" in out, (
f"expected '{algorithm} (ws={expected_ws}): {expected_ws} OK' "
f"in output:\n{out}"
)