Files
kernbench2/tests/sccl/_allreduce_helpers.py
mukesh cc1bbd0ab7 eval: fold GEMM/allreduce harnesses into self-contained milestone benches
Move the GEMM + allreduce sweep/render logic out of scripts/ and tests/
into two self-contained eval benches so a user can regenerate every
result + figure with one command:

  kernbench run --bench milestone-1h-gemm   (MILESTONE_FAST=1 reuses JSON)
  kernbench run --bench milestone-1h-ccl

- benches/milestone_1h_{gemm,ccl}.py: single home for each domain; the
  run(torch) entry drives the sweeps and writes figures into
  benches/1H_milestone_output/{gemm,ccl}/ (gitignored), then submits a
  sentinel tensor to satisfy the run_bench contract.
- tests/gemm + tests/sccl helpers and scripts/gemm_sweep.py become thin
  re-export/wrapper shims over the benches (single source preserved); the
  pytest-only param builders + _run_distributed wrapper stay in the shim.
- eval-bench pattern: a bench may drive many configs + build its own
  per-config engines (extends ADR-0045 D5; reverses ADR-0044 D1/D2).

ADR-0054 (EN+KO) records the design; ADR-0043/0044/0045 + CLAUDE.md CLI
Semantics amended; ADR INDEX regenerated. Verified: milestone benches run
clean (ok=True, all artifacts), full suite 67 passed, lang-pairs OK.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 15:19:52 -07:00

113 lines
3.2 KiB
Python

"""Thin re-export shim for the sccl allreduce tests.
Not a test module (no ``test_`` prefix → pytest does not collect it).
The driver core, config writers, sweep constants, renderers, aggregators,
topology-diagram + FSIM-comparison emitters, and the direct-launch parity
reference now live in ``kernbench.benches.milestone_1h_ccl`` (production
single home, ADR-0054). This shim re-exports them and keeps
the pytest-specific pieces local: the ``pytest.param`` matrices
(``CONFIGS`` / ``_sweep_params`` / ``_bk_params``) and the fixture-coupled
``_run_distributed`` wrapper. Behavior is unchanged (defaults still target
``docs/diagrams/allreduce_latency_plots/``).
"""
from __future__ import annotations
import pytest
from kernbench.benches.milestone_1h_ccl import (
DEFAULT_N_ELEM,
TOPOLOGY_PATH,
_aggregate_sweep_plots,
_BK_N_ELEM_GRID,
_BK_ROWS_DIR,
_BUFFER_KINDS,
_crit_ns,
_drive_distributed,
_ELEM_BYTES_F16,
_SWEEP_N_ELEM,
_SWEEP_OUT_DIR,
_SWEEP_ROWS_DIR,
_SWEEP_TOPOLOGIES,
_worker,
_write_ccl_yaml,
_write_temp_configs,
aggregate_buffer_kind_plot,
emit_comparison_fsim_plot,
emit_topology_diagram,
run_allreduce,
)
__all__ = [
"CONFIGS",
"DEFAULT_N_ELEM",
"TOPOLOGY_PATH",
"_BK_ROWS_DIR",
"_ELEM_BYTES_F16",
"_SWEEP_OUT_DIR",
"_SWEEP_ROWS_DIR",
"_aggregate_sweep_plots",
"_bk_params",
"_crit_ns",
"_run_distributed",
"_sweep_params",
"_worker",
"_write_ccl_yaml",
"_write_temp_configs",
"aggregate_buffer_kind_plot",
"emit_comparison_fsim_plot",
"emit_topology_diagram",
"run_allreduce",
]
# ── pytest-coupled distributed driver wrapper ──────────────────────────
def _run_distributed(tmp_path, monkeypatch, topo_path, correlation_id, n_elem):
"""Fixture-coupled wrapper: chdir via monkeypatch, then drive.
``monkeypatch.chdir`` points the backend's ``load_ccl_config()`` (cwd
lookup) at the temp ``ccl.yaml`` and auto-restores cwd at test teardown.
Returns ``(engine, n_cubes)``.
"""
monkeypatch.chdir(tmp_path)
return _drive_distributed(topo_path, correlation_id, n_elem)
# ── pytest.param matrices ──────────────────────────────────────────────
CONFIGS = [
pytest.param(
"lrab_hierarchical_allreduce", "ring_1d", 6, None, None,
id="ring_6sip",
),
pytest.param(
"lrab_hierarchical_allreduce", "torus_2d", 6, 2, 3,
id="torus_6sip_2x3",
),
pytest.param(
"lrab_hierarchical_allreduce", "mesh_2d_no_wrap", 6, 2, 3,
id="mesh_6sip_2x3",
),
]
def _sweep_params():
out = []
for algorithm, sip_topology, n_sips, sip_w, sip_h in _SWEEP_TOPOLOGIES:
for n_elem in _SWEEP_N_ELEM:
out.append(pytest.param(
algorithm, sip_topology, n_sips, sip_w, sip_h, n_elem,
id=f"{sip_topology}-n_elem{n_elem}",
))
return out
def _bk_params():
out = []
for bk in _BUFFER_KINDS:
for n_elem in _BK_N_ELEM_GRID:
out.append(pytest.param(bk, n_elem, id=f"{bk}-n_elem{n_elem}"))
return out