Files
kernbench2/tests/test_ccl_allreduce_matrix.py
T
ywkang 787409ced1 ADR-0024 Phase B: update xfail reason with architectural blocker details
Phase B Option A (freeze + defer to ADR-0027): the root cause of
ring_default_ws strict-xfail is that bench workers call torch.zeros /
copy_ which drive env.run in the WORKER-greenlet context. Any pending
KernelLaunchMsg gets stepped inside that worker, spawning kernel_runner
with parent = worker (not main). When the worker yields/finishes, the
kernel greenlet is orphaned and its next switch_to_simpy raises
GreenletExit mid-add — producing rank 0 mean=1 (expected 3).

This is a larger architectural redesign (lazy-deploy tensor API,
coroutine worker, or setup/verify split) and is parked until ADR-0027
(Megatron TP) starts, where the proper solution ships with TP use cases.

No production changes; xfail reason + inline comment only.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 12:46:33 -07:00

169 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""End-to-end matrix tests for the unified ``ccl_allreduce`` bench.
Each parametrized case writes a tmp ``ccl.yaml`` overlay that selects a
specific (algorithm, world_size, buffer_kind, n_elem) combination, then
runs the bench via the CLI and asserts the printed line reports all
ranks OK.
This single test file replaces the per-variant bench tests
(test_ccl_allreduce_e2e, test_ccl_mesh_allreduce, test_ccl_tree_allreduce,
test_ccl_multicube, test_ccl_multisip).
"""
from __future__ import annotations
import os
import textwrap
import pytest
import kernbench.cli.main as cli_main
CCL_YAML_TEMPLATE = textwrap.dedent("""\
defaults:
algorithm: {algorithm}
buffer_kind: {buffer_kind}
backpressure: sleep
n_slots: 4
slot_size: 4096
vc_chunk_size: 256
ipcq_credit_size_bytes: 16
algorithms:
{algorithm}:
module: {module}
topology: {topology}
buffer_kind: {buffer_kind}
{world_size_line}{n_elem_line}
""")
def _write_ccl_yaml(
tmp_path,
*,
algorithm: str,
module: str,
topology: str,
buffer_kind: str = "tcm",
world_size: int | None = None,
n_elem: int | None = None,
) -> str:
"""Write a tmp ccl.yaml in tmp_path and return its directory."""
ws_line = f" world_size: {world_size}\n" if world_size is not None else ""
nel_line = f" n_elem: {n_elem}\n" if n_elem is not None else ""
body = CCL_YAML_TEMPLATE.format(
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size_line=ws_line,
n_elem_line=nel_line,
)
yaml_path = tmp_path / "ccl.yaml"
yaml_path.write_text(body)
return str(tmp_path)
CASES = [
# algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws
#
# Default fallback — no world_size override → ADR-0024 D1 derives
# from topology (SIP count = 2). Exercises the new SIP-level TP
# launcher + cross-SIP ring.
# XFAIL — architectural blocker (ADR-0024 Phase B, future redesign):
# Bench workers call torch.zeros / copy_ which internally drive
# env.run in the WORKER-greenlet context. Any KernelLaunchMsg already
# pending in the SimPy queue gets stepped inside that worker context,
# which in turn spawns kernel_runner + kernel greenlet with parent =
# worker (not main). When the worker later yields / finishes, the
# kernel greenlet is orphaned; its next switch_to_simpy raises
# GreenletExit mid-add, producing rank 0 mean=1 (expected 3).
# Fix requires redesigning worker semantics so env.run only ever
# drives from main (options: lazy-deploy tensor API, coroutine
# worker, or setup/verify split). Not a single-PR change — parked
# until ADR-0027 (Megatron TP) starts, at which point a proper
# architectural solution lands together with TP use cases.
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", None, 8, 2,
id="ring_default_ws",
marks=pytest.mark.xfail(
reason="ADR-0024 Phase B: worker-greenlet env.run captures "
"kernel greenlet as child → orphaned on worker yield. "
"Needs architectural redesign (see test comment).",
strict=True,
),
),
# Buffer variants at 8-rank (fast — same kernel, different slot space).
pytest.param(
"ring_allreduce_tcm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 8, 32, 8,
id="ring_tcm_8",
),
pytest.param(
"ring_allreduce_hbm", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "hbm", 8, 32, 8,
id="ring_hbm_8",
),
pytest.param(
"ring_allreduce_sram", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "sram", 8, 32, 8,
id="ring_sram_8",
),
# Multi-cube (16-rank, cross-cube within 1 SIP).
pytest.param(
"ring_allreduce_16", "kernbench.ccl.algorithms.ring_allreduce",
"ring_1d", "tcm", 16, 16, 16,
id="ring_multi_cube",
),
# Mesh + tree algorithms.
pytest.param(
"mesh_allreduce_4", "kernbench.ccl.algorithms.mesh_allreduce",
"mesh_2d", "tcm", 4, 16, 4,
id="mesh_2x2",
),
pytest.param(
"tree_allreduce_7", "kernbench.ccl.algorithms.tree_allreduce",
"tree_binary", "tcm", 7, 16, 7,
id="tree_binary_7",
),
]
@pytest.mark.parametrize(
"algorithm,module,topology,buffer_kind,world_size,n_elem,expected_ws",
CASES,
)
def test_ccl_allreduce_matrix(
tmp_path, capsys, monkeypatch,
algorithm, module, topology, buffer_kind, world_size, n_elem, expected_ws,
):
"""Each (algorithm × buffer × world_size) combo passes through the
unified bench and yields all ranks OK."""
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")
)
yaml_dir = _write_ccl_yaml(
tmp_path,
algorithm=algorithm,
module=module,
topology=topology,
buffer_kind=buffer_kind,
world_size=world_size,
n_elem=n_elem,
)
monkeypatch.chdir(yaml_dir)
rc = cli_main.main([
"run",
"--topology", os.path.join(project_root, "topology.yaml"),
"--bench", "ccl_allreduce",
"--verify-data",
])
assert rc == 0
out = capsys.readouterr().out
assert "FAIL" not in out, f"unexpected FAIL in output:\n{out}"
assert f"{algorithm} (ws={expected_ws}): {expected_ws} OK" in out, (
f"expected '{algorithm} (ws={expected_ws}): {expected_ws} OK' "
f"in output:\n{out}"
)