Add --verify-data CLI flag, Tensor.data property, parallel DataExecutor

- CLI: --verify-data flag enables Phase 2 data verification (ADR-0020)
- Tensor.data: returns actual numpy values (verify-data) or zeros placeholder
- Tensor.__repr__: shows value summary or data=N/A (placeholder)
- DataExecutor: ThreadPoolExecutor for same-timestamp parallel op execution
- BenchResult.engine: exposes op_log/memory_store for Phase 2 access

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-09 09:34:01 -07:00
parent 59e36f0c34
commit dc3fb02aed
8 changed files with 174 additions and 12 deletions
+28 -3
View File
@@ -21,6 +21,10 @@ def build_parser() -> argparse.ArgumentParser:
runp.add_argument(
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
)
runp.add_argument(
"--verify-data", action="store_true", default=False,
help="Enable Phase 2 data verification (ADR-0020)",
)
runp.set_defaults(_handler=cmd_run)
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
@@ -36,9 +40,11 @@ def build_parser() -> argparse.ArgumentParser:
return p
def engine_factory(topology: object, device: DeviceSelector) -> SimEngine:
def engine_factory(
topology: object, device: DeviceSelector, *, enable_data: bool = False,
) -> SimEngine:
topo_obj = getattr(topology, "topology_obj", topology)
return GraphEngine(topo_obj)
return GraphEngine(topo_obj, enable_data=enable_data)
def cmd_web(args) -> int:
@@ -53,8 +59,12 @@ def cmd_run(args) -> int:
topo = resolve_topology(args.topology)
bench = resolve_bench(args.bench)
device = resolve_device(args.device)
verify_data = getattr(args, "verify_data", False)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory)
def _factory(topology, device):
return engine_factory(topology, device, enable_data=verify_data)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory)
topo_obj = getattr(topo, "topology_obj", topo)
spec = getattr(topo_obj, "spec", None)
@@ -62,6 +72,21 @@ def cmd_run(args) -> int:
print(format_report(result.traces, title=args.bench, spec=spec))
print(result.summary_text())
# Phase 2: data execution (ADR-0020)
if verify_data and result.engine is not None:
from kernbench.sim_engine.data_executor import DataExecutor
op_log = result.engine.op_log
store = result.engine.memory_store
if op_log and store is not None:
executor = DataExecutor(op_log, store)
executor.run()
n_gemm = sum(1 for r in op_log if r.op_kind == "gemm")
n_math = sum(1 for r in op_log if r.op_kind == "math")
print(f"[data] Phase 2 complete: {len(op_log)} ops ({n_gemm} gemm, {n_math} math)")
else:
print("[data] No op_log recorded — skipping Phase 2")
return 0 if result.completion.ok else 1
+5 -4
View File
@@ -62,6 +62,7 @@ def run_bench(
correlation_id=correlation_id,
trace=None,
traces=collected_traces,
engine=engine,
)
if completion_policy == CompletionPolicy.LAST_SUBMITTED:
@@ -69,7 +70,7 @@ def run_bench(
completion, trace = engine.get_completion(last)
return BenchResult(
completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces,
trace=trace, traces=collected_traces, engine=engine,
)
if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST:
@@ -80,11 +81,11 @@ def run_bench(
if not c.ok:
return BenchResult(
completion=c, correlation_id=correlation_id,
trace=last_trace, traces=collected_traces,
trace=last_trace, traces=collected_traces, engine=engine,
)
return BenchResult(
completion=Completion(ok=True), correlation_id=correlation_id,
trace=last_trace, traces=collected_traces,
trace=last_trace, traces=collected_traces, engine=engine,
)
# LAST_COMPLETED placeholder (needs engine support for timing). Fall back.
@@ -92,5 +93,5 @@ def run_bench(
completion, trace = engine.get_completion(last)
return BenchResult(
completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces,
trace=trace, traces=collected_traces, engine=engine,
)
+1
View File
@@ -314,6 +314,7 @@ class RuntimeContext:
t._handle = handle
import weakref
t._ctx_ref = weakref.ref(self)
t._memory_store = getattr(self.engine, "_memory_store", None)
self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg
+39
View File
@@ -5,6 +5,8 @@ import weakref
from dataclasses import dataclass
from typing import Literal
import numpy as np
from kernbench.policy.address.allocator import PEMemAllocator
from kernbench.policy.placement.dp import DPPolicy, ShardSpec
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
@@ -50,6 +52,20 @@ def dtype_itemsize(dtype: str) -> int:
return _DTYPE_ITEMSIZE[dtype]
_NUMPY_DTYPE = {
"f16": np.float16, "fp16": np.float16, "float16": np.float16,
"f32": np.float32, "fp32": np.float32, "float32": np.float32,
"bf16": np.float16,
"i8": np.int8, "int8": np.int8,
"i16": np.int16, "int16": np.int16,
"i32": np.int32, "int32": np.int32,
}
def _numpy_dtype(dtype: str) -> np.dtype:
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
def deploy_tensor(
*,
name: str,
@@ -129,6 +145,7 @@ class Tensor:
self._dp_metadata: DPMetadata | None = None
self._handle: TensorHandle | None = None
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
self._memory_store = None # set by RuntimeContext when enable_data=True
def __del__(self) -> None:
if self._ctx_ref is None or self._handle is None:
@@ -137,6 +154,28 @@ class Tensor:
if ctx is not None:
ctx._free_tensor(self)
def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None:
arr = self.data
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
else:
parts.append(", data=N/A (placeholder)")
parts.append(")")
return "".join(parts)
@property
def data(self) -> np.ndarray:
"""Tensor data as numpy array. Returns actual values when enable_data=True,
zeros placeholder otherwise (like an uninitialized tensor)."""
if self._memory_store is not None and self._handle is not None:
shard = self._handle.shards[0]
try:
return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype)
except KeyError:
pass
return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype))
@property
def itemsize(self) -> int:
return dtype_itemsize(self.dtype)
+1
View File
@@ -12,6 +12,7 @@ class BenchResult:
correlation_id: str
trace: Trace | None = None
traces: list[dict] | None = None
engine: object | None = None # GraphEngine ref for Phase 2 data access
def summary_text(self) -> str:
if self.completion.ok:
+13 -5
View File
@@ -6,6 +6,7 @@ Same-timestamp independent ops can be batched for efficiency.
"""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from typing import Any
@@ -28,11 +29,18 @@ class DataExecutor:
self.store = store
def run(self) -> None:
"""Execute all ops in op_log order, grouped by t_start."""
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
ops = list(ops_iter)
for op in ops:
self._execute_op(op)
"""Execute all ops in op_log order, grouped by t_start.
Same-timestamp ops are independent and executed in parallel
via ThreadPoolExecutor (numpy releases the GIL for BLAS ops).
"""
with ThreadPoolExecutor() as pool:
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
ops = list(ops_iter)
if len(ops) == 1:
self._execute_op(ops[0])
else:
list(pool.map(self._execute_op, ops))
def _execute_op(self, op: OpRecord) -> None:
if op.op_kind == "memory":