from pathlib import Path from kernbench.topology.builder import load_topology from kernbench.topology.visualizer import emit_diagrams TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml" VIEW_FILES = ["system_view.svg", "sip_view.svg", "cube_view.svg", "pe_view.svg"] def _emit(tmp_path: Path) -> list[Path]: graph = load_topology(TOPOLOGY_PATH) return emit_diagrams(graph, tmp_path) def test_emit_creates_all_svg_files(tmp_path): created = _emit(tmp_path) assert len(created) == 4 for name in VIEW_FILES: assert (tmp_path / name).exists() assert (tmp_path / name).stat().st_size > 0 def test_svg_output_is_deterministic(tmp_path): graph = load_topology(TOPOLOGY_PATH) emit_diagrams(graph, tmp_path) first = {name: (tmp_path / name).read_text() for name in VIEW_FILES} emit_diagrams(graph, tmp_path) second = {name: (tmp_path / name).read_text() for name in VIEW_FILES} for name in VIEW_FILES: assert first[name] == second[name], f"{name} is not deterministic" def test_cube_svg_contains_hbm_ctrl(tmp_path): _emit(tmp_path) svg = (tmp_path / "cube_view.svg").read_text() assert "HBM CTRL" in svg def test_cube_svg_contains_ucie_ports(tmp_path): _emit(tmp_path) svg = (tmp_path / "cube_view.svg").read_text() for port in ("UCIe-N", "UCIe-S", "UCIe-W", "UCIe-E"): assert port in svg def test_cube_svg_contains_pe_nodes(tmp_path): _emit(tmp_path) svg = (tmp_path / "cube_view.svg").read_text() for i in range(8): assert f"PE{i}" in svg def test_pe_svg_contains_all_components(tmp_path): _emit(tmp_path) svg = (tmp_path / "pe_view.svg").read_text() for comp in ("PE CPU", "PE SCHEDULER", "PE DMA", "PE GEMM", "PE MATH", "PE TCM"): assert comp in svg def test_sip_svg_contains_cubes(tmp_path): _emit(tmp_path) svg = (tmp_path / "sip_view.svg").read_text() assert "CUBE (0,0)" in svg assert "CUBE (3,3)" in svg def test_system_svg_contains_switch_and_sips(tmp_path): _emit(tmp_path) svg = (tmp_path / "system_view.svg").read_text() assert "Fabric Switch" in svg assert "SIP 0" in svg assert "SIP 1" in svg def test_svg_is_valid_xml(tmp_path): _emit(tmp_path) for name in VIEW_FILES: svg = (tmp_path / name).read_text() assert svg.startswith("")