types: narrow BenchResult.engine to GraphEngine, cast topology in engine_factory
Replace BenchResult.engine: object | None with GraphEngine | None via TYPE_CHECKING import (avoids circular import at runtime). Cast the topology argument to TopologyGraph at the GraphEngine call site for the duck-typed engine_factory. Fixes Pylance reportAttributeAccessIssue warnings on op_log and topology arg. Type annotations only; no runtime behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from kernbench.benches.registry import list_all, resolve
|
from kernbench.benches.registry import list_all, resolve
|
||||||
from kernbench.cli.report import format_report
|
from kernbench.cli.report import format_report
|
||||||
@@ -9,6 +10,7 @@ from kernbench.runtime_api.bench_runner import run_bench
|
|||||||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||||||
from kernbench.sim_engine.engine import GraphEngine
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
from kernbench.topology.builder import resolve_topology
|
from kernbench.topology.builder import resolve_topology
|
||||||
|
from kernbench.topology.types import TopologyGraph
|
||||||
|
|
||||||
|
|
||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
@@ -49,7 +51,7 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
def engine_factory(
|
def engine_factory(
|
||||||
topology: object, device: DeviceSelector, *, enable_data: bool = False,
|
topology: object, device: DeviceSelector, *, enable_data: bool = False,
|
||||||
) -> SimEngine:
|
) -> SimEngine:
|
||||||
topo_obj = getattr(topology, "topology_obj", topology)
|
topo_obj = cast(TopologyGraph, getattr(topology, "topology_obj", topology))
|
||||||
return GraphEngine(topo_obj, enable_data=enable_data)
|
return GraphEngine(topo_obj, enable_data=enable_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from kernbench.common.types import Completion, Trace
|
from kernbench.common.types import Completion, Trace
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BenchResult:
|
class BenchResult:
|
||||||
@@ -12,7 +16,7 @@ class BenchResult:
|
|||||||
correlation_id: str
|
correlation_id: str
|
||||||
trace: Trace | None = None
|
trace: Trace | None = None
|
||||||
traces: list[dict] | None = None
|
traces: list[dict] | None = None
|
||||||
engine: object | None = None # GraphEngine ref for Phase 2 data access
|
engine: GraphEngine | None = None
|
||||||
|
|
||||||
def summary_text(self) -> str:
|
def summary_text(self) -> str:
|
||||||
if self.completion.ok:
|
if self.completion.ok:
|
||||||
|
|||||||
Reference in New Issue
Block a user