Add --verify-data CLI flag, Tensor.data property, parallel DataExecutor
- CLI: --verify-data flag enables Phase 2 data verification (ADR-0020) - Tensor.data: returns actual numpy values (verify-data) or zeros placeholder - Tensor.__repr__: shows value summary or data=N/A (placeholder) - DataExecutor: ThreadPoolExecutor for same-timestamp parallel op execution - BenchResult.engine: exposes op_log/memory_store for Phase 2 access Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,6 +21,10 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
runp.add_argument(
|
||||
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
|
||||
)
|
||||
runp.add_argument(
|
||||
"--verify-data", action="store_true", default=False,
|
||||
help="Enable Phase 2 data verification (ADR-0020)",
|
||||
)
|
||||
runp.set_defaults(_handler=cmd_run)
|
||||
|
||||
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
|
||||
@@ -36,9 +40,11 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
return p
|
||||
|
||||
|
||||
def engine_factory(topology: object, device: DeviceSelector) -> SimEngine:
|
||||
def engine_factory(
|
||||
topology: object, device: DeviceSelector, *, enable_data: bool = False,
|
||||
) -> SimEngine:
|
||||
topo_obj = getattr(topology, "topology_obj", topology)
|
||||
return GraphEngine(topo_obj)
|
||||
return GraphEngine(topo_obj, enable_data=enable_data)
|
||||
|
||||
|
||||
def cmd_web(args) -> int:
|
||||
@@ -53,8 +59,12 @@ def cmd_run(args) -> int:
|
||||
topo = resolve_topology(args.topology)
|
||||
bench = resolve_bench(args.bench)
|
||||
device = resolve_device(args.device)
|
||||
verify_data = getattr(args, "verify_data", False)
|
||||
|
||||
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory)
|
||||
def _factory(topology, device):
|
||||
return engine_factory(topology, device, enable_data=verify_data)
|
||||
|
||||
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=_factory)
|
||||
|
||||
topo_obj = getattr(topo, "topology_obj", topo)
|
||||
spec = getattr(topo_obj, "spec", None)
|
||||
@@ -62,6 +72,21 @@ def cmd_run(args) -> int:
|
||||
print(format_report(result.traces, title=args.bench, spec=spec))
|
||||
print(result.summary_text())
|
||||
|
||||
# Phase 2: data execution (ADR-0020)
|
||||
if verify_data and result.engine is not None:
|
||||
from kernbench.sim_engine.data_executor import DataExecutor
|
||||
|
||||
op_log = result.engine.op_log
|
||||
store = result.engine.memory_store
|
||||
if op_log and store is not None:
|
||||
executor = DataExecutor(op_log, store)
|
||||
executor.run()
|
||||
n_gemm = sum(1 for r in op_log if r.op_kind == "gemm")
|
||||
n_math = sum(1 for r in op_log if r.op_kind == "math")
|
||||
print(f"[data] Phase 2 complete: {len(op_log)} ops ({n_gemm} gemm, {n_math} math)")
|
||||
else:
|
||||
print("[data] No op_log recorded — skipping Phase 2")
|
||||
|
||||
return 0 if result.completion.ok else 1
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ def run_bench(
|
||||
correlation_id=correlation_id,
|
||||
trace=None,
|
||||
traces=collected_traces,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
if completion_policy == CompletionPolicy.LAST_SUBMITTED:
|
||||
@@ -69,7 +70,7 @@ def run_bench(
|
||||
completion, trace = engine.get_completion(last)
|
||||
return BenchResult(
|
||||
completion=completion, correlation_id=correlation_id,
|
||||
trace=trace, traces=collected_traces,
|
||||
trace=trace, traces=collected_traces, engine=engine,
|
||||
)
|
||||
|
||||
if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST:
|
||||
@@ -80,11 +81,11 @@ def run_bench(
|
||||
if not c.ok:
|
||||
return BenchResult(
|
||||
completion=c, correlation_id=correlation_id,
|
||||
trace=last_trace, traces=collected_traces,
|
||||
trace=last_trace, traces=collected_traces, engine=engine,
|
||||
)
|
||||
return BenchResult(
|
||||
completion=Completion(ok=True), correlation_id=correlation_id,
|
||||
trace=last_trace, traces=collected_traces,
|
||||
trace=last_trace, traces=collected_traces, engine=engine,
|
||||
)
|
||||
|
||||
# LAST_COMPLETED placeholder (needs engine support for timing). Fall back.
|
||||
@@ -92,5 +93,5 @@ def run_bench(
|
||||
completion, trace = engine.get_completion(last)
|
||||
return BenchResult(
|
||||
completion=completion, correlation_id=correlation_id,
|
||||
trace=trace, traces=collected_traces,
|
||||
trace=trace, traces=collected_traces, engine=engine,
|
||||
)
|
||||
|
||||
@@ -314,6 +314,7 @@ class RuntimeContext:
|
||||
t._handle = handle
|
||||
import weakref
|
||||
t._ctx_ref = weakref.ref(self)
|
||||
t._memory_store = getattr(self.engine, "_memory_store", None)
|
||||
self._tensors.append(weakref.ref(t))
|
||||
|
||||
# Install VA→PA mappings via fabric MmuMapMsg
|
||||
|
||||
@@ -5,6 +5,8 @@ import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.policy.address.allocator import PEMemAllocator
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec
|
||||
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
|
||||
@@ -50,6 +52,20 @@ def dtype_itemsize(dtype: str) -> int:
|
||||
return _DTYPE_ITEMSIZE[dtype]
|
||||
|
||||
|
||||
_NUMPY_DTYPE = {
|
||||
"f16": np.float16, "fp16": np.float16, "float16": np.float16,
|
||||
"f32": np.float32, "fp32": np.float32, "float32": np.float32,
|
||||
"bf16": np.float16,
|
||||
"i8": np.int8, "int8": np.int8,
|
||||
"i16": np.int16, "int16": np.int16,
|
||||
"i32": np.int32, "int32": np.int32,
|
||||
}
|
||||
|
||||
|
||||
def _numpy_dtype(dtype: str) -> np.dtype:
|
||||
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
|
||||
|
||||
|
||||
def deploy_tensor(
|
||||
*,
|
||||
name: str,
|
||||
@@ -129,6 +145,7 @@ class Tensor:
|
||||
self._dp_metadata: DPMetadata | None = None
|
||||
self._handle: TensorHandle | None = None
|
||||
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
||||
self._memory_store = None # set by RuntimeContext when enable_data=True
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._ctx_ref is None or self._handle is None:
|
||||
@@ -137,6 +154,28 @@ class Tensor:
|
||||
if ctx is not None:
|
||||
ctx._free_tensor(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
arr = self.data
|
||||
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
|
||||
else:
|
||||
parts.append(", data=N/A (placeholder)")
|
||||
parts.append(")")
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
"""Tensor data as numpy array. Returns actual values when enable_data=True,
|
||||
zeros placeholder otherwise (like an uninitialized tensor)."""
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
shard = self._handle.shards[0]
|
||||
try:
|
||||
return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype)
|
||||
except KeyError:
|
||||
pass
|
||||
return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype))
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int:
|
||||
return dtype_itemsize(self.dtype)
|
||||
|
||||
@@ -12,6 +12,7 @@ class BenchResult:
|
||||
correlation_id: str
|
||||
trace: Trace | None = None
|
||||
traces: list[dict] | None = None
|
||||
engine: object | None = None # GraphEngine ref for Phase 2 data access
|
||||
|
||||
def summary_text(self) -> str:
|
||||
if self.completion.ok:
|
||||
|
||||
@@ -6,6 +6,7 @@ Same-timestamp independent ops can be batched for efficiency.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from typing import Any
|
||||
|
||||
@@ -28,11 +29,18 @@ class DataExecutor:
|
||||
self.store = store
|
||||
|
||||
def run(self) -> None:
|
||||
"""Execute all ops in op_log order, grouped by t_start."""
|
||||
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
|
||||
ops = list(ops_iter)
|
||||
for op in ops:
|
||||
self._execute_op(op)
|
||||
"""Execute all ops in op_log order, grouped by t_start.
|
||||
|
||||
Same-timestamp ops are independent and executed in parallel
|
||||
via ThreadPoolExecutor (numpy releases the GIL for BLAS ops).
|
||||
"""
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
|
||||
ops = list(ops_iter)
|
||||
if len(ops) == 1:
|
||||
self._execute_op(ops[0])
|
||||
else:
|
||||
list(pool.map(self._execute_op, ops))
|
||||
|
||||
def _execute_op(self, op: OpRecord) -> None:
|
||||
if op.op_kind == "memory":
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Tests for --verify-data CLI flag (Phase 1 verification)."""
|
||||
import kernbench.cli.main as cli_main
|
||||
|
||||
|
||||
def test_cli_verify_data_flag_parsed(monkeypatch):
|
||||
"""--verify-data flag is parsed and stored as True."""
|
||||
|
||||
def fake_cmd_run(args) -> int:
|
||||
assert args.verify_data is True
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
|
||||
rc = cli_main.main([
|
||||
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
|
||||
"--verify-data",
|
||||
])
|
||||
assert rc == 0
|
||||
|
||||
|
||||
def test_cli_verify_data_flag_default(monkeypatch):
|
||||
"""Without --verify-data, flag defaults to False."""
|
||||
|
||||
def fake_cmd_run(args) -> int:
|
||||
assert args.verify_data is False
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
|
||||
rc = cli_main.main([
|
||||
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
|
||||
])
|
||||
assert rc == 0
|
||||
|
||||
|
||||
def test_cmd_run_verify_data_enables_engine():
|
||||
"""--verify-data runs full pipeline with enable_data=True and DataExecutor."""
|
||||
rc = cli_main.main([
|
||||
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
|
||||
"--device", "sip:0", "--verify-data",
|
||||
])
|
||||
assert rc == 0
|
||||
|
||||
|
||||
def test_cmd_run_without_verify_data_no_op_log():
|
||||
"""Without --verify-data, engine runs in timing-only mode (no op_log)."""
|
||||
rc = cli_main.main([
|
||||
"run", "--topology", "topology.yaml", "--bench", "qkv_gemm",
|
||||
"--device", "sip:0",
|
||||
])
|
||||
assert rc == 0
|
||||
@@ -186,3 +186,41 @@ def test_sequential_gemm_then_math():
|
||||
|
||||
exp_result = store.read("tcm", 0x300)
|
||||
assert np.allclose(exp_result, np.exp(expected_gemm))
|
||||
|
||||
|
||||
def test_parallel_same_timestamp_ops():
|
||||
"""Multiple independent ops at the same t_start produce correct results
|
||||
when executed in parallel (ThreadPoolExecutor)."""
|
||||
store = MemoryStore()
|
||||
n_ops = 8
|
||||
# Each op: independent GEMM writing to a different address
|
||||
for i in range(n_ops):
|
||||
a = np.full((4, 4), float(i + 1), dtype=np.float16)
|
||||
b = np.eye(4, dtype=np.float16)
|
||||
store.write("tcm", 0x1000 * i, a)
|
||||
store.write("tcm", 0x1000 * i + 0x800, b)
|
||||
|
||||
ops = [
|
||||
OpRecord(
|
||||
t_start=0.0, t_end=100.0,
|
||||
component_id=f"pe{i}.pe_gemm",
|
||||
op_kind="gemm", op_name="gemm_f16",
|
||||
params={
|
||||
"src_a_addr": 0x1000 * i,
|
||||
"src_b_addr": 0x1000 * i + 0x800,
|
||||
"dst_addr": 0x80000 + 0x1000 * i,
|
||||
"shape_a": (4, 4), "shape_b": (4, 4), "shape_out": (4, 4),
|
||||
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16",
|
||||
"addr_space": "tcm",
|
||||
},
|
||||
)
|
||||
for i in range(n_ops)
|
||||
]
|
||||
|
||||
executor = DataExecutor(ops, store)
|
||||
executor.run()
|
||||
|
||||
for i in range(n_ops):
|
||||
result = store.read("tcm", 0x80000 + 0x1000 * i)
|
||||
expected = np.full((4, 4), float(i + 1), dtype=np.float16)
|
||||
assert np.allclose(result, expected), f"op {i}: expected {expected}, got {result}"
|
||||
|
||||
Reference in New Issue
Block a user