Allreduce latency sweep across topologies and data sizes
Adds test_allreduce_latency_sweep that runs the existing intercube allreduce kernel under three SIP topologies (ring_1d, torus_2d, mesh_2d_no_wrap, all at n_sips=4) across 11 data sizes from 256 B/SIP up to 1 MB/SIP. For each point, captures max(pe_exec_ns) — the critical-path kernel time — and emits CSV plus log-x and linear-x plots, both per-topology and combined overview, with KB/MB-formatted tick labels. Reuses run_allreduce + _write_temp_configs and adds a slot_size auto-bump when n_elem*2 exceeds the default IPCQ slot. Sweep skips n_elem=16 because the runtime's dim_map scalar-arg remapping (context.py:761) collides any int-valued kernel scalar that matches a global tensor dim with its local shard size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -179,7 +179,9 @@ CONFIGS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm):
|
def _write_temp_configs(
|
||||||
|
tmp_path, sip_topology, n_sips, algorithm, n_elem_override=None,
|
||||||
|
):
|
||||||
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
|
"""Write temp topology.yaml and ccl.yaml with the given overrides."""
|
||||||
with open(TOPOLOGY_PATH) as f:
|
with open(TOPOLOGY_PATH) as f:
|
||||||
topo_cfg = yaml.safe_load(f)
|
topo_cfg = yaml.safe_load(f)
|
||||||
@@ -193,6 +195,15 @@ def _write_temp_configs(tmp_path, sip_topology, n_sips, algorithm):
|
|||||||
with open(ccl_path) as f:
|
with open(ccl_path) as f:
|
||||||
ccl_cfg = yaml.safe_load(f)
|
ccl_cfg = yaml.safe_load(f)
|
||||||
ccl_cfg["defaults"]["algorithm"] = algorithm
|
ccl_cfg["defaults"]["algorithm"] = algorithm
|
||||||
|
if n_elem_override is not None:
|
||||||
|
ccl_cfg.setdefault("algorithms", {}).setdefault(
|
||||||
|
algorithm, {},
|
||||||
|
)["n_elem"] = int(n_elem_override)
|
||||||
|
# Ensure IPCQ slot is big enough for the per-message payload.
|
||||||
|
per_msg_bytes = int(n_elem_override) * 2 # f16
|
||||||
|
default_slot = int(ccl_cfg["defaults"].get("slot_size", 4096))
|
||||||
|
if per_msg_bytes > default_slot:
|
||||||
|
ccl_cfg["defaults"]["slot_size"] = per_msg_bytes
|
||||||
tmp_ccl = tmp_path / "ccl.yaml"
|
tmp_ccl = tmp_path / "ccl.yaml"
|
||||||
with open(tmp_ccl, "w") as f:
|
with open(tmp_ccl, "w") as f:
|
||||||
yaml.dump(ccl_cfg, f, default_flow_style=False)
|
yaml.dump(ccl_cfg, f, default_flow_style=False)
|
||||||
@@ -220,3 +231,191 @@ def test_allreduce(tmp_path, algorithm, sip_topology, n_sips):
|
|||||||
algorithm=algorithm, ccl_yaml=ccl_path,
|
algorithm=algorithm, ccl_yaml=ccl_path,
|
||||||
)
|
)
|
||||||
assert result["ok_cubes"] > 0
|
assert result["ok_cubes"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Latency sweep ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# avoid 16 (== n_cubes, dim_map collision). Goes up to 1 MB per SIP:
|
||||||
|
# bytes_per_sip = n_cubes * n_elem * 2 = 32 * n_elem.
|
||||||
|
_SWEEP_N_ELEM = [
|
||||||
|
8, 32, 64, 128, 512, 1024, 2048,
|
||||||
|
4096, 8192, 16384, 32768,
|
||||||
|
]
|
||||||
|
_ELEM_BYTES_F16 = 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_allreduce_latency_sweep(tmp_path):
|
||||||
|
"""Sweep n_elem across each SIP topology; record max(pe_exec_ns)
|
||||||
|
as the critical-path kernel latency. Emits CSV + PNG plots to
|
||||||
|
tests/allreduce_latency_plots/.
|
||||||
|
"""
|
||||||
|
import csv
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import FuncFormatter
|
||||||
|
|
||||||
|
def _fmt_bytes(x, _pos):
|
||||||
|
"""Format tick as B / KB / MB."""
|
||||||
|
if x <= 0:
|
||||||
|
return "0"
|
||||||
|
if x >= 1024 * 1024:
|
||||||
|
return f"{x / (1024 * 1024):.0f} MB"
|
||||||
|
if x >= 1024:
|
||||||
|
return f"{x / 1024:.0f} KB"
|
||||||
|
return f"{x:.0f} B"
|
||||||
|
|
||||||
|
_bytes_fmt = FuncFormatter(_fmt_bytes)
|
||||||
|
|
||||||
|
out_dir = Path(__file__).parent / "allreduce_latency_plots"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
records: list[dict] = []
|
||||||
|
|
||||||
|
# Apples-to-apples: same n_sips across all three topologies.
|
||||||
|
for algorithm, sip_topology, n_sips in [
|
||||||
|
("intercube_allreduce", "ring_1d", 4),
|
||||||
|
("intercube_allreduce", "torus_2d", 4),
|
||||||
|
("intercube_allreduce", "mesh_2d_no_wrap", 4),
|
||||||
|
]:
|
||||||
|
for n_elem in _SWEEP_N_ELEM:
|
||||||
|
sub = tmp_path / f"{sip_topology}_{n_elem}"
|
||||||
|
sub.mkdir()
|
||||||
|
topo_path, ccl_path = _write_temp_configs(
|
||||||
|
sub, sip_topology, n_sips, algorithm,
|
||||||
|
n_elem_override=n_elem,
|
||||||
|
)
|
||||||
|
topo = resolve_topology(topo_path)
|
||||||
|
engine = GraphEngine(topo.topology_obj, enable_data=True)
|
||||||
|
spec = topo.topology_obj.spec
|
||||||
|
|
||||||
|
with RuntimeContext(
|
||||||
|
engine=engine,
|
||||||
|
target_device=DeviceSelector("all"),
|
||||||
|
correlation_id=f"sweep_{algorithm}_{sip_topology}_{n_elem}",
|
||||||
|
spec=spec,
|
||||||
|
) as ctx:
|
||||||
|
result = run_allreduce(
|
||||||
|
ctx, engine, spec,
|
||||||
|
algorithm=algorithm, ccl_yaml=ccl_path,
|
||||||
|
)
|
||||||
|
assert result["ok_cubes"] > 0
|
||||||
|
|
||||||
|
pe_exec_vals = [
|
||||||
|
float(tr.get("pe_exec_ns", 0.0) or 0.0)
|
||||||
|
for _, (_, tr) in engine._results.items()
|
||||||
|
if isinstance(tr, dict)
|
||||||
|
]
|
||||||
|
crit_ns = max(pe_exec_vals) if pe_exec_vals else 0.0
|
||||||
|
|
||||||
|
cm = spec["sip"]["cube_mesh"]
|
||||||
|
n_cubes = int(cm["w"]) * int(cm["h"])
|
||||||
|
bytes_per_sip = n_cubes * n_elem * _ELEM_BYTES_F16
|
||||||
|
# pe="replicate" + num_pes=1 → one active PE per cube owns
|
||||||
|
# the whole cube row. Per-PE bytes == per-cube-tile bytes ==
|
||||||
|
# per-message bytes over the IPCQ fabric.
|
||||||
|
bytes_per_pe = n_elem * _ELEM_BYTES_F16
|
||||||
|
|
||||||
|
records.append({
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"sip_topology": sip_topology,
|
||||||
|
"n_sips": n_sips,
|
||||||
|
"n_elem": n_elem,
|
||||||
|
"bytes_per_pe": bytes_per_pe,
|
||||||
|
"bytes_per_sip": bytes_per_sip,
|
||||||
|
"latency_ns": crit_ns,
|
||||||
|
})
|
||||||
|
print(
|
||||||
|
f"[{sip_topology:<16} n_sips={n_sips} n_elem={n_elem:>5} "
|
||||||
|
f"bytes/pe={bytes_per_pe:>7} bytes/sip={bytes_per_sip:>9}] "
|
||||||
|
f"pe_exec_max = {crit_ns:8.1f} ns"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(out_dir / "summary.csv", "w", newline="", encoding="utf-8") as f:
|
||||||
|
w = csv.DictWriter(f, fieldnames=[
|
||||||
|
"algorithm", "sip_topology", "n_sips", "n_elem",
|
||||||
|
"bytes_per_pe", "bytes_per_sip", "latency_ns",
|
||||||
|
])
|
||||||
|
w.writeheader()
|
||||||
|
for r in records:
|
||||||
|
w.writerow(r)
|
||||||
|
|
||||||
|
topologies = sorted({r["sip_topology"] for r in records})
|
||||||
|
# Per-topology plots: log-scale + linear-scale side-by-side.
|
||||||
|
# X-axis = bytes per PE (per-message payload size).
|
||||||
|
for topo_name in topologies:
|
||||||
|
rs = sorted(
|
||||||
|
[r for r in records if r["sip_topology"] == topo_name],
|
||||||
|
key=lambda r: r["bytes_per_pe"],
|
||||||
|
)
|
||||||
|
xs = [r["bytes_per_pe"] for r in rs]
|
||||||
|
ys = [r["latency_ns"] for r in rs]
|
||||||
|
title = (
|
||||||
|
f"Allreduce latency — {topo_name} "
|
||||||
|
f"(n_sips={rs[0]['n_sips']})"
|
||||||
|
)
|
||||||
|
# Log-scale
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5))
|
||||||
|
ax.plot(xs, ys, marker="o", color="tab:blue")
|
||||||
|
ax.set_xscale("log", base=2)
|
||||||
|
ax.set_xlabel("Bytes per PE (log scale)")
|
||||||
|
ax.set_ylabel("max pe_exec_ns (critical path)")
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.xaxis.set_major_formatter(_bytes_fmt)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out_dir / f"{topo_name}.png", dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
# Linear-scale companion
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5))
|
||||||
|
ax.plot(xs, ys, marker="o", color="tab:blue")
|
||||||
|
ax.set_xlabel("Bytes per PE")
|
||||||
|
ax.set_ylabel("max pe_exec_ns (critical path)")
|
||||||
|
ax.set_title(title + " [linear scale]")
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.xaxis.set_major_formatter(_bytes_fmt)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out_dir / f"{topo_name}_linear.png", dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
# Combined overview — two variants: log-scale (overview.png) and
|
||||||
|
# linear-scale (overview_linear.png).
|
||||||
|
colors = {"ring_1d": "tab:blue", "torus_2d": "tab:orange",
|
||||||
|
"mesh_2d_no_wrap": "tab:green"}
|
||||||
|
|
||||||
|
def _draw_overview(log_x: bool, filename: str, title_suffix: str) -> None:
|
||||||
|
fig, ax = plt.subplots(figsize=(9, 6))
|
||||||
|
for topo_name in topologies:
|
||||||
|
rs = sorted(
|
||||||
|
[r for r in records if r["sip_topology"] == topo_name],
|
||||||
|
key=lambda r: r["bytes_per_pe"],
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[r["bytes_per_pe"] for r in rs],
|
||||||
|
[r["latency_ns"] for r in rs],
|
||||||
|
marker="o",
|
||||||
|
label=f"{topo_name} (n_sips={rs[0]['n_sips']})",
|
||||||
|
color=colors.get(topo_name),
|
||||||
|
)
|
||||||
|
if log_x:
|
||||||
|
ax.set_xscale("log", base=2)
|
||||||
|
ax.set_xlabel("Bytes per PE (log scale)")
|
||||||
|
else:
|
||||||
|
ax.set_xlabel("Bytes per PE")
|
||||||
|
ax.set_ylabel("max pe_exec_ns (critical path)")
|
||||||
|
ax.set_title("Multi-device allreduce latency by topology" + title_suffix)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend()
|
||||||
|
ax.xaxis.set_major_formatter(_bytes_fmt)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out_dir / filename, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
_draw_overview(log_x=True, filename="overview.png", title_suffix="")
|
||||||
|
_draw_overview(
|
||||||
|
log_x=False, filename="overview_linear.png",
|
||||||
|
title_suffix=" [linear scale]",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nWrote {out_dir / 'overview.png'} + "
|
||||||
|
f"{out_dir / 'overview_linear.png'}"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user