types: narrow BenchResult.engine to GraphEngine, cast topology in engine_factory
Replace BenchResult.engine: object | None with GraphEngine | None via TYPE_CHECKING import (avoids circular import at runtime). Cast the topology argument to TopologyGraph at the GraphEngine call site for the duck-typed engine_factory. Fixes Pylance reportAttributeAccessIssue warnings on op_log and topology arg. Type annotations only; no runtime behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
from kernbench.benches.registry import list_all, resolve
|
||||
from kernbench.cli.report import format_report
|
||||
@@ -9,6 +10,7 @@ from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
from kernbench.topology.types import TopologyGraph
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
@@ -49,7 +51,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
def engine_factory(
|
||||
topology: object, device: DeviceSelector, *, enable_data: bool = False,
|
||||
) -> SimEngine:
|
||||
topo_obj = getattr(topology, "topology_obj", topology)
|
||||
topo_obj = cast(TopologyGraph, getattr(topology, "topology_obj", topology))
|
||||
return GraphEngine(topo_obj, enable_data=enable_data)
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,13 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from kernbench.common.types import Completion, Trace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BenchResult:
|
||||
@@ -12,7 +16,7 @@ class BenchResult:
|
||||
correlation_id: str
|
||||
trace: Trace | None = None
|
||||
traces: list[dict] | None = None
|
||||
engine: object | None = None # GraphEngine ref for Phase 2 data access
|
||||
engine: GraphEngine | None = None
|
||||
|
||||
def summary_text(self) -> str:
|
||||
if self.completion.ok:
|
||||
|
||||
Reference in New Issue
Block a user