benches: package as kernbench.benches, add @bench registry + list subcommand

Move benches/ -> src/kernbench/benches/ and src/kernbench/cli/probe.py ->
src/kernbench/probes/probe.py. Each bench self-registers via
@bench(name=..., description=...); kernbench list enumerates benches
with auto-assigned indices, --bench accepts kebab-case name or numeric
index. Audit at package-import time fails if any non-underscore module
forgets the decorator. ADR-0010 (EN + KO) updated to reflect the new
resolver path, list subcommand, and probes package separation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 14:42:10 -07:00
parent 168b0c89f0
commit 049e3d8bb3
28 changed files with 398 additions and 79 deletions
View File
-2
View File
@@ -1,2 +0,0 @@
def run(torch):
print("IPCQ all reduce kernel bench")
-40
View File
@@ -1,40 +0,0 @@
from __future__ import annotations
import importlib
from collections.abc import Callable
from typing import Any
from kernbench.runtime_api.context import RuntimeContext
BenchFn = Callable[[RuntimeContext], Any]
def _load_module(bench_id: str):
bench_id = bench_id.strip()
if not bench_id:
raise ValueError("Bench id is empty.")
module_path = f"benches.{bench_id}"
try:
return importlib.import_module(module_path)
except ModuleNotFoundError as e:
raise ValueError(
f"Unknown bench '{bench_id}'. Expected module {module_path}.py"
) from e
def resolve_bench(bench_id: str) -> BenchFn:
"""Resolve a bench id into its ``run(torch)`` callable.
Expected layout (repo root):
benches/<bench_id>.py
def run(torch: RuntimeContext) -> Any
"""
mod = _load_module(bench_id)
run_fn = getattr(mod, "run", None)
if run_fn is None:
raise ValueError(
f"Bench module benches.{bench_id} must define 'run(torch)'."
)
if not callable(run_fn):
raise ValueError(f"'run' in benches.{bench_id} is not callable.")
return run_fn
@@ -6,10 +6,11 @@ Accepted
## Context ## Context
`kernbench` CLI는 시뮬레이터의 사용자 대면 진입점이다. 개의 서브명령을 `kernbench` CLI는 시뮬레이터의 사용자 대면 진입점이다. 개의 서브명령을
노출한다: 노출한다:
- `run` — 토폴로지에 대해 벤치마크를 실행한다. - `run` — 토폴로지에 대해 벤치마크를 실행한다.
- `list` — 등록된 벤치마크 목록을 출력한다.
- `probe` — 레이턴시 / 대역폭 측정을 위한 진단 유틸리티. - `probe` — 레이턴시 / 대역폭 측정을 위한 진단 유틸리티.
- `web` — 인터랙티브 토폴로지 뷰어. - `web` — 인터랙티브 토폴로지 뷰어.
@@ -33,8 +34,9 @@ Accepted
- `--topology <path>`: 토폴로지 YAML 파일 경로. `resolve_topology()` - `--topology <path>`: 토폴로지 YAML 파일 경로. `resolve_topology()`
통해 로드된다. 통해 로드된다.
- `--bench <name>`: 벤치마크 이름. `benches.loader.resolve_bench()` - `--bench <identifier>`: 벤치마크 식별자. `kernbench.benches.registry.resolve()`
통해 해석된다. 통해 해석되며, 등록된 kebab-case 이름(예: `gemm-single-pe`) 또는
`kernbench list` 의 숫자 인덱스를 모두 받는다.
선택 인자: 선택 인자:
@@ -61,7 +63,22 @@ Accepted
CLI는 여러 OS 프로세스나 독립된 시뮬레이션 실행을 생성하지 **않는다** CLI는 여러 OS 프로세스나 독립된 시뮬레이션 실행을 생성하지 **않는다**
병렬성은 단일 시뮬레이션 인스턴스 내부에서 일어난다. 병렬성은 단일 시뮬레이션 인스턴스 내부에서 일어난다.
### D4. `kernbench probe` — 레이턴시 / 대역폭 진단 유틸리티 ### D4. `kernbench list` — 등록된 벤치마크 목록 출력
인자 없음. 각 등록된 벤치의 자동 부여된 인덱스, 등록된 이름,
한 줄 설명을 출력한다.
벤치는 `@bench(name=..., description=...)` 데코레이터
(`kernbench.benches.registry`)를 통해 자기 자신을 등록한다.
`kernbench.benches/` 아래의 언더스코어로 시작하지 않는 모든 모듈은
반드시 최소 하나의 벤치를 등록해야 한다; 데코레이터가 누락되면
패키지 import 시점에 `RuntimeError`가 발생한다.
인덱스는 import 시점에 이름의 알파벳 순으로 부여된다. 인덱스는
`--bench` 의 축약 표기를 위한 CLI 편의 기능이며 안정적인 API가
아니다 — 알파벳 순으로 새 벤치가 끼면 이후 인덱스가 밀린다.
### D5. `kernbench probe` — 레이턴시 / 대역폭 진단 유틸리티
필수 인자: 필수 인자:
@@ -85,7 +102,7 @@ Probe는 추가로 단조성 불변식을 검증한다 — 예를 들어 local-H
레이턴시 / 대역폭 모델을 검증하기 위한 개발자 도구이다; 벤치마크가 레이턴시 / 대역폭 모델을 검증하기 위한 개발자 도구이다; 벤치마크가
아니다. 아니다.
### D5. `kernbench web` — 토폴로지 뷰어 ### D6. `kernbench web` — 토폴로지 뷰어
선택 인자: 선택 인자:
@@ -99,7 +116,7 @@ Probe는 추가로 단조성 불변식을 검증한다 — 예를 들어 local-H
- `kernbench web`은 인터랙티브이다 — 팬/줌, 컴포넌트 속성 호버, - `kernbench web`은 인터랙티브이다 — 팬/줌, 컴포넌트 속성 호버,
SIP / CUBE / PE 뷰 간 전환. SIP / CUBE / PE 뷰 간 전환.
### D6. runtime API와 시뮬레이션 엔진은 디바이스 스코프를 유지한다 ### D7. runtime API와 시뮬레이션 엔진은 디바이스 스코프를 유지한다
- runtime API 호출은 호출당 하나의 디바이스에서 동작한다. - runtime API 호출은 호출당 하나의 디바이스에서 동작한다.
- 시뮬레이션 엔진은 모든 요청을 결정론적으로 스케줄링한다. - 시뮬레이션 엔진은 모든 요청을 결정론적으로 스케줄링한다.
@@ -108,6 +125,9 @@ Probe는 추가로 단조성 불변식을 검증한다 — 예를 들어 local-H
이 불변식은 각 레이어를 독립적으로 테스트 가능하게 유지한다; 디바이스 이 불변식은 각 레이어를 독립적으로 테스트 가능하게 유지한다; 디바이스
열거와 다중 디바이스 팬아웃은 오직 CLI의 `run` 명령에만 존재한다(D3). 열거와 다중 디바이스 팬아웃은 오직 CLI의 `run` 명령에만 존재한다(D3).
`probe` 구현은 `kernbench.probes` 아래에 있다 (`kernbench.benches`
분리됨). 이는 probe가 등록된 벤치가 아니라 진단 유틸리티임을 반영한다.
## Consequences ## Consequences
- 벤치마크 작성자는 단일 디바이스 로직을 작성한다; 다중 디바이스 동작은 - 벤치마크 작성자는 단일 디바이스 로직을 작성한다; 다중 디바이스 동작은
@@ -7,9 +7,10 @@ Accepted
## Context ## Context
The `kernbench` CLI is the user-facing entry point of the simulator. It The `kernbench` CLI is the user-facing entry point of the simulator. It
exposes three subcommands: exposes four subcommands:
- `run` — execute a benchmark against a topology. - `run` — execute a benchmark against a topology.
- `list` — enumerate registered benches.
- `probe` — diagnostic utility for latency / BW measurement. - `probe` — diagnostic utility for latency / BW measurement.
- `web` — interactive topology viewer. - `web` — interactive topology viewer.
@@ -33,8 +34,10 @@ Required arguments:
- `--topology <path>`: topology YAML file path. Loaded via - `--topology <path>`: topology YAML file path. Loaded via
`resolve_topology()`. `resolve_topology()`.
- `--bench <name>`: benchmark name. Resolved via - `--bench <identifier>`: benchmark identifier. Resolved via
`benches.loader.resolve_bench()`. `kernbench.benches.registry.resolve()`, which accepts either the
registered kebab-case name (e.g., `gemm-single-pe`) or a numeric
index from `kernbench list`.
Optional arguments: Optional arguments:
@@ -63,7 +66,21 @@ When `--device all` (or omitted) and the topology has multiple SIPs:
The CLI does NOT spawn multiple OS processes or independent The CLI does NOT spawn multiple OS processes or independent
simulation runs — parallelism is internal to one simulation instance. simulation runs — parallelism is internal to one simulation instance.
### D4. `kernbench probe` — latency / BW diagnostic utility ### D4. `kernbench list` — enumerate registered benches
No arguments. Prints each registered bench's auto-assigned index,
registered name, and one-line description.
Benches register themselves via the `@bench(name=..., description=...)`
decorator (`kernbench.benches.registry`). Every non-underscore module
under `kernbench.benches/` MUST register at least one bench; a missing
decorator raises `RuntimeError` at package import time.
Indices are assigned alphabetically by name at import time. They are a
CLI convenience (shorthand for `--bench`), not a stable API — a new
bench inserted alphabetically will shift later indices.
### D5. `kernbench probe` — latency / BW diagnostic utility
Required argument: Required argument:
@@ -87,7 +104,7 @@ that local-HBM access ≤ cross-PE-within-cube ≤ cross-cube ≤
cross-SIP — and reports violations. Probe is a developer tool for cross-SIP — and reports violations. Probe is a developer tool for
verifying the latency / BW model; it is not a benchmark. verifying the latency / BW model; it is not a benchmark.
### D5. `kernbench web` — topology viewer ### D6. `kernbench web` — topology viewer
Optional arguments: Optional arguments:
@@ -102,7 +119,7 @@ the browser. Distinct from the static `docs/diagrams/` artifacts:
- `kernbench web` is interactive — pan/zoom, hover for component - `kernbench web` is interactive — pan/zoom, hover for component
attributes, switch between SIP / CUBE / PE views. attributes, switch between SIP / CUBE / PE views.
### D6. Runtime API and simulation engine remain device-scoped ### D7. Runtime API and simulation engine remain device-scoped
- Runtime API calls operate on one device per invocation. - Runtime API calls operate on one device per invocation.
- The simulation engine schedules all requests deterministically. - The simulation engine schedules all requests deterministically.
@@ -112,6 +129,10 @@ This invariant keeps each layer testable in isolation; device
enumeration and multi-device fan-out live only in the CLI's `run` enumeration and multi-device fan-out live only in the CLI's `run`
command (D3). command (D3).
The `probe` implementation lives under `kernbench.probes` (separate
from `kernbench.benches`), reflecting that probes are diagnostic
utilities, not registered benches.
## Consequences ## Consequences
- Benchmark authors write single-device logic; multi-device behavior - Benchmark authors write single-device logic; multi-device behavior
+2 -2
View File
@@ -12,8 +12,8 @@ dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard
kernbench = "kernbench.cli.main:main" kernbench = "kernbench.cli.main:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src", "."] where = ["src"]
include = ["kernbench*", "benches*"] include = ["kernbench*"]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
+4 -4
View File
@@ -117,19 +117,19 @@ def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") ->
os.environ["MATMUL_N"] = str(N) os.environ["MATMUL_N"] = str(N)
os.environ["MATMUL_VARIANT"] = variant os.environ["MATMUL_VARIANT"] = variant
# Late imports so env vars are read by benches/matmul_composite at module load. # Late imports so env vars are read by matmul_composite at module load.
# Force re-import to pick up new env values. # Force re-import to pick up new env values.
for mod_name in [m for m in list(sys.modules) if m.startswith("benches.matmul_composite")]: for mod_name in [m for m in list(sys.modules) if m.startswith("kernbench.benches.matmul_composite")]:
del sys.modules[mod_name] del sys.modules[mod_name]
from benches.loader import resolve_bench from kernbench.benches.registry import resolve as resolve_bench
from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import resolve_device from kernbench.runtime_api.types import resolve_device
from kernbench.sim_engine.engine import GraphEngine from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology from kernbench.topology.builder import resolve_topology
topo = resolve_topology(topology) topo = resolve_topology(topology)
bench = resolve_bench("matmul_composite") bench = resolve_bench("matmul-composite").run
device = resolve_device(None) device = resolve_device(None)
t0 = time.time() t0 = time.time()
+9
View File
@@ -0,0 +1,9 @@
"""kernbench.benches: eager-import sibling modules so @bench fires.
Underscore-prefixed modules are treated as helpers and skipped.
After import, every imported module must have registered at least one
bench, or a RuntimeError is raised by the audit.
"""
from kernbench.benches.registry import _eager_import_and_audit
_eager_import_and_audit(__path__, __name__)
@@ -14,6 +14,7 @@ from dataclasses import dataclass
import numpy as np import numpy as np
from kernbench.benches.registry import bench
from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config from kernbench.ccl.install import load_ccl_config, resolve_algorithm_config
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
@@ -95,6 +96,10 @@ def _worker(rank: int, cfg: _BenchCfg, torch) -> None:
_report(tensor.numpy(), cfg) _report(tensor.numpy(), cfg)
@bench(
name="ccl-allreduce",
description="CCL all-reduce bench (TP launcher; rank = SIP).",
)
def run(torch) -> None: def run(torch) -> None:
torch.distributed.init_process_group(backend="ahbm") torch.distributed.init_process_group(backend="ahbm")
cfg = _resolve_cfg(torch) cfg = _resolve_cfg(torch)
@@ -10,6 +10,7 @@ per-tile DMA internally.
Run: Run:
kernbench run gemm_single_pe kernbench run gemm_single_pe
""" """
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
# GEMM dimensions: (M, K) x (K, N) → (M, N) # GEMM dimensions: (M, K) x (K, N) → (M, N)
@@ -27,6 +28,10 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
tl.wait(h) tl.wait(h)
@bench(
name="gemm-single-pe",
description="Single-PE GEMM via scheduler_v2 (pe_accel).",
)
def run(torch): def run(torch):
"""Run the single-PE GEMM benchmark.""" """Run the single-PE GEMM benchmark."""
dp = DPPolicy(cube="replicate", pe="replicate", dp = DPPolicy(cube="replicate", pe="replicate",
@@ -20,6 +20,7 @@ topology.yaml is unchanged.
Run: Run:
kernbench run gpt3_qkv kernbench run gpt3_qkv
""" """
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
# -- PE configuration (DPPolicy overrides — does not change topology.yaml) ----- # -- PE configuration (DPPolicy overrides — does not change topology.yaml) -----
@@ -66,6 +67,10 @@ def _gpt3_qkv_kernel(x_ptr, wq_ptr, wk_ptr, wv_ptr,
tl.wait(hv) tl.wait(hv)
@bench(
name="gpt3-qkv",
description="GPT-3 QKV projection sharded column-wise across all PEs.",
)
def run(torch): def run(torch):
"""Run the GPT-3 QKV benchmark.""" """Run the GPT-3 QKV benchmark."""
M = SEQ_LEN M = SEQ_LEN
+9
View File
@@ -0,0 +1,9 @@
from kernbench.benches.registry import bench
@bench(
name="ipcq-allreduce",
description="IPCQ all-reduce kernel bench (placeholder).",
)
def run(torch):
print("IPCQ all reduce kernel bench")
@@ -17,6 +17,7 @@ Run:
""" """
import os import os
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
M = int(os.environ.get("MATMUL_M", "256")) M = int(os.environ.get("MATMUL_M", "256"))
@@ -57,6 +58,10 @@ _KERNELS = {
} }
@bench(
name="matmul-composite",
description="Single-PE composite GEMM with ref/load variants for perf characterization.",
)
def run(torch): def run(torch):
if VARIANT not in _KERNELS: if VARIANT not in _KERNELS:
raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; " raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; "
@@ -7,6 +7,7 @@ Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
- Tensor a is loaded into TCM via DMA - Tensor a is loaded into TCM via DMA
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32) - Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
""" """
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
# GEMM dimensions: (M, K) x (K, N) → (M, N) # GEMM dimensions: (M, K) x (K, N) → (M, N)
@@ -28,6 +29,10 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
tl.wait(handle) tl.wait(handle)
@bench(
name="qkv-gemm",
description="QKV GEMM (Q*K^T) on a single PE — full host-to-PE pipeline.",
)
def run(torch): def run(torch):
"""Run the QKV GEMM benchmark.""" """Run the QKV GEMM benchmark."""
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE) # DP placement: a=replicate (cube-level), b/out=column_wise (N-axis, single PE)
@@ -7,6 +7,7 @@ Kernel: tl.load(a) + tl.ref(b) + tl.composite(gemm) + tl.wait()
- Tensor a is loaded into TCM via DMA - Tensor a is loaded into TCM via DMA
- Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32) - Tensor b stays in HBM; PE_SCHEDULER streams it per-tile (32x64x32)
""" """
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
# GEMM dimensions: (M, K) x (K, N) -> (M, N) # GEMM dimensions: (M, K) x (K, N) -> (M, N)
@@ -28,6 +29,10 @@ def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
tl.wait(handle) tl.wait(handle)
@bench(
name="qkv-gemm-multi-pe",
description="Column-parallel QKV GEMM across all PEs in a cube (multi-PE).",
)
def run(torch): def run(torch):
"""Run the multi-PE QKV GEMM benchmark.""" """Run the multi-PE QKV GEMM benchmark."""
# DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split) # DP placement: a=replicate (cube-level), b/out=column_wise (N-axis split)
+106
View File
@@ -0,0 +1,106 @@
"""Bench registry: @bench decorator + name/index resolution.
Each bench module under ``kernbench.benches`` MUST register its callable
via ``@bench(name=..., description=...)``. Indices are assigned
alphabetically by name after eager import; they are a CLI convenience,
not a stable API.
"""
from __future__ import annotations
import re
from collections.abc import Callable
from dataclasses import dataclass
from importlib import import_module
from pkgutil import iter_modules
BenchFn = Callable[..., object]
_NAME_RE = re.compile(r"^[a-z][a-z0-9]*(-[a-z0-9]+)*$")
@dataclass(frozen=True)
class BenchSpec:
index: int
name: str
description: str
run: BenchFn
_PENDING: list[tuple[str, str, BenchFn]] = []
_REGISTERED_MODULES: set[str] = set()
_REGISTRY: dict[str, BenchSpec] = {}
def bench(*, name: str, description: str) -> Callable[[BenchFn], BenchFn]:
if not isinstance(name, str) or not _NAME_RE.match(name):
raise ValueError(
f"bench name {name!r} must be kebab-case (lowercase, digits, dashes; "
f"starts with a letter)."
)
if not isinstance(description, str) or not description.strip():
raise ValueError(f"bench {name!r}: description must be a non-empty string.")
def deco(fn: BenchFn) -> BenchFn:
_PENDING.append((name, description, fn))
_REGISTERED_MODULES.add(fn.__module__)
return fn
return deco
def _finalize() -> None:
if _REGISTRY:
return
seen: set[str] = set()
for n, _, _ in _PENDING:
if n in seen:
raise RuntimeError(f"duplicate bench name: {n!r}")
seen.add(n)
for i, (n, d, f) in enumerate(sorted(_PENDING, key=lambda t: t[0]), start=1):
_REGISTRY[n] = BenchSpec(index=i, name=n, description=d, run=f)
def list_all() -> list[BenchSpec]:
_finalize()
return sorted(_REGISTRY.values(), key=lambda s: s.index)
def resolve(identifier: str) -> BenchSpec:
_finalize()
if not isinstance(identifier, str) or not identifier.strip():
raise ValueError("bench identifier must be a non-empty string.")
ident = identifier.strip()
if ident.isdigit():
idx = int(ident)
for s in _REGISTRY.values():
if s.index == idx:
return s
raise ValueError(
f"No bench with index {idx}. Use 'kernbench list' to see options."
)
if ident in _REGISTRY:
return _REGISTRY[ident]
raise ValueError(
f"Unknown bench {ident!r}. Use 'kernbench list' to see options."
)
def _audit_modules(imported: list[str], registered: set[str]) -> None:
missing = sorted(m for m in imported if m not in registered)
if missing:
raise RuntimeError(
f"Bench module(s) missing @bench decorator: {missing}. "
f"Each file under kernbench.benches/ must register at least one bench "
f"via @bench(...), or be renamed with a leading underscore if it is a "
f"helper."
)
def _eager_import_and_audit(pkg_path: list[str], pkg_name: str) -> None:
imported: list[str] = []
for m in iter_modules(pkg_path):
if m.name == "registry" or m.name.startswith("_"):
continue
mod = import_module(f"{pkg_name}.{m.name}")
imported.append(mod.__name__)
_audit_modules(imported, _REGISTERED_MODULES)
@@ -9,6 +9,7 @@ The kernel uses standard Triton patterns:
- tl.num_programs(0) for PE count within cube - tl.num_programs(0) for PE count within cube
- Shape args are automatically localized by launch() - Shape args are automatically localized by launch()
""" """
from kernbench.benches.registry import bench
from kernbench.policy.placement.dp import DPPolicy from kernbench.policy.placement.dp import DPPolicy
M, K = 128, 256 M, K = 128, 256
@@ -26,6 +27,10 @@ def _copy_kernel(src_ptr, dst_ptr, M, K, tl, DTYPE="f16"):
tl.store(dst_ptr + offset, data) tl.store(dst_ptr + offset, data)
@bench(
name="va-offset-verify",
description="Triton base_ptr + pid * stride VA addressing verification (TP sharded).",
)
def run(torch): def run(torch):
"""Run the VA offset verification benchmark with full TP sharding.""" """Run the VA offset verification benchmark with full TP sharding."""
dp = DPPolicy(cube="column_wise", pe="column_wise") dp = DPPolicy(cube="column_wise", pe="column_wise")
+23 -6
View File
@@ -1,10 +1,10 @@
import argparse import argparse
import sys import sys
from benches.loader import resolve_bench from kernbench.benches.registry import list_all, resolve
from kernbench.cli.probe import cmd_probe
from kernbench.cli.report import format_report from kernbench.cli.report import format_report
from kernbench.common.types import SimEngine from kernbench.common.types import SimEngine
from kernbench.probes.probe import cmd_probe
from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import DeviceSelector, resolve_device from kernbench.runtime_api.types import DeviceSelector, resolve_device
from kernbench.sim_engine.engine import GraphEngine from kernbench.sim_engine.engine import GraphEngine
@@ -17,7 +17,10 @@ def build_parser() -> argparse.ArgumentParser:
runp = sub.add_parser("run", help="Run a benchmark") runp = sub.add_parser("run", help="Run a benchmark")
runp.add_argument("--topology", required=True) runp.add_argument("--topology", required=True)
runp.add_argument("--bench", required=True) runp.add_argument(
"--bench", required=True,
help="Bench name (kebab-case) or numeric index from 'kernbench list'",
)
runp.add_argument( runp.add_argument(
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)" "--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
) )
@@ -27,6 +30,9 @@ def build_parser() -> argparse.ArgumentParser:
) )
runp.set_defaults(_handler=cmd_run) runp.set_defaults(_handler=cmd_run)
listp = sub.add_parser("list", help="List registered benches")
listp.set_defaults(_handler=cmd_list)
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns") probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
probep.add_argument("--topology", required=True) probep.add_argument("--topology", required=True)
probep.add_argument("--case", default="all", help="Case name or 'all' (default: all)") probep.add_argument("--case", default="all", help="Case name or 'all' (default: all)")
@@ -53,23 +59,34 @@ def cmd_web(args) -> int:
return 0 return 0
def cmd_list(args) -> int:
specs = list_all()
print(f"{'#':>3} {'NAME':<22} DESCRIPTION")
print("-" * 80)
for s in specs:
print(f"{s.index:>3} {s.name:<22} {s.description}")
return 0
def cmd_run(args) -> int: def cmd_run(args) -> int:
print("> Running benchmark with:", args) print("> Running benchmark with:", args)
topo = resolve_topology(args.topology) topo = resolve_topology(args.topology)
bench = resolve_bench(args.bench) spec_entry = resolve(args.bench)
device = resolve_device(args.device) device = resolve_device(args.device)
verify_data = getattr(args, "verify_data", False) verify_data = getattr(args, "verify_data", False)
def _factory(topology, device): def _factory(topology, device):
return engine_factory(topology, device, enable_data=verify_data) return engine_factory(topology, device, enable_data=verify_data)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory) result = run_bench(
topology=topo, bench_fn=spec_entry.run, device=device, engine_factory=_factory,
)
topo_obj = getattr(topo, "topology_obj", topo) topo_obj = getattr(topo, "topology_obj", topo)
spec = getattr(topo_obj, "spec", None) spec = getattr(topo_obj, "spec", None)
if result.traces: if result.traces:
print(format_report(result.traces, title=args.bench, spec=spec)) print(format_report(result.traces, title=spec_entry.name, spec=spec))
print(result.summary_text()) print(result.summary_text())
# Phase 2 diagnostic summary (ADR-0020). The actual Phase 2 replay # Phase 2 diagnostic summary (ADR-0020). The actual Phase 2 replay
+5
View File
@@ -0,0 +1,5 @@
"""kernbench.probes: latency/BW diagnostic utilities (not benchmarks).
See ADR-0010 D4. Probe is a developer tool for verifying the latency/BW
model; it bypasses the bench registry.
"""
+95
View File
@@ -0,0 +1,95 @@
"""Tests for kernbench.benches.registry — @bench decorator + resolve/list."""
from __future__ import annotations
import pytest
from kernbench.benches import registry
EXPECTED_NAMES = [
"ccl-allreduce",
"gemm-single-pe",
"gpt3-qkv",
"ipcq-allreduce",
"matmul-composite",
"qkv-gemm",
"qkv-gemm-multi-pe",
"va-offset-verify",
]
def test_registry_lists_all_benches():
specs = registry.list_all()
names = [s.name for s in specs]
assert names == EXPECTED_NAMES
def test_registry_indices_are_1_based_sorted_by_name():
specs = registry.list_all()
assert [s.index for s in specs] == list(range(1, len(EXPECTED_NAMES) + 1))
assert sorted(s.name for s in specs) == [s.name for s in specs]
def test_resolve_by_name_returns_spec():
spec = registry.resolve("gemm-single-pe")
assert spec.name == "gemm-single-pe"
assert callable(spec.run)
assert spec.description.strip()
def test_resolve_by_index_string_matches_list_order():
specs = registry.list_all()
third = specs[2]
resolved = registry.resolve(str(third.index))
assert resolved is third
def test_resolve_unknown_name_raises():
with pytest.raises(ValueError, match="kernbench list"):
registry.resolve("does-not-exist")
def test_resolve_unknown_index_raises():
with pytest.raises(ValueError, match="kernbench list"):
registry.resolve("99")
def test_resolve_empty_identifier_raises():
with pytest.raises(ValueError):
registry.resolve("")
def test_bench_decorator_rejects_invalid_name():
with pytest.raises(ValueError, match="kebab-case"):
registry.bench(name="Invalid_Name", description="x")
def test_bench_decorator_rejects_empty_description():
with pytest.raises(ValueError, match="non-empty"):
registry.bench(name="ok-name", description=" ")
def test_audit_raises_on_missing_decorator():
with pytest.raises(RuntimeError, match="missing @bench decorator"):
registry._audit_modules(
imported=["kernbench.benches.fake_no_dec", "kernbench.benches.real"],
registered={"kernbench.benches.real"},
)
def test_audit_passes_when_all_registered():
registry._audit_modules(
imported=["kernbench.benches.a", "kernbench.benches.b"],
registered={"kernbench.benches.a", "kernbench.benches.b"},
)
def test_duplicate_name_at_finalize_fails(monkeypatch):
"""_finalize() rejects two pending entries with the same name."""
monkeypatch.setattr(registry, "_PENDING", [
("dup", "d1", lambda: None),
("dup", "d2", lambda: None),
])
monkeypatch.setattr(registry, "_REGISTRY", {})
with pytest.raises(RuntimeError, match="duplicate bench name"):
registry._finalize()
+3 -3
View File
@@ -6,17 +6,17 @@ def test_cli_main_arg_parsing(monkeypatch):
def fake_cmd_run(args) -> int: def fake_cmd_run(args) -> int:
assert args.cmd == "run" assert args.cmd == "run"
assert args.topology == "topology.yaml" assert args.topology == "topology.yaml"
assert args.bench == "qkv_gemm" assert args.bench == "qkv-gemm"
assert args.device == None assert args.device == None
return 0 return 0
# monkey patch the handler to test arg parsing without running the actual bench # monkey patch the handler to test arg parsing without running the actual bench
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run) monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm"]) rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv-gemm"])
assert rc == 0 assert rc == 0
def test_cli_main(): def test_cli_main():
"""CLI bench run on single SIP device.""" """CLI bench run on single SIP device."""
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "--device", "sip:0"]) rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv-gemm", "--device", "sip:0"])
assert rc == 0 assert rc == 0
+44
View File
@@ -0,0 +1,44 @@
"""Tests for `kernbench list` subcommand and `--bench <index>` resolution."""
from __future__ import annotations
import kernbench.cli.main as cli_main
from kernbench.benches import registry
def test_cli_list_outputs_all_benches(capsys):
rc = cli_main.main(["list"])
assert rc == 0
out = capsys.readouterr().out
for spec in registry.list_all():
assert spec.name in out
assert "DESCRIPTION" in out
def test_cli_run_by_index(monkeypatch):
"""CLI accepts numeric index for --bench; same callable as the name."""
qkv_spec = registry.resolve("qkv-gemm")
captured = {}
def fake_run_bench(*, topology, bench_fn, device, engine_factory):
captured["bench_fn"] = bench_fn
class _R:
traces = []
engine = None
class completion:
ok = True
def summary_text(self):
return ""
return _R()
monkeypatch.setattr(cli_main, "run_bench", fake_run_bench)
rc = cli_main.main([
"run", "--topology", "topology.yaml",
"--bench", str(qkv_spec.index),
"--device", "sip:0",
])
assert rc == 0
assert captured["bench_fn"] is qkv_spec.run
+4 -4
View File
@@ -11,7 +11,7 @@ def test_cli_verify_data_flag_parsed(monkeypatch):
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run) monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main([ rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "run", "--topology", "topology.yaml", "--bench", "qkv-gemm",
"--verify-data", "--verify-data",
]) ])
assert rc == 0 assert rc == 0
@@ -26,7 +26,7 @@ def test_cli_verify_data_flag_default(monkeypatch):
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run) monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main([ rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "run", "--topology", "topology.yaml", "--bench", "qkv-gemm",
]) ])
assert rc == 0 assert rc == 0
@@ -34,7 +34,7 @@ def test_cli_verify_data_flag_default(monkeypatch):
def test_cmd_run_verify_data_enables_engine(): def test_cmd_run_verify_data_enables_engine():
"""--verify-data runs full pipeline with enable_data=True and DataExecutor.""" """--verify-data runs full pipeline with enable_data=True and DataExecutor."""
rc = cli_main.main([ rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "run", "--topology", "topology.yaml", "--bench", "qkv-gemm",
"--device", "sip:0", "--verify-data", "--device", "sip:0", "--verify-data",
]) ])
assert rc == 0 assert rc == 0
@@ -43,7 +43,7 @@ def test_cmd_run_verify_data_enables_engine():
def test_cmd_run_without_verify_data_no_op_log(): def test_cmd_run_without_verify_data_no_op_log():
"""Without --verify-data, engine runs in timing-only mode (no op_log).""" """Without --verify-data, engine runs in timing-only mode (no op_log)."""
rc = cli_main.main([ rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "run", "--topology", "topology.yaml", "--bench", "qkv-gemm",
"--device", "sip:0", "--device", "sip:0",
]) ])
assert rc == 0 assert rc == 0
+1 -1
View File
@@ -235,7 +235,7 @@ def test_qkv_gemm_still_passes():
correlation_id="test_regression", correlation_id="test_regression",
spec=graph.spec, spec=graph.spec,
) )
from benches.qkv_gemm import run as bench_run from kernbench.benches.qkv_gemm import run as bench_run
bench_run(ctx) bench_run(ctx)
ctx.wait_all() ctx.wait_all()
# If we get here without exception, the benchmark succeeded # If we get here without exception, the benchmark succeeded
+2 -2
View File
@@ -864,7 +864,7 @@ def test_mcpu_kernel_launch_composite():
def test_qkv_gemm_bench_completes(): def test_qkv_gemm_bench_completes():
"""The qkv_gemm benchmark runs to completion without error.""" """The qkv_gemm benchmark runs to completion without error."""
clear_registry() clear_registry()
from benches.qkv_gemm import run as bench_run from kernbench.benches.qkv_gemm import run as bench_run
from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.context import RuntimeContext
graph = load_topology(TOPOLOGY_PATH) graph = load_topology(TOPOLOGY_PATH)
@@ -958,7 +958,7 @@ def test_mcpu_multi_pe_kernel_launch():
def test_qkv_gemm_bench_multi_pe_completes(): def test_qkv_gemm_bench_multi_pe_completes():
"""The qkv_gemm_multi_pe benchmark runs to completion without error.""" """The qkv_gemm_multi_pe benchmark runs to completion without error."""
clear_registry() clear_registry()
from benches.qkv_gemm_multi_pe import run as bench_run from kernbench.benches.qkv_gemm_multi_pe import run as bench_run
from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.context import RuntimeContext
graph = load_topology(TOPOLOGY_PATH) graph = load_topology(TOPOLOGY_PATH)
+2 -2
View File
@@ -263,7 +263,7 @@ def test_pe_cross_cube_best_worst():
def test_probe_timestamp_trace(): def test_probe_timestamp_trace():
"""_hop_timestamps must return monotonically increasing cumulative timestamps.""" """_hop_timestamps must return monotonically increasing cumulative timestamps."""
from kernbench.cli.probe import _hop_timestamps, _build_edge_map from kernbench.probes.probe import _hop_timestamps, _build_edge_map
graph = _graph() graph = _graph()
edge_map = _build_edge_map(graph) edge_map = _build_edge_map(graph)
resolver = AddressResolver(graph) resolver = AddressResolver(graph)
@@ -341,7 +341,7 @@ def test_hbm_efficiency_applied():
def test_probe_sweep_saturation(): def test_probe_sweep_saturation():
"""Utilization at 1MB must exceed utilization at 4KB for pe-local-hbm.""" """Utilization at 1MB must exceed utilization at 4KB for pe-local-hbm."""
from kernbench.cli.probe import _sweep_util from kernbench.probes.probe import _sweep_util
# pe-local-hbm: ovhd=2ns (router), wire~0.03ns, bn from topology # pe-local-hbm: ovhd=2ns (router), wire~0.03ns, bn from topology
bn = _hbm_effective_bw() bn = _hbm_effective_bw()
u = _sweep_util(2.0, 0.03, bn) u = _sweep_util(2.0, 0.03, bn)
+1 -1
View File
@@ -143,7 +143,7 @@ def test_2d_bench_completes():
engine=engine, target_device=DeviceSelector("sip:0"), engine=engine, target_device=DeviceSelector("sip:0"),
correlation_id="vo3", spec=graph.spec, correlation_id="vo3", spec=graph.spec,
) )
from benches.va_offset_verify import run as bench_run from kernbench.benches.va_offset_verify import run as bench_run
bench_run(ctx) bench_run(ctx)
ctx.wait_all() ctx.wait_all()