38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from kernbench.runtime_api.context import RuntimeContext
|
|
|
|
BenchFn = Callable[[RuntimeContext], Any]
|
|
|
|
|
|
def resolve_bench(bench_id: str) -> BenchFn:
|
|
"""
|
|
Resolve a bench id into a callable bench function.
|
|
|
|
Expected layout (repo root):
|
|
benches/<bench_id>.py
|
|
def run(ctx: RuntimeContext) -> Any
|
|
"""
|
|
bench_id = bench_id.strip()
|
|
if not bench_id:
|
|
raise ValueError("Bench id is empty.")
|
|
|
|
module_path = f"benches.{bench_id}"
|
|
|
|
try:
|
|
mod = importlib.import_module(module_path)
|
|
except ModuleNotFoundError as e:
|
|
raise ValueError(f"Unknown bench '{bench_id}'. Expected module {module_path}.py") from e
|
|
|
|
run_fn = getattr(mod, "run", None)
|
|
if run_fn is None:
|
|
raise ValueError(f"Bench module {module_path} must define a 'run(ctx)' function.")
|
|
if not callable(run_fn):
|
|
raise ValueError(f"'run' in {module_path} is not callable.")
|
|
|
|
return run_fn
|