diff --git a/src/kernbench/cli/main.py b/src/kernbench/cli/main.py index 958a50f..04dc594 100644 --- a/src/kernbench/cli/main.py +++ b/src/kernbench/cli/main.py @@ -1,5 +1,6 @@ import argparse import sys +from typing import cast from kernbench.benches.registry import list_all, resolve from kernbench.cli.report import format_report @@ -9,6 +10,7 @@ from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import DeviceSelector, resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology +from kernbench.topology.types import TopologyGraph def build_parser() -> argparse.ArgumentParser: @@ -49,7 +51,7 @@ def build_parser() -> argparse.ArgumentParser: def engine_factory( topology: object, device: DeviceSelector, *, enable_data: bool = False, ) -> SimEngine: - topo_obj = getattr(topology, "topology_obj", topology) + topo_obj = cast(TopologyGraph, getattr(topology, "topology_obj", topology)) return GraphEngine(topo_obj, enable_data=enable_data) diff --git a/src/kernbench/runtime_api/types.py b/src/kernbench/runtime_api/types.py index 3654a82..ec0972a 100644 --- a/src/kernbench/runtime_api/types.py +++ b/src/kernbench/runtime_api/types.py @@ -2,9 +2,13 @@ from __future__ import annotations import re from dataclasses import dataclass +from typing import TYPE_CHECKING from kernbench.common.types import Completion, Trace +if TYPE_CHECKING: + from kernbench.sim_engine.engine import GraphEngine + @dataclass(frozen=True) class BenchResult: @@ -12,7 +16,7 @@ class BenchResult: correlation_id: str trace: Trace | None = None traces: list[dict] | None = None - engine: object | None = None # GraphEngine ref for Phase 2 data access + engine: GraphEngine | None = None def summary_text(self) -> str: if self.completion.ok: