"""Sweep GEMM shapes through kernbench and dump PE_accelerator engine times. For each shape: - run benches.matmul_composite via the same run_bench path the CLI uses - read result.engine.op_log - filter to per-PE engines: pe_dma, pe_fetch_store, pe_gemm, pe_math - record sum-of-durations (engine occupancy) AND wall-clock active interval Output: docs/diagrams/gemm_sweep.json """ from __future__ import annotations import json import os import sys import time from pathlib import Path # Default sweep covering under-tile, single-tile, multi-tile, and asymmetric regimes. # Each entry is either a single integer (square M=K=N=S) or "MxKxN". # Override via env: SWEEP_SHAPES="16,32,16x2048x16,..." DEFAULT_SHAPES = [ "32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K "32x64x32", # 1 tile, exact single-tile fit "32x128x32", # 2 tiles, aligned "32x128x128", # 8 tiles, aligned "32x3072x32", # 48 tiles, all K-axis (tall-skinny) "8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed "128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed "512", # 2048 tiles, fully aligned — "well-pipelined" reference ] # Operand-staging variants exercised per shape. VARIANTS = ["ref_ref", "load_ref", "load_load"] # Engines whose timings we collect (component_id suffix match). ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"] # Per-stage breakdown labels (StageType enum names from pe_types.py). STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"] # Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N). TILE_M, TILE_K, TILE_N = 32, 64, 32 OUT_PATH = Path(__file__).parent.parent / "docs" / "diagrams" / "gemm_sweep.json" def _engine_wall_ns(records, suffix: str) -> float: """Wall-clock interval the engine was active (union of overlapping ops).""" intervals = [(r.t_start, r.t_end) for r in records if r.component_id.endswith("." + suffix)] if not intervals: return 0.0 intervals.sort() merged_end = intervals[0][1] merged_start = intervals[0][0] total = 0.0 for s, e in intervals[1:]: if s <= merged_end: merged_end = max(merged_end, e) else: total += merged_end - merged_start merged_start, merged_end = s, e total += merged_end - merged_start return total def _engine_occupancy_ns(records, suffix: str) -> float: return sum(r.t_end - r.t_start for r in records if r.component_id.endswith("." + suffix)) def _engine_count(records, suffix: str) -> int: return sum(1 for r in records if r.component_id.endswith("." + suffix)) def _stage_occupancy_ns(records, stage_type: str) -> float: """Sum t_end - t_start over op_log records whose params.stage_type matches. Requires op_log records produced post the TileToken stage_type capture (sim_engine/op_log.py). """ return sum( r.t_end - r.t_start for r in records if r.params.get("stage_type") == stage_type ) def _stage_wall_ns(records, stage_type: str) -> float: """Interval-union wall-clock for records whose stage_type matches.""" intervals = sorted( (r.t_start, r.t_end) for r in records if r.params.get("stage_type") == stage_type ) if not intervals: return 0.0 total = 0.0 cs, ce = intervals[0] for s, e in intervals[1:]: if s <= ce: ce = max(ce, e) else: total += ce - cs cs, ce = s, e total += ce - cs return total def _stage_count(records, stage_type: str) -> int: return sum(1 for r in records if r.params.get("stage_type") == stage_type) def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict: os.environ["MATMUL_M"] = str(M) os.environ["MATMUL_K"] = str(K) os.environ["MATMUL_N"] = str(N) os.environ["MATMUL_VARIANT"] = variant # Late imports so env vars are read by benches/matmul_composite at module load. # Force re-import to pick up new env values. for mod_name in [m for m in list(sys.modules) if m.startswith("benches.matmul_composite")]: del sys.modules[mod_name] from benches.loader import resolve_bench from kernbench.runtime_api.bench_runner import run_bench from kernbench.runtime_api.types import resolve_device from kernbench.sim_engine.engine import GraphEngine from kernbench.topology.builder import resolve_topology topo = resolve_topology(topology) bench = resolve_bench("matmul_composite") device = resolve_device(None) t0 = time.time() result = run_bench( topology=topo, bench_fn=bench, device=device, engine_factory=lambda t, d: GraphEngine( getattr(t, "topology_obj", t), enable_data=True, ), ) wall = time.time() - t0 op_log = result.engine.op_log if not result.completion.ok: raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}") # Bytes touched at f16 (2 B): full A + full B + full out (each operand # streamed once through HBM by the composite plan). bytes_total = (M * K + K * N + M * N) * 2 row = { "M": M, "K": K, "N": N, "variant": variant, "flops": 2 * M * K * N, "bytes_hbm": bytes_total, "arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte "tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K), "sim_wall_clock_s": round(wall, 3), "engines": {}, } for eng in ENGINES: row["engines"][eng] = { "occupancy_ns": _engine_occupancy_ns(op_log, eng), "wall_ns": _engine_wall_ns(op_log, eng), "record_count": _engine_count(op_log, eng), } row["stages"] = {} for stage in STAGES: row["stages"][stage] = { "occupancy_ns": _stage_occupancy_ns(op_log, stage), "wall_ns": _stage_wall_ns(op_log, stage), "record_count": _stage_count(op_log, stage), } # Kernel-window wall-clock = max t_end - min t_start over PE engine records. pe_records = [r for r in op_log if any(r.component_id.endswith("." + e) for e in ENGINES)] if pe_records: row["pe_window_ns"] = max(r.t_end for r in pe_records) \ - min(r.t_start for r in pe_records) else: row["pe_window_ns"] = 0.0 stage_records = [r for r in op_log if r.params.get("stage_type") in STAGES] if stage_records: row["composite_window_ns"] = max(r.t_end for r in stage_records) \ - min(r.t_start for r in stage_records) else: row["composite_window_ns"] = 0.0 return row def _ceil(a: int, b: int) -> int: return (a + b - 1) // b def main() -> int: shapes_env = os.environ.get("SWEEP_SHAPES") raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES) shapes: list[tuple[int, int, int]] = [] for s in raw: s = s.strip() if not s: continue if "x" in s.lower(): parts = s.lower().split("x") shapes.append((int(parts[0]), int(parts[1]), int(parts[2]))) else: v = int(s) shapes.append((v, v, v)) topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml") rows = [] for M, K, N in shapes: for variant in VARIANTS: print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True) row = _run_one(M, K, N, topology, variant=variant) rows.append(row) eng_dma = row["engines"]["pe_dma"] eng_gem = row["engines"]["pe_gemm"] print(f" tiles={row['tile_count_expected']:>6} " f"pe_window={row['pe_window_ns']:8.1f}ns " f"dma_occ={eng_dma['occupancy_ns']:9.1f} " f"gemm_occ={eng_gem['occupancy_ns']:8.1f} " f"(sim {row['sim_wall_clock_s']:.1f}s)") OUT_PATH.parent.mkdir(parents=True, exist_ok=True) OUT_PATH.write_text(json.dumps({ "tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N}, "engines": ENGINES, "stages": STAGES, "variants": VARIANTS, "rows": rows, }, indent=2)) print(f"\n[sweep] wrote {OUT_PATH}") return 0 if __name__ == "__main__": raise SystemExit(main())