diff --git a/src/kernbench/cli/main.py b/src/kernbench/cli/main.py index 9163c08..63b0811 100644 --- a/src/kernbench/cli/main.py +++ b/src/kernbench/cli/main.py @@ -21,6 +21,10 @@ def build_parser() -> argparse.ArgumentParser: runp.add_argument( "--device", default=None, help="Target device: 'all' or 'sip:' (default: all)" ) + runp.add_argument( + "--verify-data", action="store_true", default=False, + help="Enable Phase 2 data verification (ADR-0020)", + ) runp.set_defaults(_handler=cmd_run) probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns") @@ -36,9 +40,11 @@ def build_parser() -> argparse.ArgumentParser: return p -def engine_factory(topology: object, device: DeviceSelector) -> SimEngine: +def engine_factory( + topology: object, device: DeviceSelector, *, enable_data: bool = False, +) -> SimEngine: topo_obj = getattr(topology, "topology_obj", topology) - return GraphEngine(topo_obj) + return GraphEngine(topo_obj, enable_data=enable_data) def cmd_web(args) -> int: @@ -53,8 +59,12 @@ def cmd_run(args) -> int: topo = resolve_topology(args.topology) bench = resolve_bench(args.bench) device = resolve_device(args.device) + verify_data = getattr(args, "verify_data", False) - result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory) + def _factory(topology, device): + return engine_factory(topology, device, enable_data=verify_data) + + result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory) topo_obj = getattr(topo, "topology_obj", topo) spec = getattr(topo_obj, "spec", None) @@ -62,6 +72,21 @@ def cmd_run(args) -> int: print(format_report(result.traces, title=args.bench, spec=spec)) print(result.summary_text()) + # Phase 2: data execution (ADR-0020) + if verify_data and result.engine is not None: + from kernbench.sim_engine.data_executor import DataExecutor + + op_log = result.engine.op_log + store = result.engine.memory_store + if op_log and store is not None: + executor = DataExecutor(op_log, store) + executor.run() + n_gemm = sum(1 for r in op_log if r.op_kind == "gemm") + n_math = sum(1 for r in op_log if r.op_kind == "math") + print(f"[data] Phase 2 complete: {len(op_log)} ops ({n_gemm} gemm, {n_math} math)") + else: + print("[data] No op_log recorded — skipping Phase 2") + return 0 if result.completion.ok else 1 diff --git a/src/kernbench/runtime_api/bench_runner.py b/src/kernbench/runtime_api/bench_runner.py index 164a6a7..4f478f7 100644 --- a/src/kernbench/runtime_api/bench_runner.py +++ b/src/kernbench/runtime_api/bench_runner.py @@ -62,6 +62,7 @@ def run_bench( correlation_id=correlation_id, trace=None, traces=collected_traces, + engine=engine, ) if completion_policy == CompletionPolicy.LAST_SUBMITTED: @@ -69,7 +70,7 @@ def run_bench( completion, trace = engine.get_completion(last) return BenchResult( completion=completion, correlation_id=correlation_id, - trace=trace, traces=collected_traces, + trace=trace, traces=collected_traces, engine=engine, ) if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST: @@ -80,11 +81,11 @@ def run_bench( if not c.ok: return BenchResult( completion=c, correlation_id=correlation_id, - trace=last_trace, traces=collected_traces, + trace=last_trace, traces=collected_traces, engine=engine, ) return BenchResult( completion=Completion(ok=True), correlation_id=correlation_id, - trace=last_trace, traces=collected_traces, + trace=last_trace, traces=collected_traces, engine=engine, ) # LAST_COMPLETED placeholder (needs engine support for timing). Fall back. @@ -92,5 +93,5 @@ def run_bench( completion, trace = engine.get_completion(last) return BenchResult( completion=completion, correlation_id=correlation_id, - trace=trace, traces=collected_traces, + trace=trace, traces=collected_traces, engine=engine, ) diff --git a/src/kernbench/runtime_api/context.py b/src/kernbench/runtime_api/context.py index 1934152..b522ed5 100644 --- a/src/kernbench/runtime_api/context.py +++ b/src/kernbench/runtime_api/context.py @@ -314,6 +314,7 @@ class RuntimeContext: t._handle = handle import weakref t._ctx_ref = weakref.ref(self) + t._memory_store = getattr(self.engine, "_memory_store", None) self._tensors.append(weakref.ref(t)) # Install VA→PA mappings via fabric MmuMapMsg diff --git a/src/kernbench/runtime_api/tensor.py b/src/kernbench/runtime_api/tensor.py index 88ff5a3..7fa40c6 100644 --- a/src/kernbench/runtime_api/tensor.py +++ b/src/kernbench/runtime_api/tensor.py @@ -5,6 +5,8 @@ import weakref from dataclasses import dataclass from typing import Literal +import numpy as np + from kernbench.policy.address.allocator import PEMemAllocator from kernbench.policy.placement.dp import DPPolicy, ShardSpec from kernbench.runtime_api.kernel import TensorArg, TensorArgShard @@ -50,6 +52,20 @@ def dtype_itemsize(dtype: str) -> int: return _DTYPE_ITEMSIZE[dtype] +_NUMPY_DTYPE = { + "f16": np.float16, "fp16": np.float16, "float16": np.float16, + "f32": np.float32, "fp32": np.float32, "float32": np.float32, + "bf16": np.float16, + "i8": np.int8, "int8": np.int8, + "i16": np.int16, "int16": np.int16, + "i32": np.int32, "int32": np.int32, +} + + +def _numpy_dtype(dtype: str) -> np.dtype: + return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16)) + + def deploy_tensor( *, name: str, @@ -129,6 +145,7 @@ class Tensor: self._dp_metadata: DPMetadata | None = None self._handle: TensorHandle | None = None self._ctx_ref: weakref.ref | None = None # set by RuntimeContext + self._memory_store = None # set by RuntimeContext when enable_data=True def __del__(self) -> None: if self._ctx_ref is None or self._handle is None: @@ -137,6 +154,28 @@ class Tensor: if ctx is not None: ctx._free_tensor(self) + def __repr__(self) -> str: + parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"] + if self._memory_store is not None and self._handle is not None: + arr = self.data + parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}") + else: + parts.append(", data=N/A (placeholder)") + parts.append(")") + return "".join(parts) + + @property + def data(self) -> np.ndarray: + """Tensor data as numpy array. Returns actual values when enable_data=True, + zeros placeholder otherwise (like an uninitialized tensor).""" + if self._memory_store is not None and self._handle is not None: + shard = self._handle.shards[0] + try: + return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype) + except KeyError: + pass + return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype)) + @property def itemsize(self) -> int: return dtype_itemsize(self.dtype) diff --git a/src/kernbench/runtime_api/types.py b/src/kernbench/runtime_api/types.py index a484ac1..3654a82 100644 --- a/src/kernbench/runtime_api/types.py +++ b/src/kernbench/runtime_api/types.py @@ -12,6 +12,7 @@ class BenchResult: correlation_id: str trace: Trace | None = None traces: list[dict] | None = None + engine: object | None = None # GraphEngine ref for Phase 2 data access def summary_text(self) -> str: if self.completion.ok: diff --git a/src/kernbench/sim_engine/data_executor.py b/src/kernbench/sim_engine/data_executor.py index e546ffb..fe0be02 100644 --- a/src/kernbench/sim_engine/data_executor.py +++ b/src/kernbench/sim_engine/data_executor.py @@ -6,6 +6,7 @@ Same-timestamp independent ops can be batched for efficiency. """ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from itertools import groupby from typing import Any @@ -28,11 +29,18 @@ class DataExecutor: self.store = store def run(self) -> None: - """Execute all ops in op_log order, grouped by t_start.""" - for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start): - ops = list(ops_iter) - for op in ops: - self._execute_op(op) + """Execute all ops in op_log order, grouped by t_start. + + Same-timestamp ops are independent and executed in parallel + via ThreadPoolExecutor (numpy releases the GIL for BLAS ops). + """ + with ThreadPoolExecutor() as pool: + for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start): + ops = list(ops_iter) + if len(ops) == 1: + self._execute_op(ops[0]) + else: + list(pool.map(self._execute_op, ops)) def _execute_op(self, op: OpRecord) -> None: if op.op_kind == "memory": diff --git a/tests/test_cli_verify_data.py b/tests/test_cli_verify_data.py new file mode 100644 index 0000000..9c71f79 --- /dev/null +++ b/tests/test_cli_verify_data.py @@ -0,0 +1,49 @@ +"""Tests for --verify-data CLI flag (Phase 1 verification).""" +import kernbench.cli.main as cli_main + + +def test_cli_verify_data_flag_parsed(monkeypatch): + """--verify-data flag is parsed and stored as True.""" + + def fake_cmd_run(args) -> int: + assert args.verify_data is True + return 0 + + monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run) + rc = cli_main.main([ + "run", "--topology", "topology.yaml", "--bench", "qkv_gemm", + "--verify-data", + ]) + assert rc == 0 + + +def test_cli_verify_data_flag_default(monkeypatch): + """Without --verify-data, flag defaults to False.""" + + def fake_cmd_run(args) -> int: + assert args.verify_data is False + return 0 + + monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run) + rc = cli_main.main([ + "run", "--topology", "topology.yaml", "--bench", "qkv_gemm", + ]) + assert rc == 0 + + +def test_cmd_run_verify_data_enables_engine(): + """--verify-data runs full pipeline with enable_data=True and DataExecutor.""" + rc = cli_main.main([ + "run", "--topology", "topology.yaml", "--bench", "qkv_gemm", + "--device", "sip:0", "--verify-data", + ]) + assert rc == 0 + + +def test_cmd_run_without_verify_data_no_op_log(): + """Without --verify-data, engine runs in timing-only mode (no op_log).""" + rc = cli_main.main([ + "run", "--topology", "topology.yaml", "--bench", "qkv_gemm", + "--device", "sip:0", + ]) + assert rc == 0 diff --git a/tests/test_data_executor.py b/tests/test_data_executor.py index af1aa97..0204ba8 100644 --- a/tests/test_data_executor.py +++ b/tests/test_data_executor.py @@ -186,3 +186,41 @@ def test_sequential_gemm_then_math(): exp_result = store.read("tcm", 0x300) assert np.allclose(exp_result, np.exp(expected_gemm)) + + +def test_parallel_same_timestamp_ops(): + """Multiple independent ops at the same t_start produce correct results + when executed in parallel (ThreadPoolExecutor).""" + store = MemoryStore() + n_ops = 8 + # Each op: independent GEMM writing to a different address + for i in range(n_ops): + a = np.full((4, 4), float(i + 1), dtype=np.float16) + b = np.eye(4, dtype=np.float16) + store.write("tcm", 0x1000 * i, a) + store.write("tcm", 0x1000 * i + 0x800, b) + + ops = [ + OpRecord( + t_start=0.0, t_end=100.0, + component_id=f"pe{i}.pe_gemm", + op_kind="gemm", op_name="gemm_f16", + params={ + "src_a_addr": 0x1000 * i, + "src_b_addr": 0x1000 * i + 0x800, + "dst_addr": 0x80000 + 0x1000 * i, + "shape_a": (4, 4), "shape_b": (4, 4), "shape_out": (4, 4), + "dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16", + "addr_space": "tcm", + }, + ) + for i in range(n_ops) + ] + + executor = DataExecutor(ops, store) + executor.run() + + for i in range(n_ops): + result = store.read("tcm", 0x80000 + 0x1000 * i) + expected = np.full((4, 4), float(i + 1), dtype=np.float16) + assert np.allclose(result, expected), f"op {i}: expected {expected}, got {result}"