dc3fb02aed
- CLI: --verify-data flag enables Phase 2 data verification (ADR-0020) - Tensor.data: returns actual numpy values (verify-data) or zeros placeholder - Tensor.__repr__: shows value summary or data=N/A (placeholder) - DataExecutor: ThreadPoolExecutor for same-timestamp parallel op execution - BenchResult.engine: exposes op_log/memory_store for Phase 2 access Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
101 lines
3.5 KiB
Python
101 lines
3.5 KiB
Python
import argparse
|
|
import sys
|
|
|
|
from benches.loader import resolve_bench
|
|
from kernbench.cli.probe import cmd_probe
|
|
from kernbench.cli.report import format_report
|
|
from kernbench.common.types import SimEngine
|
|
from kernbench.runtime_api.bench_runner import run_bench
|
|
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
|
from kernbench.sim_engine.engine import GraphEngine
|
|
from kernbench.topology.builder import resolve_topology
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(prog="kernbench")
|
|
sub = p.add_subparsers(dest="cmd", required=True)
|
|
|
|
runp = sub.add_parser("run", help="Run a benchmark")
|
|
runp.add_argument("--topology", required=True)
|
|
runp.add_argument("--bench", required=True)
|
|
runp.add_argument(
|
|
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
|
|
)
|
|
runp.add_argument(
|
|
"--verify-data", action="store_true", default=False,
|
|
help="Enable Phase 2 data verification (ADR-0020)",
|
|
)
|
|
runp.set_defaults(_handler=cmd_run)
|
|
|
|
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
|
|
probep.add_argument("--topology", required=True)
|
|
probep.add_argument("--case", default="all", help="Case name or 'all' (default: all)")
|
|
probep.set_defaults(_handler=cmd_probe)
|
|
|
|
webp = sub.add_parser("web", help="Launch topology viewer in browser")
|
|
webp.add_argument("--port", type=int, default=8765, help="HTTP port (default: 8765)")
|
|
webp.add_argument("--no-open", action="store_true", help="Don't auto-open browser")
|
|
webp.set_defaults(_handler=cmd_web)
|
|
|
|
return p
|
|
|
|
|
|
def engine_factory(
|
|
topology: object, device: DeviceSelector, *, enable_data: bool = False,
|
|
) -> SimEngine:
|
|
topo_obj = getattr(topology, "topology_obj", topology)
|
|
return GraphEngine(topo_obj, enable_data=enable_data)
|
|
|
|
|
|
def cmd_web(args) -> int:
|
|
from kernbench.web.server import serve
|
|
serve(port=args.port, open_browser=not args.no_open)
|
|
return 0
|
|
|
|
|
|
def cmd_run(args) -> int:
|
|
print("> Running benchmark with:", args)
|
|
|
|
topo = resolve_topology(args.topology)
|
|
bench = resolve_bench(args.bench)
|
|
device = resolve_device(args.device)
|
|
verify_data = getattr(args, "verify_data", False)
|
|
|
|
def _factory(topology, device):
|
|
return engine_factory(topology, device, enable_data=verify_data)
|
|
|
|
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory)
|
|
|
|
topo_obj = getattr(topo, "topology_obj", topo)
|
|
spec = getattr(topo_obj, "spec", None)
|
|
if result.traces:
|
|
print(format_report(result.traces, title=args.bench, spec=spec))
|
|
print(result.summary_text())
|
|
|
|
# Phase 2: data execution (ADR-0020)
|
|
if verify_data and result.engine is not None:
|
|
from kernbench.sim_engine.data_executor import DataExecutor
|
|
|
|
op_log = result.engine.op_log
|
|
store = result.engine.memory_store
|
|
if op_log and store is not None:
|
|
executor = DataExecutor(op_log, store)
|
|
executor.run()
|
|
n_gemm = sum(1 for r in op_log if r.op_kind == "gemm")
|
|
n_math = sum(1 for r in op_log if r.op_kind == "math")
|
|
print(f"[data] Phase 2 complete: {len(op_log)} ops ({n_gemm} gemm, {n_math} math)")
|
|
else:
|
|
print("[data] No op_log recorded — skipping Phase 2")
|
|
|
|
return 0 if result.completion.ok else 1
|
|
|
|
|
|
def main(argv=None) -> int:
|
|
parser = build_parser()
|
|
args = parser.parse_args(argv)
|
|
return int(args._handler(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|