Files
kernbench2/tests/sccl/test_plot_buffer_kind_sweep.py
mukesh b610cb0d9a sccl: drive allreduce tests via torch.distributed; reorganize into tests/sccl/
Convert the multidevice allreduce correctness + latency/buffer-kind sweeps
to run through the real PyTorch-distributed path
(init_process_group(backend="ahbm") -> mp.spawn -> dist.all_reduce) instead
of direct ctx.launch, and reorganize the CCL/allreduce tests into a
tests/sccl/ package split one test per file.

Production change (required for the distributed path on non-square SIP grids):
- AhbmCCLBackend now reads explicit system.sips.w/h from the spec, with a
  square-only sqrt fallback that raises on ambiguity, instead of silently
  guessing round(sqrt(count)). This fixes the 2x3 / 3x2 torus + mesh cases,
  which previously resolved to a wrong 2x2 grid. Mirrors the test helper's
  _sip_topo_dims precedence (explicit w/h > square fallback > raise).

Test reorganization (tests/sccl/):
- _allreduce_helpers.py: shared plumbing (distributed driver, config writers,
  direct-launch run_allreduce parity reference, sweep/buffer-kind constants,
  plot aggregators, topology-diagram + FSIM-comparison emitters).
- test_allreduce_ring_torus_mesh.py: correctness across ring/torus/mesh.
- test_distributed_default_topology.py: full distributed path on topology.yaml.
- test_plot_latency_sweep.py / test_plot_buffer_kind_sweep.py: sweep rows.
- test_plot_topology_diagram.py / test_plot_comparison_fsim.py: plot emitters.
- test_intercube_root_center.py: moved in (ADR-0032 center-root latency guard).

Also:
- Move the FSIM comparison plot generator out of scripts/ into the sccl suite.
- Delete superseded test files (test_allreduce_multidevice,
  test_distributed_lrab_hierarchical_allreduce, test_allreduce_buffer_kind_sweep)
  and repoint conftest aggregators + the ipcq buffer-kind importers.
- Regenerate the allreduce_latency_plots derived artifacts from the full sweep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 22:24:43 -07:00

67 lines
2.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Buffer-kind sweep (TCM / SRAM / HBM) on torus_2d 6 SIPs (3×2), distributed.
Each parametrized case writes one JSON row; the conftest sessionfinish hook
calls ``aggregate_buffer_kind_plot`` to emit the comparison PNG + csv. Pre
slot-latency modeling the three lines overlap exactly (slot access is
latency-free today).
"""
from __future__ import annotations
import json
import pytest
import yaml
from tests.sccl._allreduce_helpers import (
_BK_ROWS_DIR,
_ELEM_BYTES_F16,
_bk_params,
_crit_ns,
_run_distributed,
_write_temp_configs,
)
@pytest.mark.parametrize("buffer_kind,n_elem", _bk_params())
def test_buffer_kind_allreduce_one(tmp_path, monkeypatch, buffer_kind, n_elem):
sub = tmp_path / f"{buffer_kind}_{n_elem}"
sub.mkdir()
topo_path, ccl_path = _write_temp_configs(
sub,
sip_topology="torus_2d",
n_sips=6,
algorithm="lrab_hierarchical_allreduce",
sip_w=3, sip_h=2,
n_elem_override=n_elem,
)
# Override buffer_kind in the temp ccl.yaml (read by the ahbm backend
# at init_process_group time via load_ccl_config()).
with open(ccl_path) as f:
ccl_cfg = yaml.safe_load(f)
ccl_cfg.setdefault("defaults", {})["buffer_kind"] = buffer_kind
ccl_cfg.setdefault("algorithms", {}).setdefault(
"lrab_hierarchical_allreduce", {},
)["buffer_kind"] = buffer_kind
with open(ccl_path, "w") as f:
yaml.dump(ccl_cfg, f, default_flow_style=False)
engine, _ = _run_distributed(
sub, monkeypatch, topo_path,
f"bk_sweep_{buffer_kind}_{n_elem}", n_elem,
)
crit_ns = _crit_ns(engine)
bytes_per_pe = n_elem * _ELEM_BYTES_F16
record = {
"buffer_kind": buffer_kind,
"sip_topology": "torus_2d",
"n_sips": 6,
"n_elem": n_elem,
"bytes_per_pe": bytes_per_pe,
"latency_ns": crit_ns,
}
_BK_ROWS_DIR.mkdir(parents=True, exist_ok=True)
row_path = _BK_ROWS_DIR / f"{buffer_kind}_{n_elem}.json"
with open(row_path, "w", encoding="utf-8") as f:
json.dump(record, f)