Add --verify-data CLI flag, Tensor.data property, parallel DataExecutor

- CLI: --verify-data flag enables Phase 2 data verification (ADR-0020)
- Tensor.data: returns actual numpy values (verify-data) or zeros placeholder
- Tensor.__repr__: shows value summary or data=N/A (placeholder)
- DataExecutor: ThreadPoolExecutor for same-timestamp parallel op execution
- BenchResult.engine: exposes op_log/memory_store for Phase 2 access

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-09 09:34:01 -07:00
parent 59e36f0c34
commit dc3fb02aed
8 changed files with 174 additions and 12 deletions
+28 -3
View File
@@ -21,6 +21,10 @@ def build_parser() -> argparse.ArgumentParser:
runp.add_argument( runp.add_argument(
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)" "--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
) )
runp.add_argument(
"--verify-data", action="store_true", default=False,
help="Enable Phase 2 data verification (ADR-0020)",
)
runp.set_defaults(_handler=cmd_run) runp.set_defaults(_handler=cmd_run)
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns") probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
@@ -36,9 +40,11 @@ def build_parser() -> argparse.ArgumentParser:
return p return p
def engine_factory(topology: object, device: DeviceSelector) -> SimEngine: def engine_factory(
topology: object, device: DeviceSelector, *, enable_data: bool = False,
) -> SimEngine:
topo_obj = getattr(topology, "topology_obj", topology) topo_obj = getattr(topology, "topology_obj", topology)
return GraphEngine(topo_obj) return GraphEngine(topo_obj, enable_data=enable_data)
def cmd_web(args) -> int: def cmd_web(args) -> int:
@@ -53,8 +59,12 @@ def cmd_run(args) -> int:
topo = resolve_topology(args.topology) topo = resolve_topology(args.topology)
bench = resolve_bench(args.bench) bench = resolve_bench(args.bench)
device = resolve_device(args.device) device = resolve_device(args.device)
verify_data = getattr(args, "verify_data", False)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory) def _factory(topology, device):
return engine_factory(topology, device, enable_data=verify_data)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory)
topo_obj = getattr(topo, "topology_obj", topo) topo_obj = getattr(topo, "topology_obj", topo)
spec = getattr(topo_obj, "spec", None) spec = getattr(topo_obj, "spec", None)
@@ -62,6 +72,21 @@ def cmd_run(args) -> int:
print(format_report(result.traces, title=args.bench, spec=spec)) print(format_report(result.traces, title=args.bench, spec=spec))
print(result.summary_text()) print(result.summary_text())
# Phase 2: data execution (ADR-0020)
if verify_data and result.engine is not None:
from kernbench.sim_engine.data_executor import DataExecutor
op_log = result.engine.op_log
store = result.engine.memory_store
if op_log and store is not None:
executor = DataExecutor(op_log, store)
executor.run()
n_gemm = sum(1 for r in op_log if r.op_kind == "gemm")
n_math = sum(1 for r in op_log if r.op_kind == "math")
print(f"[data] Phase 2 complete: {len(op_log)} ops ({n_gemm} gemm, {n_math} math)")
else:
print("[data] No op_log recorded — skipping Phase 2")
return 0 if result.completion.ok else 1 return 0 if result.completion.ok else 1
+5 -4
View File
@@ -62,6 +62,7 @@ def run_bench(
correlation_id=correlation_id, correlation_id=correlation_id,
trace=None, trace=None,
traces=collected_traces, traces=collected_traces,
engine=engine,
) )
if completion_policy == CompletionPolicy.LAST_SUBMITTED: if completion_policy == CompletionPolicy.LAST_SUBMITTED:
@@ -69,7 +70,7 @@ def run_bench(
completion, trace = engine.get_completion(last) completion, trace = engine.get_completion(last)
return BenchResult( return BenchResult(
completion=completion, correlation_id=correlation_id, completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces, trace=trace, traces=collected_traces, engine=engine,
) )
if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST: if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST:
@@ -80,11 +81,11 @@ def run_bench(
if not c.ok: if not c.ok:
return BenchResult( return BenchResult(
completion=c, correlation_id=correlation_id, completion=c, correlation_id=correlation_id,
trace=last_trace, traces=collected_traces, trace=last_trace, traces=collected_traces, engine=engine,
) )
return BenchResult( return BenchResult(
completion=Completion(ok=True), correlation_id=correlation_id, completion=Completion(ok=True), correlation_id=correlation_id,
trace=last_trace, traces=collected_traces, trace=last_trace, traces=collected_traces, engine=engine,
) )
# LAST_COMPLETED placeholder (needs engine support for timing). Fall back. # LAST_COMPLETED placeholder (needs engine support for timing). Fall back.
@@ -92,5 +93,5 @@ def run_bench(
completion, trace = engine.get_completion(last) completion, trace = engine.get_completion(last)
return BenchResult( return BenchResult(
completion=completion, correlation_id=correlation_id, completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces, trace=trace, traces=collected_traces, engine=engine,
) )
+1
View File
@@ -314,6 +314,7 @@ class RuntimeContext:
t._handle = handle t._handle = handle
import weakref import weakref
t._ctx_ref = weakref.ref(self) t._ctx_ref = weakref.ref(self)
t._memory_store = getattr(self.engine, "_memory_store", None)
self._tensors.append(weakref.ref(t)) self._tensors.append(weakref.ref(t))
# Install VA→PA mappings via fabric MmuMapMsg # Install VA→PA mappings via fabric MmuMapMsg
+39
View File
@@ -5,6 +5,8 @@ import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
import numpy as np
from kernbench.policy.address.allocator import PEMemAllocator from kernbench.policy.address.allocator import PEMemAllocator
from kernbench.policy.placement.dp import DPPolicy, ShardSpec from kernbench.policy.placement.dp import DPPolicy, ShardSpec
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
@@ -50,6 +52,20 @@ def dtype_itemsize(dtype: str) -> int:
return _DTYPE_ITEMSIZE[dtype] return _DTYPE_ITEMSIZE[dtype]
_NUMPY_DTYPE = {
"f16": np.float16, "fp16": np.float16, "float16": np.float16,
"f32": np.float32, "fp32": np.float32, "float32": np.float32,
"bf16": np.float16,
"i8": np.int8, "int8": np.int8,
"i16": np.int16, "int16": np.int16,
"i32": np.int32, "int32": np.int32,
}
def _numpy_dtype(dtype: str) -> np.dtype:
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
def deploy_tensor( def deploy_tensor(
*, *,
name: str, name: str,
@@ -129,6 +145,7 @@ class Tensor:
self._dp_metadata: DPMetadata | None = None self._dp_metadata: DPMetadata | None = None
self._handle: TensorHandle | None = None self._handle: TensorHandle | None = None
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
self._memory_store = None # set by RuntimeContext when enable_data=True
def __del__(self) -> None: def __del__(self) -> None:
if self._ctx_ref is None or self._handle is None: if self._ctx_ref is None or self._handle is None:
@@ -137,6 +154,28 @@ class Tensor:
if ctx is not None: if ctx is not None:
ctx._free_tensor(self) ctx._free_tensor(self)
def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None:
arr = self.data
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
else:
parts.append(", data=N/A (placeholder)")
parts.append(")")
return "".join(parts)
@property
def data(self) -> np.ndarray:
"""Tensor data as numpy array. Returns actual values when enable_data=True,
zeros placeholder otherwise (like an uninitialized tensor)."""
if self._memory_store is not None and self._handle is not None:
shard = self._handle.shards[0]
try:
return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype)
except KeyError:
pass
return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype))
@property @property
def itemsize(self) -> int: def itemsize(self) -> int:
return dtype_itemsize(self.dtype) return dtype_itemsize(self.dtype)
+1
View File
@@ -12,6 +12,7 @@ class BenchResult:
correlation_id: str correlation_id: str
trace: Trace | None = None trace: Trace | None = None
traces: list[dict] | None = None traces: list[dict] | None = None
engine: object | None = None # GraphEngine ref for Phase 2 data access
def summary_text(self) -> str: def summary_text(self) -> str:
if self.completion.ok: if self.completion.ok:
+13 -5
View File
@@ -6,6 +6,7 @@ Same-timestamp independent ops can be batched for efficiency.
""" """
from __future__ import annotations from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from typing import Any from typing import Any
@@ -28,11 +29,18 @@ class DataExecutor:
self.store = store self.store = store
def run(self) -> None: def run(self) -> None:
"""Execute all ops in op_log order, grouped by t_start.""" """Execute all ops in op_log order, grouped by t_start.
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
ops = list(ops_iter) Same-timestamp ops are independent and executed in parallel
for op in ops: via ThreadPoolExecutor (numpy releases the GIL for BLAS ops).
self._execute_op(op) """
with ThreadPoolExecutor() as pool:
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
ops = list(ops_iter)
if len(ops) == 1:
self._execute_op(ops[0])
else:
list(pool.map(self._execute_op, ops))
def _execute_op(self, op: OpRecord) -> None: def _execute_op(self, op: OpRecord) -> None:
if op.op_kind == "memory": if op.op_kind == "memory":
+49
View File
@@ -0,0 +1,49 @@
"""Tests for --verify-data CLI flag (Phase 1 verification)."""
import kernbench.cli.main as cli_main
def test_cli_verify_data_flag_parsed(monkeypatch):
"""--verify-data flag is parsed and stored as True."""
def fake_cmd_run(args) -> int:
assert args.verify_data is True
return 0
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
"--verify-data",
])
assert rc == 0
def test_cli_verify_data_flag_default(monkeypatch):
"""Without --verify-data, flag defaults to False."""
def fake_cmd_run(args) -> int:
assert args.verify_data is False
return 0
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
])
assert rc == 0
def test_cmd_run_verify_data_enables_engine():
"""--verify-data runs full pipeline with enable_data=True and DataExecutor."""
rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
"--device", "sip:0", "--verify-data",
])
assert rc == 0
def test_cmd_run_without_verify_data_no_op_log():
"""Without --verify-data, engine runs in timing-only mode (no op_log)."""
rc = cli_main.main([
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
"--device", "sip:0",
])
assert rc == 0
+38
View File
@@ -186,3 +186,41 @@ def test_sequential_gemm_then_math():
exp_result = store.read("tcm", 0x300) exp_result = store.read("tcm", 0x300)
assert np.allclose(exp_result, np.exp(expected_gemm)) assert np.allclose(exp_result, np.exp(expected_gemm))
def test_parallel_same_timestamp_ops():
"""Multiple independent ops at the same t_start produce correct results
when executed in parallel (ThreadPoolExecutor)."""
store = MemoryStore()
n_ops = 8
# Each op: independent GEMM writing to a different address
for i in range(n_ops):
a = np.full((4, 4), float(i + 1), dtype=np.float16)
b = np.eye(4, dtype=np.float16)
store.write("tcm", 0x1000 * i, a)
store.write("tcm", 0x1000 * i + 0x800, b)
ops = [
OpRecord(
t_start=0.0, t_end=100.0,
component_id=f"pe{i}.pe_gemm",
op_kind="gemm", op_name="gemm_f16",
params={
"src_a_addr": 0x1000 * i,
"src_b_addr": 0x1000 * i + 0x800,
"dst_addr": 0x80000 + 0x1000 * i,
"shape_a": (4, 4), "shape_b": (4, 4), "shape_out": (4, 4),
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16",
"addr_space": "tcm",
},
)
for i in range(n_ops)
]
executor = DataExecutor(ops, store)
executor.run()
for i in range(n_ops):
result = store.read("tcm", 0x80000 + 0x1000 * i)
expected = np.full((4, 4), float(i + 1), dtype=np.float16)
assert np.allclose(result, expected), f"op {i}: expected {expected}, got {result}"