diff --git a/pyproject.toml b/pyproject.toml index ef6ba8e..8a5863f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "kernbench" version = "0.1.0" requires-python = ">=3.10" -dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"] +dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0", "matplotlib>=3.7"] [project.scripts] kernbench = "kernbench.cli.main:main" diff --git a/src/kernbench/ccl/algorithms/intercube_allreduce.py b/src/kernbench/ccl/algorithms/intercube_allreduce.py index 32be7cd..e68e598 100644 --- a/src/kernbench/ccl/algorithms/intercube_allreduce.py +++ b/src/kernbench/ccl/algorithms/intercube_allreduce.py @@ -24,9 +24,7 @@ TOPO_NAME_TO_KIND = { } -def kernel_args(world_size: int, n_elem: int) -> tuple: - cube_w = 4 - cube_h = 4 +def kernel_args(world_size: int, n_elem: int, *, cube_w: int = 4, cube_h: int = 4) -> tuple: return (n_elem, cube_w, cube_h, world_size) @@ -127,61 +125,79 @@ def allreduce_intercube_multidevice( row = cube_id // cube_w col = cube_id % cube_w nbytes = n_elem * 2 + single_cube = (cube_w == 1 and cube_h == 1) pe_addr = t_ptr + cube_id * nbytes acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16") - # ── Phase 1: row reduce W → E ── - if col == 0: - tl.send(dir="E", src=acc) - elif col < cube_w - 1: - recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") - acc = acc + recv - tl.send(dir="E", src=acc) + if single_cube: + # ── Single-cube mode: skip intra-SIP reduce, go directly to + # inter-SIP exchange (TP use case: one cube per rank). ── + if n_sips > 1: + if sip_topo_kind == SIP_TOPO_RING: + acc = _inter_sip_ring(acc, n_sips, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_TORUS: + acc = _inter_sip_torus_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_MESH: + acc = _inter_sip_mesh_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) else: - recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") - acc = acc + recv + # ── Multi-cube mode: full mesh reduce + inter-SIP + broadcast ── - # ── Phase 2: col reduce N → S on rightmost column ── - if col == cube_w - 1: - if row == 0: - tl.send(dir="S", src=acc) - elif row < cube_h - 1: - recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + # Phase 1: row reduce W → E + if col == 0: + tl.send(dir="E", src=acc) + elif col < cube_w - 1: + recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") acc = acc + recv - tl.send(dir="S", src=acc) + tl.send(dir="E", src=acc) else: - recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16") acc = acc + recv - # ── Phase 3: inter-SIP exchange on root cube ── - root_cube = (cube_h - 1) * cube_w + (cube_w - 1) - if cube_id == root_cube and n_sips > 1: - if sip_topo_kind == SIP_TOPO_RING: - acc = _inter_sip_ring(acc, n_sips, n_elem, tl) - elif sip_topo_kind == SIP_TOPO_TORUS: - acc = _inter_sip_torus_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) - elif sip_topo_kind == SIP_TOPO_MESH: - acc = _inter_sip_mesh_2d(acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + # Phase 2: col reduce N → S on rightmost column + if col == cube_w - 1: + if row == 0: + tl.send(dir="S", src=acc) + elif row < cube_h - 1: + recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + acc = acc + recv + tl.send(dir="S", src=acc) + else: + recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16") + acc = acc + recv - # ── Phase 4: col broadcast S → N on rightmost column ── - if col == cube_w - 1: - if row == cube_h - 1: - tl.send(dir="N", src=acc) - elif row > 0: - acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") - tl.send(dir="N", src=acc) + # Phase 3: inter-SIP exchange on root cube + root_cube = (cube_h - 1) * cube_w + (cube_w - 1) + if cube_id == root_cube and n_sips > 1: + if sip_topo_kind == SIP_TOPO_RING: + acc = _inter_sip_ring(acc, n_sips, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_TORUS: + acc = _inter_sip_torus_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + elif sip_topo_kind == SIP_TOPO_MESH: + acc = _inter_sip_mesh_2d( + acc, sip_rank, sip_topo_w, sip_topo_h, n_elem, tl) + + # Phase 4: col broadcast S → N on rightmost column + if col == cube_w - 1: + if row == cube_h - 1: + tl.send(dir="N", src=acc) + elif row > 0: + acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") + tl.send(dir="N", src=acc) + else: + acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") + + # Phase 5: row broadcast E → W + if col == cube_w - 1: + tl.send(dir="W", src=acc) + elif col > 0: + acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") + tl.send(dir="W", src=acc) else: - acc = tl.recv(dir="S", shape=(n_elem,), dtype="f16") - - # ── Phase 5: row broadcast E → W ── - if col == cube_w - 1: - tl.send(dir="W", src=acc) - elif col > 0: - acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") - tl.send(dir="W", src=acc) - else: - acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") + acc = tl.recv(dir="E", shape=(n_elem,), dtype="f16") tl.store(pe_addr, acc) diff --git a/src/kernbench/runtime_api/distributed.py b/src/kernbench/runtime_api/distributed.py index f87a268..a56086f 100644 --- a/src/kernbench/runtime_api/distributed.py +++ b/src/kernbench/runtime_api/distributed.py @@ -113,7 +113,18 @@ class AhbmCCLBackend: ) n_elem = shards[0].nbytes // tensor.itemsize kernel_fn = self._algo_module.kernel - kernel_args = self._algo_module.kernel_args(self._world_size, n_elem) + # Derive effective cube dims from tensor's actual shard placement + # (may differ from topology mesh when TP uses fewer cubes). + sip0_cubes = sorted({s.cube for s in shards if s.sip == shards[0].sip}) + eff_n_cubes = len(sip0_cubes) if sip0_cubes else 1 + if eff_n_cubes == 1: + eff_cube_w, eff_cube_h = 1, 1 + else: + eff_cube_w, eff_cube_h = self._cube_w, self._cube_h + kernel_args = self._algo_module.kernel_args( + self._world_size, n_elem, + cube_w=eff_cube_w, cube_h=eff_cube_h, + ) # Resolve sip_rank from the current greenlet's bound rank from greenlet import getcurrent as _gc diff --git a/tests/allreduce_latency_plots/mesh_2d_no_wrap.png b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png index 6a4d0ca..fe83aaf 100644 Binary files a/tests/allreduce_latency_plots/mesh_2d_no_wrap.png and b/tests/allreduce_latency_plots/mesh_2d_no_wrap.png differ diff --git a/tests/allreduce_latency_plots/overview.png b/tests/allreduce_latency_plots/overview.png index 0622007..172333c 100644 Binary files a/tests/allreduce_latency_plots/overview.png and b/tests/allreduce_latency_plots/overview.png differ diff --git a/tests/allreduce_latency_plots/ring_1d.png b/tests/allreduce_latency_plots/ring_1d.png index beb73fa..9cf75cb 100644 Binary files a/tests/allreduce_latency_plots/ring_1d.png and b/tests/allreduce_latency_plots/ring_1d.png differ diff --git a/tests/allreduce_latency_plots/summary.csv b/tests/allreduce_latency_plots/summary.csv index d40f782..4cfdd0d 100644 --- a/tests/allreduce_latency_plots/summary.csv +++ b/tests/allreduce_latency_plots/summary.csv @@ -1,37 +1,37 @@ -algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns -intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993 -intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55 -intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975 -intercube_allreduce,mesh_2d_no_wrap,6,128,256,4096,3544.049999999992 -intercube_allreduce,mesh_2d_no_wrap,6,512,1024,16384,3667.049999999992 -intercube_allreduce,mesh_2d_no_wrap,6,1024,2048,32768,3837.049999999992 -intercube_allreduce,mesh_2d_no_wrap,6,2048,4096,65536,4177.049999999992 -intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959 -intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945 -intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937 -intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872 -intercube_allreduce,mesh_2d_no_wrap,6,49152,98304,1572864,19817.049999999872 -intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937 -intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947 -intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992 -intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865 -intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865 -intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865 -intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865 -intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965 -intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957 -intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932 -intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903 -intercube_allreduce,ring_1d,6,49152,98304,1572864,18995.879999999917 -intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923 -intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993 -intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905 -intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985 -intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985 -intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985 -intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985 -intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777 -intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705 -intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952 -intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923 -intercube_allreduce,torus_2d,6,49152,98304,1572864,16195.479999999952 +algorithm,sip_topology,n_sips,n_elem,bytes_per_pe,bytes_per_sip,latency_ns +intercube_allreduce,mesh_2d_no_wrap,6,8,16,256,3508.4249999999993 +intercube_allreduce,mesh_2d_no_wrap,6,32,64,1024,3515.55 +intercube_allreduce,mesh_2d_no_wrap,6,64,128,2048,3525.0499999999975 +intercube_allreduce,mesh_2d_no_wrap,6,128,256,4096,3544.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,512,1024,16384,3667.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,1024,2048,32768,3837.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,2048,4096,65536,4177.049999999992 +intercube_allreduce,mesh_2d_no_wrap,6,4096,8192,131072,4857.049999999959 +intercube_allreduce,mesh_2d_no_wrap,6,8192,16384,262144,6217.049999999945 +intercube_allreduce,mesh_2d_no_wrap,6,16384,32768,524288,8937.049999999937 +intercube_allreduce,mesh_2d_no_wrap,6,32768,65536,1048576,14377.049999999872 +intercube_allreduce,mesh_2d_no_wrap,6,49152,98304,1572864,19817.049999999872 +intercube_allreduce,ring_1d,6,8,16,256,3073.1299999999937 +intercube_allreduce,ring_1d,6,32,64,1024,3079.8799999999947 +intercube_allreduce,ring_1d,6,64,128,2048,3088.879999999992 +intercube_allreduce,ring_1d,6,128,256,4096,3106.8799999999865 +intercube_allreduce,ring_1d,6,512,1024,16384,3225.8799999999865 +intercube_allreduce,ring_1d,6,1024,2048,32768,3391.8799999999865 +intercube_allreduce,ring_1d,6,2048,4096,65536,3723.8799999999865 +intercube_allreduce,ring_1d,6,4096,8192,131072,4387.879999999965 +intercube_allreduce,ring_1d,6,8192,16384,262144,5715.879999999957 +intercube_allreduce,ring_1d,6,16384,32768,524288,8371.879999999932 +intercube_allreduce,ring_1d,6,32768,65536,1048576,13683.879999999903 +intercube_allreduce,ring_1d,6,49152,98304,1572864,18995.879999999917 +intercube_allreduce,torus_2d,6,8,16,256,2190.4799999999923 +intercube_allreduce,torus_2d,6,32,64,1024,2196.479999999993 +intercube_allreduce,torus_2d,6,64,128,2048,2204.4799999999905 +intercube_allreduce,torus_2d,6,128,256,4096,2220.479999999985 +intercube_allreduce,torus_2d,6,512,1024,16384,2325.479999999985 +intercube_allreduce,torus_2d,6,1024,2048,32768,2471.479999999985 +intercube_allreduce,torus_2d,6,2048,4096,65536,2763.479999999985 +intercube_allreduce,torus_2d,6,4096,8192,131072,3347.4799999999777 +intercube_allreduce,torus_2d,6,8192,16384,262144,4515.4799999999705 +intercube_allreduce,torus_2d,6,16384,32768,524288,6851.479999999952 +intercube_allreduce,torus_2d,6,32768,65536,1048576,11523.479999999923 +intercube_allreduce,torus_2d,6,49152,98304,1572864,16195.479999999952 diff --git a/tests/allreduce_latency_plots/topology.png b/tests/allreduce_latency_plots/topology.png index 40e8719..1990768 100644 Binary files a/tests/allreduce_latency_plots/topology.png and b/tests/allreduce_latency_plots/topology.png differ diff --git a/tests/allreduce_latency_plots/torus_2d.png b/tests/allreduce_latency_plots/torus_2d.png index ce4b502..d689f24 100644 Binary files a/tests/allreduce_latency_plots/torus_2d.png and b/tests/allreduce_latency_plots/torus_2d.png differ diff --git a/tests/pe2pe_latency_plots/h1_intra_horizontal.png b/tests/pe2pe_latency_plots/h1_intra_horizontal.png index 23a4db0..22f7eb4 100644 Binary files a/tests/pe2pe_latency_plots/h1_intra_horizontal.png and b/tests/pe2pe_latency_plots/h1_intra_horizontal.png differ diff --git a/tests/pe2pe_latency_plots/h2_intra_vertical.png b/tests/pe2pe_latency_plots/h2_intra_vertical.png index a7af541..6ed9e58 100644 Binary files a/tests/pe2pe_latency_plots/h2_intra_vertical.png and b/tests/pe2pe_latency_plots/h2_intra_vertical.png differ diff --git a/tests/pe2pe_latency_plots/h3_inter_cube_horizontal.png b/tests/pe2pe_latency_plots/h3_inter_cube_horizontal.png index 94b9eef..99278a8 100644 Binary files a/tests/pe2pe_latency_plots/h3_inter_cube_horizontal.png and b/tests/pe2pe_latency_plots/h3_inter_cube_horizontal.png differ diff --git a/tests/pe2pe_latency_plots/h4_inter_cube_vertical.png b/tests/pe2pe_latency_plots/h4_inter_cube_vertical.png index 3f685da..0a89ee5 100644 Binary files a/tests/pe2pe_latency_plots/h4_inter_cube_vertical.png and b/tests/pe2pe_latency_plots/h4_inter_cube_vertical.png differ diff --git a/tests/pe2pe_latency_plots/overview.png b/tests/pe2pe_latency_plots/overview.png index 8914ae7..3aba2ad 100644 Binary files a/tests/pe2pe_latency_plots/overview.png and b/tests/pe2pe_latency_plots/overview.png differ diff --git a/tests/pe2pe_latency_plots/summary.csv b/tests/pe2pe_latency_plots/summary.csv index 7362353..03bb499 100644 --- a/tests/pe2pe_latency_plots/summary.csv +++ b/tests/pe2pe_latency_plots/summary.csv @@ -79,13 +79,3 @@ h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),8192,ipcq,181.659999 h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),8192,raw,183.04000000000087 h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,ipcq,205.65999999999985 h4_inter_cube_vertical,Inter-cube vertical (cube0 to cube4),10240,raw,207.04000000000087 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",128,ipcq,6.015000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",256,ipcq,6.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",384,ipcq,7.015000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",512,ipcq,7.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",768,ipcq,8.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",1024,ipcq,9.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",2048,ipcq,13.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",4096,ipcq,21.515000000003056 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",8192,ipcq,37.51499999999214 -h5_inter_sip,"Inter-SIP (sip0 to sip1, same cube/pe)",10240,ipcq,45.51499999999214