049e3d8bb3
Move benches/ -> src/kernbench/benches/ and src/kernbench/cli/probe.py -> src/kernbench/probes/probe.py. Each bench self-registers via @bench(name=..., description=...); kernbench list enumerates benches with auto-assigned indices, --bench accepts kebab-case name or numeric index. Audit at package-import time fails if any non-underscore module forgets the decorator. ADR-0010 (EN + KO) updated to reflect the new resolver path, list subcommand, and probes package separation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
240 lines
8.4 KiB
Python
240 lines
8.4 KiB
Python
"""Sweep GEMM shapes through kernbench and dump PE_accelerator engine times.
|
|
|
|
For each shape:
|
|
- run benches.matmul_composite via the same run_bench path the CLI uses
|
|
- read result.engine.op_log
|
|
- filter to per-PE engines: pe_dma, pe_fetch_store, pe_gemm, pe_math
|
|
- record sum-of-durations (engine occupancy) AND wall-clock active interval
|
|
|
|
Output: docs/diagrams/gemm_sweep.json
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Default sweep covering under-tile, single-tile, multi-tile, and asymmetric regimes.
|
|
# Each entry is either a single integer (square M=K=N=S) or "MxKxN".
|
|
# Override via env: SWEEP_SHAPES="16,32,16x2048x16,..."
|
|
DEFAULT_SHAPES = [
|
|
"32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K
|
|
"32x64x32", # 1 tile, exact single-tile fit
|
|
"32x128x32", # 2 tiles, aligned
|
|
"32x128x128", # 8 tiles, aligned
|
|
"32x3072x32", # 48 tiles, all K-axis (tall-skinny)
|
|
"8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed
|
|
"128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed
|
|
"512", # 2048 tiles, fully aligned — "well-pipelined" reference
|
|
]
|
|
|
|
# Operand-staging variants exercised per shape.
|
|
VARIANTS = ["ref_ref", "load_ref", "load_load"]
|
|
|
|
# Engines whose timings we collect (component_id suffix match).
|
|
ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"]
|
|
|
|
# Per-stage breakdown labels (StageType enum names from pe_types.py).
|
|
STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"]
|
|
|
|
# Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N).
|
|
TILE_M, TILE_K, TILE_N = 32, 64, 32
|
|
|
|
OUT_PATH = Path(__file__).parent.parent / "docs" / "diagrams" / "gemm_sweep.json"
|
|
|
|
|
|
def _engine_wall_ns(records, suffix: str) -> float:
|
|
"""Wall-clock interval the engine was active (union of overlapping ops)."""
|
|
intervals = [(r.t_start, r.t_end) for r in records
|
|
if r.component_id.endswith("." + suffix)]
|
|
if not intervals:
|
|
return 0.0
|
|
intervals.sort()
|
|
merged_end = intervals[0][1]
|
|
merged_start = intervals[0][0]
|
|
total = 0.0
|
|
for s, e in intervals[1:]:
|
|
if s <= merged_end:
|
|
merged_end = max(merged_end, e)
|
|
else:
|
|
total += merged_end - merged_start
|
|
merged_start, merged_end = s, e
|
|
total += merged_end - merged_start
|
|
return total
|
|
|
|
|
|
def _engine_occupancy_ns(records, suffix: str) -> float:
|
|
return sum(r.t_end - r.t_start for r in records
|
|
if r.component_id.endswith("." + suffix))
|
|
|
|
|
|
def _engine_count(records, suffix: str) -> int:
|
|
return sum(1 for r in records if r.component_id.endswith("." + suffix))
|
|
|
|
|
|
def _stage_occupancy_ns(records, stage_type: str) -> float:
|
|
"""Sum t_end - t_start over op_log records whose params.stage_type matches.
|
|
|
|
Requires op_log records produced post the TileToken stage_type capture
|
|
(sim_engine/op_log.py).
|
|
"""
|
|
return sum(
|
|
r.t_end - r.t_start
|
|
for r in records
|
|
if r.params.get("stage_type") == stage_type
|
|
)
|
|
|
|
|
|
def _stage_wall_ns(records, stage_type: str) -> float:
|
|
"""Interval-union wall-clock for records whose stage_type matches."""
|
|
intervals = sorted(
|
|
(r.t_start, r.t_end) for r in records
|
|
if r.params.get("stage_type") == stage_type
|
|
)
|
|
if not intervals:
|
|
return 0.0
|
|
total = 0.0
|
|
cs, ce = intervals[0]
|
|
for s, e in intervals[1:]:
|
|
if s <= ce:
|
|
ce = max(ce, e)
|
|
else:
|
|
total += ce - cs
|
|
cs, ce = s, e
|
|
total += ce - cs
|
|
return total
|
|
|
|
|
|
def _stage_count(records, stage_type: str) -> int:
|
|
return sum(1 for r in records if r.params.get("stage_type") == stage_type)
|
|
|
|
|
|
def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict:
|
|
os.environ["MATMUL_M"] = str(M)
|
|
os.environ["MATMUL_K"] = str(K)
|
|
os.environ["MATMUL_N"] = str(N)
|
|
os.environ["MATMUL_VARIANT"] = variant
|
|
|
|
# Late imports so env vars are read by matmul_composite at module load.
|
|
# Force re-import to pick up new env values.
|
|
for mod_name in [m for m in list(sys.modules) if m.startswith("kernbench.benches.matmul_composite")]:
|
|
del sys.modules[mod_name]
|
|
|
|
from kernbench.benches.registry import resolve as resolve_bench
|
|
from kernbench.runtime_api.bench_runner import run_bench
|
|
from kernbench.runtime_api.types import resolve_device
|
|
from kernbench.sim_engine.engine import GraphEngine
|
|
from kernbench.topology.builder import resolve_topology
|
|
|
|
topo = resolve_topology(topology)
|
|
bench = resolve_bench("matmul-composite").run
|
|
device = resolve_device(None)
|
|
|
|
t0 = time.time()
|
|
result = run_bench(
|
|
topology=topo, bench_fn=bench, device=device,
|
|
engine_factory=lambda t, d: GraphEngine(
|
|
getattr(t, "topology_obj", t), enable_data=True,
|
|
),
|
|
)
|
|
wall = time.time() - t0
|
|
|
|
op_log = result.engine.op_log
|
|
if not result.completion.ok:
|
|
raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}")
|
|
|
|
# Bytes touched at f16 (2 B): full A + full B + full out (each operand
|
|
# streamed once through HBM by the composite plan).
|
|
bytes_total = (M * K + K * N + M * N) * 2
|
|
row = {
|
|
"M": M, "K": K, "N": N,
|
|
"variant": variant,
|
|
"flops": 2 * M * K * N,
|
|
"bytes_hbm": bytes_total,
|
|
"arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte
|
|
"tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K),
|
|
"sim_wall_clock_s": round(wall, 3),
|
|
"engines": {},
|
|
}
|
|
for eng in ENGINES:
|
|
row["engines"][eng] = {
|
|
"occupancy_ns": _engine_occupancy_ns(op_log, eng),
|
|
"wall_ns": _engine_wall_ns(op_log, eng),
|
|
"record_count": _engine_count(op_log, eng),
|
|
}
|
|
row["stages"] = {}
|
|
for stage in STAGES:
|
|
row["stages"][stage] = {
|
|
"occupancy_ns": _stage_occupancy_ns(op_log, stage),
|
|
"wall_ns": _stage_wall_ns(op_log, stage),
|
|
"record_count": _stage_count(op_log, stage),
|
|
}
|
|
# Kernel-window wall-clock = max t_end - min t_start over PE engine records.
|
|
pe_records = [r for r in op_log
|
|
if any(r.component_id.endswith("." + e) for e in ENGINES)]
|
|
if pe_records:
|
|
row["pe_window_ns"] = max(r.t_end for r in pe_records) \
|
|
- min(r.t_start for r in pe_records)
|
|
else:
|
|
row["pe_window_ns"] = 0.0
|
|
stage_records = [r for r in op_log
|
|
if r.params.get("stage_type") in STAGES]
|
|
if stage_records:
|
|
row["composite_window_ns"] = max(r.t_end for r in stage_records) \
|
|
- min(r.t_start for r in stage_records)
|
|
else:
|
|
row["composite_window_ns"] = 0.0
|
|
return row
|
|
|
|
|
|
def _ceil(a: int, b: int) -> int:
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def main() -> int:
|
|
shapes_env = os.environ.get("SWEEP_SHAPES")
|
|
raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES)
|
|
shapes: list[tuple[int, int, int]] = []
|
|
for s in raw:
|
|
s = s.strip()
|
|
if not s:
|
|
continue
|
|
if "x" in s.lower():
|
|
parts = s.lower().split("x")
|
|
shapes.append((int(parts[0]), int(parts[1]), int(parts[2])))
|
|
else:
|
|
v = int(s)
|
|
shapes.append((v, v, v))
|
|
topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml")
|
|
|
|
rows = []
|
|
for M, K, N in shapes:
|
|
for variant in VARIANTS:
|
|
print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True)
|
|
row = _run_one(M, K, N, topology, variant=variant)
|
|
rows.append(row)
|
|
eng_dma = row["engines"]["pe_dma"]
|
|
eng_gem = row["engines"]["pe_gemm"]
|
|
print(f" tiles={row['tile_count_expected']:>6} "
|
|
f"pe_window={row['pe_window_ns']:8.1f}ns "
|
|
f"dma_occ={eng_dma['occupancy_ns']:9.1f} "
|
|
f"gemm_occ={eng_gem['occupancy_ns']:8.1f} "
|
|
f"(sim {row['sim_wall_clock_s']:.1f}s)")
|
|
|
|
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
OUT_PATH.write_text(json.dumps({
|
|
"tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N},
|
|
"engines": ENGINES,
|
|
"stages": STAGES,
|
|
"variants": VARIANTS,
|
|
"rows": rows,
|
|
}, indent=2))
|
|
print(f"\n[sweep] wrote {OUT_PATH}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|