"""ADR-0009 D5: synchronized launch barrier. M_CPU stamps KernelLaunchMsg with target_start_ns = env.now + max path latency; PE_CPU yields until that time before recording pe_exec_start. Every PE in a single launch MUST begin kernel execution at the same env.now regardless of its dispatch path length. We verify this indirectly: for a no-op kernel, pe_exec_ns = env.now - pe_exec_start. If every PE's pe_exec_start is identical and every PE runs the same no-op body, every pe_exec_ns value must be identical. Without D5, pe_exec_start varies by dispatch-path length and so does pe_exec_ns. """ from __future__ import annotations from pathlib import Path import numpy as np from kernbench.policy.placement.dp import DPPolicy from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" def test_kernel_launch_sync_all_pes_have_equal_exec_time(): """No-op kernel: every PE's pe_exec_ns must be identical under D5.""" topo = resolve_topology(str(TOPOLOGY_PATH)) engine = GraphEngine(topo.topology_obj, enable_data=True) spec = topo.topology_obj.spec with RuntimeContext(engine=engine, target_device=DeviceSelector("all"), correlation_id="sync_test", spec=spec) as ctx: dp = DPPolicy(cube="row_wise", pe="column_wise", num_cubes=16, num_pes=8) def kernel(t_ptr, n_elem, tl): pass # no-op ctx.ahbm.set_device(0) t = ctx.zeros((16, 8 * 64), dtype="f16", dp=dp, name="probe") t.copy_(ctx.from_numpy(np.zeros((16, 8 * 64), dtype=np.float16))) pending = ctx.launch("sync_probe", kernel, t, 64, _defer_wait=True) for h, _sip, meta in pending: ctx.wait(h, _meta=meta) pe_exec_vals = [] for h, _sip, _meta in pending: _, trace = engine.get_completion(h) if trace and trace.get("pe_exec_ns") is not None: pe_exec_vals.append(float(trace["pe_exec_ns"])) assert pe_exec_vals, "expected completion traces with pe_exec_ns" spread = max(pe_exec_vals) - min(pe_exec_vals) assert spread < 1e-6, ( f"ADR-0009 D5 violated: pe_exec_ns spread across PEs = " f"{spread:.6f} ns (expected 0). Values: {pe_exec_vals}" )