e9cc40f74d
mesh_2d, torus_2d, and mesh_2d_no_wrap accept optional w,h kwargs; sqrt fall-back preserved for square layouts (back-compat tests confirm 4-SIP and 9-SIP square configs still work). sfr_config reads system.sips.w/h from spec and threads dims through to the topology fn. test_allreduce_multidevice CONFIGS switched from 4 SIPs (square) to 6 SIPs: ring_1d_6sip, torus_2d_6sip_2x3, mesh_2d_no_wrap_6sip_2x3. _write_temp_configs writes system.sips.w/h when supplied; _sip_topo_dims reads them back. Latency sweep loop also moved to 6-SIP layouts. Linear-scale plot variants dropped -- only log-scale *.png + summary.csv emitted. Plots in tests/allreduce_latency_plots regenerated. New tests/test_sip_topology_rectangular.py asserts neighbor correctness for 2x3 layouts and back-compat for square fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
430 lines
14 KiB
Python
430 lines
14 KiB
Python
"""Config-driven multi-device allreduce test application.
|
|
|
|
Reads ``ccl.yaml`` + ``topology.yaml``, dynamically loads the kernel
|
|
module from ``ccl.yaml → module``, and picks the inter-SIP exchange
|
|
pattern from ``topology.yaml → system.sips.topology``.
|
|
|
|
Run directly::
|
|
|
|
python -m pytest tests/allreduce_app.py -v -s
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
|
|
from kernbench.ccl.sfr_config import configure_sfr_intercube_multisip
|
|
from kernbench.policy.placement.dp import DPPolicy
|
|
|
|
|
|
def _sip_topo_dims(
|
|
sip_topo: str, n_sips: int,
|
|
spec_w: int | None = None, spec_h: int | None = None,
|
|
) -> tuple[int, int]:
|
|
if sip_topo == "ring_1d":
|
|
return (0, 0)
|
|
if spec_w is not None and spec_h is not None:
|
|
if spec_w * spec_h != n_sips:
|
|
raise ValueError(
|
|
f"sip layout {spec_w}x{spec_h} != n_sips ({n_sips})"
|
|
)
|
|
return (spec_w, spec_h)
|
|
side = int(round(math.sqrt(n_sips)))
|
|
if side * side != n_sips:
|
|
raise ValueError(
|
|
f"SIP topology '{sip_topo}' requires square n_sips or "
|
|
f"explicit w/h in spec, got {n_sips}"
|
|
)
|
|
return (side, side)
|
|
|
|
|
|
def run_allreduce(
|
|
ctx: Any,
|
|
engine: Any,
|
|
spec: dict,
|
|
*,
|
|
algorithm: str | None = None,
|
|
ccl_yaml: str | None = None,
|
|
) -> dict:
|
|
"""Config-driven allreduce: read yaml, load kernel, run.
|
|
|
|
Everything is resolved from config — no hardcoded kernel imports.
|
|
"""
|
|
cfg_all = load_ccl_config(ccl_yaml)
|
|
cfg = resolve_algorithm_config(cfg_all, algorithm)
|
|
|
|
# Dynamic import from ccl.yaml → module
|
|
algo_module = importlib.import_module(cfg["module"])
|
|
kernel_fn = algo_module.kernel
|
|
topo_name_to_kind = algo_module.TOPO_NAME_TO_KIND
|
|
|
|
n_elem = int(cfg.get("n_elem", 8))
|
|
sips_cfg = spec.get("system", {}).get("sips", {})
|
|
n_sips = int(sips_cfg.get("count", 1))
|
|
sip_topo = str(sips_cfg.get("topology", "ring_1d"))
|
|
spec_sip_w = sips_cfg.get("w")
|
|
spec_sip_h = sips_cfg.get("h")
|
|
spec_sip_w = int(spec_sip_w) if spec_sip_w is not None else None
|
|
spec_sip_h = int(spec_sip_h) if spec_sip_h is not None else None
|
|
|
|
cm = spec["sip"]["cube_mesh"]
|
|
cube_w = int(cm["w"])
|
|
cube_h = int(cm["h"])
|
|
n_cubes = cube_w * cube_h
|
|
|
|
sip_topo_kind = topo_name_to_kind.get(sip_topo, 0)
|
|
sip_topo_w, sip_topo_h = _sip_topo_dims(
|
|
sip_topo, n_sips, spec_w=spec_sip_w, spec_h=spec_sip_h,
|
|
)
|
|
|
|
algo_name = cfg.get("algorithm", "allreduce")
|
|
print(f"\n{'=' * 60}")
|
|
print(f"algorithm: {algo_name}")
|
|
print(f"module: {cfg['module']}")
|
|
print(f"sip_topology: {sip_topo}")
|
|
print(f"kernel: {kernel_fn.__name__}")
|
|
print(f"n_sips: {n_sips}")
|
|
print(f"n_cubes: {n_cubes}")
|
|
print(f"n_elem: {n_elem}")
|
|
print(f"{'=' * 60}")
|
|
|
|
configure_sfr_intercube_multisip(engine, spec, cfg)
|
|
|
|
dp = DPPolicy(
|
|
cube="row_wise", pe="replicate",
|
|
num_pes=1, num_cubes=n_cubes,
|
|
)
|
|
|
|
tensors = []
|
|
for sip in range(n_sips):
|
|
ctx.ahbm.set_device(sip)
|
|
t = ctx.zeros(
|
|
(n_cubes, n_elem), dtype="f16", dp=dp,
|
|
name=f"sip{sip}",
|
|
)
|
|
t.copy_(ctx.from_numpy(
|
|
np.full((n_cubes, n_elem), float(sip + 1), dtype=np.float16)
|
|
))
|
|
tensors.append(t)
|
|
|
|
for sip in range(n_sips):
|
|
arr = tensors[sip].numpy()
|
|
print(f"[SIP {sip}] input cube0[:4] = {arr[0][:4].tolist()} "
|
|
f"cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
|
|
|
t_start = engine._env.now
|
|
|
|
all_pending = []
|
|
for sip_rank, t in enumerate(tensors):
|
|
pending = ctx.launch(
|
|
algo_name, kernel_fn, t,
|
|
n_elem, cube_w, cube_h, n_sips, sip_rank,
|
|
sip_topo_kind, sip_topo_w, sip_topo_h,
|
|
_defer_wait=True,
|
|
)
|
|
all_pending.extend(pending)
|
|
|
|
for h, sip_id, meta in all_pending:
|
|
ctx.wait(h, _meta=meta)
|
|
|
|
t_end = engine._env.now
|
|
latency_ns = t_end - t_start
|
|
print(f"\n[{algo_name} ws={n_sips}] sim latency = "
|
|
f"{latency_ns:.1f} ns ({latency_ns / 1000:.3f} us)")
|
|
|
|
for key, (_, trace) in engine._results.items():
|
|
if not isinstance(trace, dict):
|
|
continue
|
|
total = trace.get("total_ns", 0.0)
|
|
pe_exec = trace.get("pe_exec_ns", 0.0) or 0.0
|
|
network = total - pe_exec
|
|
print(f" [{key}] total={total:.1f} ns "
|
|
f"pe_exec={pe_exec:.1f} ns network={network:.1f} ns")
|
|
|
|
expected = float(n_cubes * sum(range(1, n_sips + 1)))
|
|
|
|
print()
|
|
for sip in range(n_sips):
|
|
arr = tensors[sip].numpy()
|
|
print(f"[SIP {sip}] output cube0[:4] = {arr[0][:4].tolist()}")
|
|
print(f"[SIP {sip}] output cube{n_cubes - 1}[:4] = {arr[-1][:4].tolist()}")
|
|
|
|
ok_cubes = 0
|
|
for sip in range(n_sips):
|
|
arr = tensors[sip].numpy()
|
|
for cube_id in range(n_cubes):
|
|
assert np.allclose(
|
|
arr[cube_id], expected, rtol=1e-1, atol=1e-1,
|
|
), (
|
|
f"SIP{sip} cube {cube_id}: "
|
|
f"got {arr[cube_id][:4]}, expected {expected}"
|
|
)
|
|
ok_cubes += 1
|
|
|
|
print(f"\n {algo_name} (ws={n_sips}): {ok_cubes} OK")
|
|
|
|
return {
|
|
"expected": expected,
|
|
"latency_ns": latency_ns,
|
|
"ok_cubes": ok_cubes,
|
|
}
|
|
|
|
|
|
# ── pytest entry point ───────────────────────────────────────────────
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from kernbench.runtime_api.context import RuntimeContext
|
|
from kernbench.runtime_api.types import DeviceSelector
|
|
from kernbench.sim_engine.engine import GraphEngine
|
|
from kernbench.topology.builder import resolve_topology
|
|
|
|
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
|
|
|
CONFIGS = [
|
|
pytest.param(
|
|
"intercube_allreduce", "ring_1d", 6, None, None,
|
|
id="ring_6sip",
|
|
),
|
|
pytest.param(
|
|
"intercube_allreduce", "torus_2d", 6, 2, 3,
|
|
id="torus_6sip_2x3",
|
|
),
|
|
pytest.param(
|
|
"intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3,
|
|
id="mesh_6sip_2x3",
|
|
),
|
|
]
|
|
|
|
|
|
def _write_temp_configs(
|
|
tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None,
|
|
sip_w=None, sip_h=None,
|
|
):
|
|
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
|
|
with open(TOPOLOGY_PATH) as f:
|
|
topo_cfg = yaml.safe_load(f)
|
|
topo_cfg["system"]["sips"]["count"] = n_sips
|
|
topo_cfg["system"]["sips"]["topology"] = sip_topology
|
|
if sip_w is not None and sip_h is not None:
|
|
topo_cfg["system"]["sips"]["w"] = int(sip_w)
|
|
topo_cfg["system"]["sips"]["h"] = int(sip_h)
|
|
else:
|
|
topo_cfg["system"]["sips"].pop("w", None)
|
|
topo_cfg["system"]["sips"].pop("h", None)
|
|
topo_path = tmp_path / "topology.yaml"
|
|
with open(topo_path, "w") as f:
|
|
yaml.dump(topo_cfg, f, default_flow_style=False)
|
|
|
|
ccl_path = Path(__file__).parent.parent / "ccl.yaml"
|
|
with open(ccl_path) as f:
|
|
ccl_cfg = yaml.safe_load(f)
|
|
ccl_cfg["defaults"]["algorithm"] = algorithm
|
|
if n_elem_override is not None:
|
|
ccl_cfg.setdefault("algorithms", {}).setdefault(
|
|
algorithm, {},
|
|
)["n_elem"] = int(n_elem_override)
|
|
# Ensure IPCQ slot is big enough for the per-message payload.
|
|
per_msg_bytes = int(n_elem_override) * 2 # f16
|
|
default_slot = int(ccl_cfg["defaults"].get("slot_size", 4096))
|
|
if per_msg_bytes > default_slot:
|
|
ccl_cfg["defaults"]["slot_size"] = per_msg_bytes
|
|
tmp_ccl = tmp_path / "ccl.yaml"
|
|
with open(tmp_ccl, "w") as f:
|
|
yaml.dump(ccl_cfg, f, default_flow_style=False)
|
|
|
|
return str(topo_path), str(tmp_ccl)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"algorithm,sip_topology,n_sips,sip_w,sip_h", CONFIGS,
|
|
)
|
|
def test_allreduce(
|
|
tmp_path, algorithm, sip_topology, n_sips, sip_w, sip_h,
|
|
):
|
|
topo_path, ccl_path = _write_temp_configs(
|
|
tmp_path, sip_topology, n_sips, algorithm,
|
|
sip_w=sip_w, sip_h=sip_h,
|
|
)
|
|
topo = resolve_topology(topo_path)
|
|
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
|
spec = topo.topology_obj.spec
|
|
|
|
with RuntimeContext(
|
|
engine=engine,
|
|
target_device=DeviceSelector("all"),
|
|
correlation_id=f"test_{algorithm}_{sip_topology}",
|
|
spec=spec,
|
|
) as ctx:
|
|
result = run_allreduce(
|
|
ctx, engine, spec,
|
|
algorithm=algorithm, ccl_yaml=ccl_path,
|
|
)
|
|
assert result["ok_cubes"] > 0
|
|
|
|
|
|
# ── Latency sweep ─────────────────────────────────────────────────────
|
|
|
|
# avoid 16 (== n_cubes, dim_map collision). Goes up to 1 MB per SIP:
|
|
# bytes_per_sip = n_cubes * n_elem * 2 = 32 * n_elem.
|
|
_SWEEP_N_ELEM = [
|
|
8, 32, 64, 128, 512, 1024, 2048,
|
|
4096, 8192, 16384, 32768,
|
|
]
|
|
_ELEM_BYTES_F16 = 2
|
|
|
|
|
|
def test_allreduce_latency_sweep(tmp_path):
|
|
"""Sweep n_elem across each SIP topology; record max(pe_exec_ns)
|
|
as the critical-path kernel latency. Emits CSV + PNG plots to
|
|
tests/allreduce_latency_plots/.
|
|
"""
|
|
import csv
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.ticker import FuncFormatter
|
|
|
|
def _fmt_bytes(x, _pos):
|
|
"""Format tick as B / KB / MB."""
|
|
if x <= 0:
|
|
return "0"
|
|
if x >= 1024 * 1024:
|
|
return f"{x / (1024 * 1024):.0f} MB"
|
|
if x >= 1024:
|
|
return f"{x / 1024:.0f} KB"
|
|
return f"{x:.0f} B"
|
|
|
|
_bytes_fmt = FuncFormatter(_fmt_bytes)
|
|
|
|
out_dir = Path(__file__).parent / "allreduce_latency_plots"
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
records: list[dict] = []
|
|
|
|
# Apples-to-apples: same n_sips across all three topologies.
|
|
for algorithm, sip_topology, n_sips, sip_w, sip_h in [
|
|
("intercube_allreduce", "ring_1d", 6, None, None),
|
|
("intercube_allreduce", "torus_2d", 6, 2, 3),
|
|
("intercube_allreduce", "mesh_2d_no_wrap", 6, 2, 3),
|
|
]:
|
|
for n_elem in _SWEEP_N_ELEM:
|
|
sub = tmp_path / f"{sip_topology}_{n_elem}"
|
|
sub.mkdir()
|
|
topo_path, ccl_path = _write_temp_configs(
|
|
sub, sip_topology, n_sips, algorithm,
|
|
sip_w=sip_w, sip_h=sip_h,
|
|
n_elem_override=n_elem,
|
|
)
|
|
topo = resolve_topology(topo_path)
|
|
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
|
spec = topo.topology_obj.spec
|
|
|
|
with RuntimeContext(
|
|
engine=engine,
|
|
target_device=DeviceSelector("all"),
|
|
correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}",
|
|
spec=spec,
|
|
) as ctx:
|
|
result = run_allreduce(
|
|
ctx, engine, spec,
|
|
algorithm=algorithm, ccl_yaml=ccl_path,
|
|
)
|
|
assert result["ok_cubes"] > 0
|
|
|
|
pe_exec_vals = [
|
|
float(tr.get("pe_exec_ns", 0.0) or 0.0)
|
|
for _, (_, tr) in engine._results.items()
|
|
if isinstance(tr, dict)
|
|
]
|
|
crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0
|
|
|
|
cm = spec["sip"]["cube_mesh"]
|
|
n_cubes = int(cm["w"]) * int(cm["h"])
|
|
bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16
|
|
# pe="replicate" + num_pes=1 → one active PE per cube owns
|
|
# the whole cube row. Per-PE bytes == per-cube-tile bytes ==
|
|
# per-message bytes over the IPCQ fabric.
|
|
bytes_per_pe = n_elem * _ELEM_BYTES_F16
|
|
|
|
records.append({
|
|
"algorithm": algorithm,
|
|
"sip_topology": sip_topology,
|
|
"n_sips": n_sips,
|
|
"n_elem": n_elem,
|
|
"bytes_per_pe": bytes_per_pe,
|
|
"bytes_per_sip": bytes_per_sip,
|
|
"latency_ns": crit_ns,
|
|
})
|
|
print(
|
|
f"[{sip_topology:<16} n_sips={n_sips} n_elem={n_elem:>5} "
|
|
f"bytes/pe={bytes_per_pe:>7} bytes/sip={bytes_per_sip:>9}] "
|
|
f"pe_exec_max = {crit_ns:8.1f} ns"
|
|
)
|
|
|
|
with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f:
|
|
w = csv.DictWriter(f, fieldnames=[
|
|
"algorithm", "sip_topology", "n_sips", "n_elem",
|
|
"bytes_per_pe", "bytes_per_sip", "latency_ns",
|
|
])
|
|
w.writeheader()
|
|
for r in records:
|
|
w.writerow(r)
|
|
|
|
topologies = sorted({r["sip_topology"] for r in records})
|
|
# Per-topology plots, log-scale x-axis = bytes per PE.
|
|
for topo_name in topologies:
|
|
rs = sorted(
|
|
[r for r in records if r["sip_topology"] == topo_name],
|
|
key=lambda r: r["bytes_per_pe"],
|
|
)
|
|
xs = [r["bytes_per_pe"] for r in rs]
|
|
ys = [r["latency_ns"] for r in rs]
|
|
title = (
|
|
f"Allreduce latency — {topo_name} "
|
|
f"(n_sips={rs[0]['n_sips']})"
|
|
)
|
|
fig, ax = plt.subplots(figsize=(8, 5))
|
|
ax.plot(xs, ys, marker="o", color="tab:blue")
|
|
ax.set_xscale("log", base=2)
|
|
ax.set_xlabel("Bytes per PE (log scale)")
|
|
ax.set_ylabel("max pe_exec_ns (critical path)")
|
|
ax.set_title(title)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.xaxis.set_major_formatter(_bytes_fmt)
|
|
fig.tight_layout()
|
|
fig.savefig(out_dir / f"{topo_name}.png", dpi=120)
|
|
plt.close(fig)
|
|
|
|
colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange",
|
|
"mesh_2d_no_wrap": "tab:green"}
|
|
fig, ax = plt.subplots(figsize=(9, 6))
|
|
for topo_name in topologies:
|
|
rs = sorted(
|
|
[r for r in records if r["sip_topology"] == topo_name],
|
|
key=lambda r: r["bytes_per_pe"],
|
|
)
|
|
ax.plot(
|
|
[r["bytes_per_pe"] for r in rs],
|
|
[r["latency_ns"] for r in rs],
|
|
marker="o",
|
|
label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
|
|
color=colors.get(topo_name),
|
|
)
|
|
ax.set_xscale("log", base=2)
|
|
ax.set_xlabel("Bytes per PE (log scale)")
|
|
ax.set_ylabel("max pe_exec_ns (critical path)")
|
|
ax.set_title("Multi-device allreduce latency by topology")
|
|
ax.grid(True, alpha=0.3)
|
|
ax.legend()
|
|
ax.xaxis.set_major_formatter(_bytes_fmt)
|
|
fig.tight_layout()
|
|
fig.savefig(out_dir / "overview.png", dpi=120)
|
|
plt.close(fig)
|
|
|
|
print(f"\nWrote {out_dir / 'overview.png'}")
|