"""Tests for CCL builtin topology generators (ADR-0023 D11).""" import pytest from kernbench.ccl.topologies import ( mesh_2d, none, resolve_topology, ring_1d, ring_1d_unidir, tree_binary, ) # ── ring_1d ────────────────────────────────────────────────────────── def test_ring_1d_4_ranks(): assert ring_1d(0, 4) == {"E": 1, "W": 3} assert ring_1d(1, 4) == {"E": 2, "W": 0} assert ring_1d(2, 4) == {"E": 3, "W": 1} assert ring_1d(3, 4) == {"E": 0, "W": 2} def test_ring_1d_2_ranks(): assert ring_1d(0, 2) == {"E": 1, "W": 1} assert ring_1d(1, 2) == {"E": 0, "W": 0} # ── ring_1d_unidir ─────────────────────────────────────────────────── def test_ring_1d_unidir(): assert ring_1d_unidir(0, 4) == {"E": 1} assert ring_1d_unidir(3, 4) == {"E": 0} # ── mesh_2d ────────────────────────────────────────────────────────── def test_mesh_2d_2x2(): # 2x2 mesh: # 0 1 # 2 3 assert mesh_2d(0, 4) == {"N": 2, "S": 2, "E": 1, "W": 1} assert mesh_2d(1, 4) == {"N": 3, "S": 3, "E": 0, "W": 0} assert mesh_2d(2, 4) == {"N": 0, "S": 0, "E": 3, "W": 3} assert mesh_2d(3, 4) == {"N": 1, "S": 1, "E": 2, "W": 2} def test_mesh_2d_4x4(): # 4x4 mesh: rank = r*4 + c n = mesh_2d(5, 16) # r=1, c=1 assert n["N"] == 1 # ((1-1)%4)*4 + 1 assert n["S"] == 9 # ((1+1)%4)*4 + 1 assert n["W"] == 4 # 1*4 + (1-1)%4 assert n["E"] == 6 # 1*4 + (1+1)%4 def test_mesh_2d_non_square_raises(): with pytest.raises(ValueError): mesh_2d(0, 5) # ── tree_binary ────────────────────────────────────────────────────── def test_tree_binary_root(): n = tree_binary(0, 7) assert "parent" not in n assert n["child_left"] == 1 assert n["child_right"] == 2 def test_tree_binary_internal(): n = tree_binary(1, 7) assert n["parent"] == 0 assert n["child_left"] == 3 assert n["child_right"] == 4 def test_tree_binary_leaf(): n = tree_binary(6, 7) assert n["parent"] == 2 assert "child_left" not in n assert "child_right" not in n # ── none ───────────────────────────────────────────────────────────── def test_none_returns_empty(): assert none(0, 4) == {} assert none(3, 7) == {} # ── resolve_topology ───────────────────────────────────────────────── def test_resolve_topology_builtin(): fn = resolve_topology("ring_1d") assert fn(0, 4) == {"E": 1, "W": 3} def test_resolve_topology_unknown_raises(): with pytest.raises(ValueError): resolve_topology("nonsense") def test_resolve_topology_with_neighbors_override_pattern_a(): """Algorithm module with neighbors() that mutates builtin map.""" class FakeModule: @staticmethod def neighbors(rank, world_size, neighbor_map): if rank % 2 == 1: neighbor_map.pop("W", None) return neighbor_map fn = resolve_topology("ring_1d", algo_module=FakeModule) assert fn(0, 4) == {"E": 1, "W": 3} assert fn(1, 4) == {"E": 2} # W removed def test_resolve_topology_with_neighbors_override_pattern_b(): """Algorithm module with neighbors() that returns brand-new dict.""" class FakeModule: @staticmethod def neighbors(rank, world_size, neighbor_map): return {"E": (rank + 2) % world_size} fn = resolve_topology("ring_1d", algo_module=FakeModule) assert fn(0, 4) == {"E": 2} assert fn(3, 4) == {"E": 1} def test_resolve_topology_with_neighbors_override_pattern_c_none(): """Algorithm module's neighbors() returns None → builtin used as-is.""" class FakeModule: @staticmethod def neighbors(rank, world_size, neighbor_map): return None fn = resolve_topology("ring_1d", algo_module=FakeModule) assert fn(0, 4) == {"E": 1, "W": 3} def test_resolve_topology_none_with_neighbors_override(): """topology=none + custom neighbors() builds from scratch.""" class FakeModule: @staticmethod def neighbors(rank, world_size, neighbor_map): assert neighbor_map == {} # builtin returned empty return {"E": (rank + 1) % world_size} fn = resolve_topology("none", algo_module=FakeModule) assert fn(0, 4) == {"E": 1} def test_resolve_topology_module_without_neighbors(): """Algorithm module without neighbors() function works normally.""" class FakeModule: pass # no neighbors attribute fn = resolve_topology("ring_1d", algo_module=FakeModule) assert fn(0, 4) == {"E": 1, "W": 3}