diff --git a/pyproject.toml b/pyproject.toml index 93ee45e..9f762ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "kernbench" version = "0.1.0" requires-python = ">=3.10" -dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12"] +dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"] [project.scripts] kernbench = "kernbench.cli.main:main" diff --git a/src/kernbench/common/pe_commands.py b/src/kernbench/common/pe_commands.py index c6bf991..ed0dc8b 100644 --- a/src/kernbench/common/pe_commands.py +++ b/src/kernbench/common/pe_commands.py @@ -55,6 +55,7 @@ class DmaReadCmd: handle: TensorHandle src_addr: int nbytes: int + data_op: bool = True @dataclass(frozen=True) @@ -64,6 +65,7 @@ class DmaWriteCmd: handle: TensorHandle dst_addr: int nbytes: int + data_op: bool = True @dataclass(frozen=True) @@ -79,6 +81,7 @@ class GemmCmd: m: int k: int n: int + data_op: bool = True @dataclass(frozen=True) @@ -94,6 +97,7 @@ class MathCmd: inputs: tuple[TensorHandle, ...] out: TensorHandle axis: int | None = None # for reductions + data_op: bool = True @dataclass(frozen=True) @@ -111,6 +115,7 @@ class CompositeCmd: out_addr: int out_nbytes: int math_op: str | None = None # for op="math": which math operation + data_op: bool = True @dataclass(frozen=True) diff --git a/src/kernbench/components/base.py b/src/kernbench/components/base.py index 58ec12c..336b631 100644 --- a/src/kernbench/components/base.py +++ b/src/kernbench/components/base.py @@ -33,6 +33,7 @@ class ComponentBase(ABC): self.ctx = ctx self.in_ports: dict[str, simpy.Store] = {} self.out_ports: dict[str, simpy.Store] = {} + self._op_logger: Any | None = None # OpLogger, set by GraphEngine if enabled def start(self, env: simpy.Environment) -> None: """Called once after all ports are wired. @@ -64,9 +65,21 @@ class ComponentBase(ABC): txn: Any = yield self._inbox.get() env.process(self._forward_txn(env, txn)) + def _on_process_start(self, env: simpy.Environment, msg: Any) -> None: + """Op log hook: record service start for data_op messages (ADR-0020 D2).""" + if self._op_logger and getattr(msg, "data_op", False): + self._op_logger.record_start(env.now, self.node.id, msg) + + def _on_process_end(self, env: simpy.Environment, msg: Any) -> None: + """Op log hook: record service end for data_op messages (ADR-0020 D2).""" + if self._op_logger and getattr(msg, "data_op", False): + self._op_logger.record_end(env.now, self.node.id, msg) + def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: """Apply run() latency, then forward to next hop or drain at terminal.""" + self._on_process_start(env, txn) yield from self.run(env, txn.nbytes) + self._on_process_end(env, txn) next_hop = txn.next_hop # duck-typed: Transaction.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) @@ -120,10 +133,16 @@ class PeEngineBase(ComponentBase): while True: msg: Any = yield self._inbox.get() if isinstance(msg, PeInternalTxn): - env.process(self.handle_command(env, msg)) + env.process(self._handle_with_hooks(env, msg)) else: env.process(self._forward_txn(env, msg)) + def _handle_with_hooks(self, env: simpy.Environment, pe_txn: Any) -> Generator: + """Wrap handle_command with op log hooks on the inner command.""" + self._on_process_start(env, pe_txn.command) + yield from self.handle_command(env, pe_txn) + self._on_process_end(env, pe_txn.command) + @abstractmethod def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator: """Process a PE-internal command (PeInternalTxn). diff --git a/src/kernbench/components/builtin/pe_cpu.py b/src/kernbench/components/builtin/pe_cpu.py index f2e3c7b..4947b9d 100644 --- a/src/kernbench/components/builtin/pe_cpu.py +++ b/src/kernbench/components/builtin/pe_cpu.py @@ -65,24 +65,45 @@ class PeCpuComponent(ComponentBase): yield from self._forward_txn(env, txn) def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator: - """Compile kernel function and replay command trace.""" - from kernbench.common.pe_commands import ( - CompositeCmd, - PeCpuOverheadCmd, - PeInternalTxn, - WaitCmd, - ) + """Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay.""" from kernbench.triton_emu.registry import get_kernel - from kernbench.triton_emu.tl_context import TLContext, run_kernel request = txn.request - - # Phase 1: Compile — apply PE_CPU setup overhead, then run kernel yield from self.run(env, 0) kernel_fn = get_kernel(request.kernel_ref.name) + num_programs = self._derive_num_programs(request) + kernel_args = self._unpack_kernel_args(request) - # Derive num_programs from the number of PE shards in this cube + pe_exec_start = env.now + scheduler_id = f"{self._pe_prefix}.pe_scheduler" + + # Choose execution mode: greenlet (ADR-0020) or legacy command-list + store = getattr(self.ctx, "memory_store", None) if self.ctx else None + + if store is not None: + composite_results = yield from self._execute_greenlet( + env, kernel_fn, kernel_args, num_programs, scheduler_id, store, + ) + else: + composite_results = yield from self._execute_legacy( + env, kernel_fn, kernel_args, num_programs, scheduler_id, + ) + + # Record PE-internal execution time + txn.result_data["pe_exec_ns"] = env.now - pe_exec_start + total_dma_ns = 0.0 + total_compute_ns = 0.0 + for rd in composite_results: + total_dma_ns += rd.get("dma_ns", 0.0) + total_compute_ns += rd.get("compute_ns", 0.0) + txn.result_data["dma_ns"] = total_dma_ns + txn.result_data["compute_ns"] = total_compute_ns + + # Send ResponseMsg on reverse path + yield from self._send_response(env, txn, request) + + def _derive_num_programs(self, request: Any) -> int: num_programs = 1 for arg in request.args: if arg.arg_kind == "tensor": @@ -92,11 +113,9 @@ class PeCpuComponent(ComponentBase): ) if cube_pe_count > num_programs: num_programs = cube_pe_count + return num_programs - tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) - - # Unpack KernelLaunchMsg.args into positional args for kernel function - # TensorArg → va_base (already local, set by runtime) or PA fallback + def _unpack_kernel_args(self, request: Any) -> list: kernel_args: list = [] for arg in request.args: if arg.arg_kind == "tensor": @@ -107,15 +126,41 @@ class PeCpuComponent(ComponentBase): kernel_args.append(shard.pa) elif arg.arg_kind == "scalar": kernel_args.append(arg.value) + return kernel_args + def _execute_greenlet( + self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store, + ) -> Generator: + """Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved.""" + from kernbench.triton_emu.kernel_runner import KernelRunner + + runner = KernelRunner( + pe_prefix=self._pe_prefix, + pe_idx=self._pe_idx, + sip_idx=self._sip_idx, + cube_idx=self._cube_idx, + scheduler_id=scheduler_id, + out_ports=self.out_ports, + store=store, + ) + yield from runner.run(env, kernel_fn, kernel_args, num_programs) + return getattr(runner, "_composite_results", []) + + def _execute_legacy( + self, env, kernel_fn, kernel_args, num_programs, scheduler_id, + ) -> Generator: + """Legacy Phase 0 + replay: generate command list, then dispatch.""" + from kernbench.common.pe_commands import ( + CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd, + ) + from kernbench.triton_emu.tl_context import TLContext, run_kernel + + tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) run_kernel(kernel_fn, tl, *kernel_args) commands = tl.commands - # Phase 2: Replay — dispatch commands to PE_SCHEDULER - pe_exec_start = env.now - scheduler_id = f"{self._pe_prefix}.pe_scheduler" - pending: dict[str, simpy.Event] = {} # completion_id → done event - composite_results: list[dict] = [] # collect result_data from CompositeCmd txns + pending: dict[str, simpy.Event] = {} + composite_results: list[dict] = [] for cmd in commands: if isinstance(cmd, PeCpuOverheadCmd): @@ -126,47 +171,30 @@ class PeCpuComponent(ComponentBase): if evt: yield evt else: - # Wait all pending completions for evt in pending.values(): yield evt pending.clear() elif isinstance(cmd, CompositeCmd): - # Non-blocking: dispatch to scheduler, track completion done_evt = env.event() pe_txn = PeInternalTxn( - command=cmd, done=done_evt, - pe_prefix=self._pe_prefix, + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, ) composite_results.append(pe_txn.result_data) yield self.out_ports[scheduler_id].put(pe_txn) pending[cmd.completion.id] = done_evt else: - # Blocking: dispatch and wait for completion done_evt = env.event() pe_txn = PeInternalTxn( - command=cmd, done=done_evt, - pe_prefix=self._pe_prefix, + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, ) yield self.out_ports[scheduler_id].put(pe_txn) yield done_evt - # Wait for any remaining pending completions for evt in pending.values(): yield evt + return composite_results - # Record PE-internal execution time - txn.result_data["pe_exec_ns"] = env.now - pe_exec_start - - # Aggregate dma_ns / compute_ns from CompositeCmd results - total_dma_ns = 0.0 - total_compute_ns = 0.0 - for rd in composite_results: - total_dma_ns += rd.get("dma_ns", 0.0) - total_compute_ns += rd.get("compute_ns", 0.0) - txn.result_data["dma_ns"] = total_dma_ns - txn.result_data["compute_ns"] = total_compute_ns - - # Send ResponseMsg on reverse path (PE_CPU → NOC → M_CPU) + def _send_response(self, env, txn, request) -> Generator: reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: from kernbench.runtime_api.kernel import ResponseMsg diff --git a/src/kernbench/components/context.py b/src/kernbench/components/context.py index 98a6f93..14c84c3 100644 --- a/src/kernbench/components/context.py +++ b/src/kernbench/components/context.py @@ -24,6 +24,8 @@ class ComponentContext: ns_per_mm: float # wire propagation constant (from topology spec) edge_map: dict[tuple[str, str], Any] = field(default_factory=dict) spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.) + memory_store: Any = None # MemoryStore for Phase 1 data-aware execution (ADR-0020) + op_logger: Any = None # OpLogger for Phase 1 op recording (ADR-0020) def get_shared_resource( self, env: simpy.Environment, key: str, capacity: int = 1, diff --git a/src/kernbench/sim_engine/data_executor.py b/src/kernbench/sim_engine/data_executor.py new file mode 100644 index 0000000..e546ffb --- /dev/null +++ b/src/kernbench/sim_engine/data_executor.py @@ -0,0 +1,157 @@ +"""DataExecutor: Phase 2 op_log-based data execution (ADR-0020 D6). + +Executes GEMM/Math operations from the op_log using numpy. +Memory ops are skipped (already handled in Phase 1 via MemoryStore). +Same-timestamp independent ops can be batched for efficiency. +""" +from __future__ import annotations + +from itertools import groupby +from typing import Any + +import numpy as np + +from kernbench.sim_engine.memory_store import MemoryStore, _resolve_dtype +from kernbench.sim_engine.op_log import OpRecord + + +class DataExecutor: + """Phase 2 executor: replay op_log with actual numpy computation. + + Args: + op_log: list of OpRecords from Phase 1. + store: MemoryStore snapshot from Phase 1 (contains tensor data). + """ + + def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None: + self._op_log = op_log + self.store = store + + def run(self) -> None: + """Execute all ops in op_log order, grouped by t_start.""" + for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start): + ops = list(ops_iter) + for op in ops: + self._execute_op(op) + + def _execute_op(self, op: OpRecord) -> None: + if op.op_kind == "memory": + self._execute_memory(op) + elif op.op_kind == "gemm": + self._execute_gemm(op) + elif op.op_kind == "math": + self._execute_math(op) + + def _execute_memory(self, op: OpRecord) -> None: + """Memory ops are already handled by Phase 1 MemoryStore. Skip.""" + + def _execute_gemm(self, op: OpRecord) -> None: + """Execute GEMM: out = a @ b.""" + p = op.params + if "src_a_addr" not in p: + return # composite record without full params + space = p.get("addr_space", "tcm") + dtype_in = p.get("dtype_in", "f16") + dtype_out = p.get("dtype_out", dtype_in) + + a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in) + b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in) + + # Compute in higher precision if specified + dtype_acc = p.get("dtype_acc", "f32") + a_f = a.astype(_resolve_dtype(dtype_acc)) + b_f = b.astype(_resolve_dtype(dtype_acc)) + result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out)) + + self.store.write(space, p["dst_addr"], result) + + def _execute_math(self, op: OpRecord) -> None: + """Execute math op: unary, binary, or reduction.""" + p = op.params + math_op = p.get("op", op.op_name) + space = p.get("addr_space", "tcm") + dtype = p.get("dtype", "f32") + input_addrs = p.get("input_addrs", []) + input_shapes = p.get("input_shapes", []) + + inputs = [] + for addr, shape in zip(input_addrs, input_shapes): + inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype)) + + result = _compute_math(math_op, inputs, p.get("axis")) + if result is not None: + self.store.write(space, p["dst_addr"], result) + + def verify(self, expected: dict[tuple[str, int], np.ndarray], + rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]: + """Compare MemoryStore contents against expected tensors. + + Args: + expected: {(space, addr): expected_ndarray} + rtol, atol: tolerance for floating-point comparison. + + Returns: + {key_str: passed} dict. + """ + results = {} + for (space, addr), exp in expected.items(): + key = f"{space}:0x{addr:x}" + try: + actual = self.store.read(space, addr) + if np.issubdtype(actual.dtype, np.integer): + results[key] = bool(np.array_equal(actual, exp)) + else: + results[key] = bool(np.allclose(actual, exp, rtol=rtol, atol=atol)) + except KeyError: + results[key] = False + return results + + +def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.ndarray | None: + """Execute a math operation on numpy arrays.""" + if not inputs: + return None + + x = inputs[0] + + # Unary + if op == "exp": + return np.exp(x) + if op == "log": + return np.log(x) + if op == "sqrt": + return np.sqrt(x) + if op == "abs": + return np.abs(x) + if op == "sigmoid": + return 1.0 / (1.0 + np.exp(-x)) + if op == "cos": + return np.cos(x) + if op == "sin": + return np.sin(x) + + # Reduction + if op == "sum": + return np.sum(x, axis=axis, keepdims=True) + if op == "max": + return np.max(x, axis=axis, keepdims=True) + if op == "min": + return np.min(x, axis=axis, keepdims=True) + + # Binary + if len(inputs) >= 2: + y = inputs[1] + if op == "add": + return x + y + if op == "sub": + return x - y + if op == "mul": + return x * y + if op == "div": + return x / y + + # Ternary + if op == "where" and len(inputs) >= 3: + return np.where(inputs[0], inputs[1], inputs[2]) + + return None diff --git a/src/kernbench/sim_engine/memory_store.py b/src/kernbench/sim_engine/memory_store.py new file mode 100644 index 0000000..44b80e1 --- /dev/null +++ b/src/kernbench/sim_engine/memory_store.py @@ -0,0 +1,84 @@ +"""MemoryStore: tensor-granular storage for Phase 1 and Phase 2 (ADR-0020 D7). + +Logically byte-addressable, implemented as addr → numpy ndarray mapping. +Read/write are reference-based (no copy) for Phase 1 performance. +""" +from __future__ import annotations + +import numpy as np + + +# numpy dtype string → numpy dtype mapping +_DTYPE_MAP = { + "f16": np.float16, + "f32": np.float32, + "f64": np.float64, + "bf16": np.float16, # numpy has no bfloat16; use float16 as proxy + "i8": np.int8, + "i16": np.int16, + "i32": np.int32, + "i64": np.int64, + "u8": np.uint8, + "u16": np.uint16, + "u32": np.uint32, +} + + +def _resolve_dtype(dtype: str) -> np.dtype: + if dtype in _DTYPE_MAP: + return np.dtype(_DTYPE_MAP[dtype]) + return np.dtype(dtype) + + +class MemoryStore: + """Tensor-granular memory storage (ADR-0020 D7). + + Stores numpy ndarrays by (space, addr) key. + Write = reference store (no copy), read = reference return (no copy). + Overwrite at same addr replaces the entire tensor. + """ + + def __init__(self) -> None: + # {space: {addr: ndarray}} + self._storage: dict[str, dict[int, np.ndarray]] = {} + + def write(self, space: str, addr: int, data: np.ndarray) -> None: + """Store tensor at (space, addr). Reference-only, no copy.""" + if space not in self._storage: + self._storage[space] = {} + self._storage[space][addr] = data + + def read(self, space: str, addr: int, shape: tuple[int, ...] | None = None, + dtype: str | None = None) -> np.ndarray: + """Read tensor from (space, addr). Returns reference, no copy. + + If shape/dtype match stored tensor, returns as-is. + If dtype differs, performs reinterpret cast (view). + If shape differs but nbytes match, reshapes. + """ + store = self._storage.get(space) + if store is None or addr not in store: + raise KeyError(f"No data at ({space}, 0x{addr:x})") + arr = store[addr] + if dtype is not None: + np_dtype = _resolve_dtype(dtype) + if arr.dtype != np_dtype: + arr = arr.view(np_dtype) + if shape is not None and arr.shape != shape: + if arr.nbytes != np.prod(shape) * arr.dtype.itemsize: + raise ValueError( + f"Shape mismatch: stored {arr.shape} ({arr.nbytes}B) " + f"vs requested {shape} ({np.prod(shape) * arr.dtype.itemsize}B)" + ) + arr = arr.reshape(shape) + return arr + + def has(self, space: str, addr: int) -> bool: + return addr in self._storage.get(space, {}) + + def snapshot(self) -> MemoryStore: + """Create a shallow copy for Phase 2 initialization.""" + new = MemoryStore() + for space, addrs in self._storage.items(): + new._storage[space] = dict(addrs) # shallow copy of addr→ndarray map + return new diff --git a/src/kernbench/sim_engine/op_log.py b/src/kernbench/sim_engine/op_log.py new file mode 100644 index 0000000..bf0f5de --- /dev/null +++ b/src/kernbench/sim_engine/op_log.py @@ -0,0 +1,111 @@ +"""Op log infrastructure for 2-pass data execution (ADR-0020 D2, D5). + +OpRecord: single data operation with timing, params, and dependencies. +OpLogger: collects OpRecords from ComponentBase hooks during Phase 1. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class OpRecord: + """Single data operation record (ADR-0020 D5).""" + + t_start: float + t_end: float + component_id: str + op_kind: str # "memory" | "gemm" | "math" + op_name: str # e.g. "dma_read", "gemm_f16", "exp" + params: dict[str, Any] + dependency_ids: list[int] = field(default_factory=list) + + +class OpLogger: + """Collects OpRecords during Phase 1 simulation (ADR-0020 D2). + + Thread-safe is not required — SimPy is single-threaded. + Records are maintained in t_start stable ordering (insertion order). + """ + + def __init__(self) -> None: + self._records: list[OpRecord] = [] + self._pending: dict[int, dict[str, Any]] = {} # msg id → partial record + + @property + def records(self) -> list[OpRecord]: + """Records sorted by t_start (stable ordering per ADR-0020 D5).""" + self._records.sort(key=lambda r: r.t_start) + return self._records + + def record_start(self, t: float, component_id: str, msg: Any) -> None: + """Called by ComponentBase._on_process_start.""" + self._pending[id(msg)] = { + "t_start": t, + "component_id": component_id, + "msg": msg, + } + + def record_end(self, t: float, component_id: str, msg: Any) -> None: + """Called by ComponentBase._on_process_end.""" + pending = self._pending.pop(id(msg), None) + if pending is None: + return + op_kind, op_name, params = _extract_op_info(msg) + self._records.append(OpRecord( + t_start=pending["t_start"], + t_end=t, + component_id=pending["component_id"], + op_kind=op_kind, + op_name=op_name, + params=params, + )) + + +def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]: + """Extract op_kind, op_name, params from a data_op message.""" + from kernbench.common.pe_commands import ( + DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd, + ) + if isinstance(msg, DmaReadCmd): + return "memory", "dma_read", { + "src_addr": msg.src_addr, + "nbytes": msg.nbytes, + "handle_id": msg.handle.id, + } + if isinstance(msg, DmaWriteCmd): + return "memory", "dma_write", { + "dst_addr": msg.dst_addr, + "nbytes": msg.nbytes, + "handle_id": msg.handle.id, + } + if isinstance(msg, GemmCmd): + return "gemm", f"gemm_{msg.a.dtype}", { + "src_a_addr": msg.a.addr, + "src_b_addr": msg.b.addr, + "dst_addr": msg.out.addr, + "shape_a": msg.a.shape, + "shape_b": msg.b.shape, + "shape_out": msg.out.shape, + "dtype_in": msg.a.dtype, + "dtype_out": msg.out.dtype, + "m": msg.m, "k": msg.k, "n": msg.n, + } + if isinstance(msg, MathCmd): + return "math", msg.op, { + "input_addrs": [h.addr for h in msg.inputs], + "input_shapes": [h.shape for h in msg.inputs], + "dst_addr": msg.out.addr, + "shape_out": msg.out.shape, + "dtype": msg.out.dtype, + "axis": msg.axis, + } + if isinstance(msg, CompositeCmd): + return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", { + "op": msg.op, + "out_addr": msg.out_addr, + "out_nbytes": msg.out_nbytes, + } + # Fallback for unknown data_op messages + return "unknown", type(msg).__name__, {} diff --git a/src/kernbench/triton_emu/kernel_runner.py b/src/kernbench/triton_emu/kernel_runner.py new file mode 100644 index 0000000..afc75d3 --- /dev/null +++ b/src/kernbench/triton_emu/kernel_runner.py @@ -0,0 +1,199 @@ +"""KernelRunner: greenlet-based kernel ↔ SimPy bridge (ADR-0020 D3). + +Replaces Phase 0 (static command list) with interleaved execution: + - tl.load() → SimPy DMA timing + MemoryStore read → real data to kernel + - tl.store() → MemoryStore write + SimPy DMA timing + - tl.composite(gemm) → SimPy timing + op_log (actual compute in Phase 2) + +The kernel runs as a child greenlet; SimPy loop is the parent. +""" +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +import simpy +from greenlet import greenlet + +from kernbench.common.pe_commands import ( + CompletionHandle, + CompositeCmd, + DmaReadCmd, + DmaWriteCmd, + GemmCmd, + MathCmd, + PeCommand, + PeCpuOverheadCmd, + PeInternalTxn, + TensorHandle, + WaitCmd, +) + +if TYPE_CHECKING: + from kernbench.sim_engine.memory_store import MemoryStore + + +class KernelRunner: + """Greenlet ↔ SimPy bridge for kernel execution (ADR-0020 D3). + + PE_CPU creates a KernelRunner and yields from its run() method. + The kernel function executes as a plain Python function inside a + child greenlet, using TLContext methods that switch back to SimPy. + """ + + def __init__( + self, + pe_prefix: str, + pe_idx: int, + sip_idx: int, + cube_idx: int, + scheduler_id: str, + out_ports: dict[str, simpy.Store], + store: MemoryStore | None = None, + ) -> None: + self._pe_prefix = pe_prefix + self._pe_idx = pe_idx + self._sip_idx = sip_idx + self._cube_idx = cube_idx + self._scheduler_id = scheduler_id + self._out_ports = out_ports + self._store = store + self._parent: greenlet | None = None + + def run( + self, + env: simpy.Environment, + kernel_fn: Any, + kernel_args: list, + num_programs: int, + ) -> Generator: + """SimPy generator: run kernel with greenlet interleaving. + + This is the SimPy-side loop. It: + 1. Creates a TLContext connected to this runner + 2. Spawns the kernel in a child greenlet + 3. Receives commands via greenlet.switch + 4. Dispatches each command through SimPy components + 5. Returns results to the kernel + """ + from kernbench.triton_emu.tl_context import TLContext + + self._parent = greenlet.getcurrent() + + tl = TLContext( + pe_id=self._pe_idx, + num_programs=num_programs, + dispatch_cycles=0, + runner=self, + ) + + def _kernel_entry(): + TLContext._set_active(tl) # type: ignore[attr-defined] + try: + kernel_fn(*kernel_args, tl=tl) + finally: + TLContext._set_active(None) # type: ignore[attr-defined] + return None # signal kernel completion + + g = greenlet(_kernel_entry) + pending: dict[str, simpy.Event] = {} + composite_results: list[dict] = [] + + # Start kernel — first switch returns first command (or None if kernel is done) + cmd = g.switch() + + while cmd is not None: + if isinstance(cmd, PeCpuOverheadCmd): + yield env.timeout(cmd.cycles) + cmd = g.switch() + + elif isinstance(cmd, WaitCmd): + if cmd.handle is not None: + evt = pending.pop(cmd.handle.id, None) + if evt: + yield evt + else: + for evt in pending.values(): + yield evt + pending.clear() + cmd = g.switch() + + elif isinstance(cmd, DmaReadCmd): + # Dispatch DMA through SimPy components + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + yield self._out_ports[self._scheduler_id].put(pe_txn) + yield done_evt + + # Read actual data from MemoryStore (if available) + data = None + if self._store is not None: + try: + data = self._store.read( + "hbm", cmd.src_addr, + shape=cmd.handle.shape, dtype=cmd.handle.dtype, + ) + except KeyError: + pass + cmd = g.switch(data) + + elif isinstance(cmd, DmaWriteCmd): + # Write to MemoryStore first (visibility = issue, ADR-0020 D3) + if self._store is not None and cmd.handle.data is not None: + self._store.write("hbm", cmd.dst_addr, cmd.handle.data) + + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + yield self._out_ports[self._scheduler_id].put(pe_txn) + yield done_evt + cmd = g.switch() + + elif isinstance(cmd, CompositeCmd): + # Non-blocking composite + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + composite_results.append(pe_txn.result_data) + yield self._out_ports[self._scheduler_id].put(pe_txn) + pending[cmd.completion.id] = done_evt + cmd = g.switch() + + elif isinstance(cmd, (GemmCmd, MathCmd)): + # Blocking compute command + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + yield self._out_ports[self._scheduler_id].put(pe_txn) + yield done_evt + cmd = g.switch() + + else: + # Unknown command — pass through as blocking + done_evt = env.event() + pe_txn = PeInternalTxn( + command=cmd, done=done_evt, pe_prefix=self._pe_prefix, + ) + yield self._out_ports[self._scheduler_id].put(pe_txn) + yield done_evt + cmd = g.switch() + + # Wait remaining pending composites + for evt in pending.values(): + yield evt + + # Return composite results for PE_CPU aggregation + self._composite_results = composite_results + + def switch_to_simpy(self, cmd: PeCommand) -> Any: + """Called from TLContext (child greenlet) to send command to SimPy. + + Returns the result from SimPy (e.g., numpy array for DMA read). + """ + assert self._parent is not None + return self._parent.switch(cmd) diff --git a/src/kernbench/triton_emu/tl_context.py b/src/kernbench/triton_emu/tl_context.py index 63c867b..4f9732d 100644 --- a/src/kernbench/triton_emu/tl_context.py +++ b/src/kernbench/triton_emu/tl_context.py @@ -52,6 +52,7 @@ class TLContext: pe_id: int = 0, num_programs: int = 1, dispatch_cycles: int = 1, + runner: Any = None, ) -> None: self._pe_id = pe_id self._num_programs = num_programs @@ -59,6 +60,7 @@ class TLContext: self._commands: list[PeCommand] = [] self._handle_counter = 0 self._completion_counter = 0 + self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3) @property def commands(self) -> list[PeCommand]: @@ -83,7 +85,7 @@ class TLContext: def _emit_dispatch_overhead(self) -> None: if self._dispatch_cycles > 0: - self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles)) + self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles)) def _make_handle( self, addr: int, shape: tuple[int, ...], dtype: str, @@ -108,23 +110,38 @@ class TLContext: # ── Data Movement (blocking, DMA engine) ────────────────────── + def _emit(self, cmd: PeCommand) -> Any: + """Emit command: greenlet switch if runner available, else append to list.""" + if self._runner is not None: + return self._runner.switch_to_simpy(cmd) + self._commands.append(cmd) + return None + def load( self, ptr: int, shape: tuple[int, ...], dtype: str = "f16", ) -> TensorHandle: - """Load tensor from HBM to TCM. Returns TensorHandle.""" + """Load tensor from HBM to TCM. Returns TensorHandle. + + In greenlet mode: returns TensorHandle with actual numpy data. + In command-list mode: returns TensorHandle with data=None. + """ self._emit_dispatch_overhead() handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype) - self._commands.append(DmaReadCmd( - handle=handle, src_addr=ptr, nbytes=handle.nbytes, - )) + cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes) + data = self._emit(cmd) + if data is not None: + # Greenlet mode: attach real data to handle + return TensorHandle( + id=handle.id, addr=handle.addr, shape=handle.shape, + dtype=handle.dtype, nbytes=handle.nbytes, data=data, + ) return handle def store(self, ptr: int, handle: TensorHandle) -> None: """Store tensor from TCM to HBM.""" self._emit_dispatch_overhead() - self._commands.append(DmaWriteCmd( - handle=handle, dst_addr=ptr, nbytes=handle.nbytes, - )) + cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes) + self._emit(cmd) # ── GEMM Engine (blocking) ──────────────────────────────────── @@ -143,7 +160,7 @@ class TLContext: out_dtype = a.dtype out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype) self._emit_dispatch_overhead() - self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n)) + self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n)) return out # ── MATH Engine: unary (blocking) ───────────────────────────── @@ -151,7 +168,7 @@ class TLContext: def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle: out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype) self._emit_dispatch_overhead() - self._commands.append(MathCmd(op=op, inputs=(x,), out=out)) + self._emit(MathCmd(op=op, inputs=(x,), out=out)) return out def exp(self, x: TensorHandle) -> TensorHandle: @@ -184,7 +201,7 @@ class TLContext: out_shape[axis] = 1 out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype) self._emit_dispatch_overhead() - self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis)) + self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis)) return out def sum(self, x: TensorHandle, axis: int) -> TensorHandle: @@ -203,7 +220,7 @@ class TLContext: ) -> TensorHandle: out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() - self._commands.append(MathCmd(op=op, inputs=(a, b), out=out)) + self._emit(MathCmd(op=op, inputs=(a, b), out=out)) return out def where( @@ -211,7 +228,7 @@ class TLContext: ) -> TensorHandle: out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype) self._emit_dispatch_overhead() - self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out)) + self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out)) return out # ── Index / Scalar (PE_CPU, no engine) ──────────────────────── @@ -276,7 +293,7 @@ class TLContext: completion = CompletionHandle(id=self._next_completion_id()) self._emit_dispatch_overhead() - self._commands.append(CompositeCmd( + self._emit(CompositeCmd( completion=completion, op=op, a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes, math_op=math_op, @@ -285,11 +302,11 @@ class TLContext: def wait(self, handle: CompletionHandle | None = None) -> None: """Wait for a specific composite or all pending composites.""" - self._commands.append(WaitCmd(handle=handle)) + self._emit(WaitCmd(handle=handle)) def cycles(self, n: int) -> None: """Declare PE_CPU scalar execution overhead (cycles).""" - self._commands.append(PeCpuOverheadCmd(cycles=n)) + self._emit(PeCpuOverheadCmd(cycles=n)) # ── TensorHandle arithmetic operators ───────────────────────────── diff --git a/tests/test_data_executor.py b/tests/test_data_executor.py new file mode 100644 index 0000000..af1aa97 --- /dev/null +++ b/tests/test_data_executor.py @@ -0,0 +1,188 @@ +"""Tests for DataExecutor Phase 2 execution (ADR-0020 D6).""" +import numpy as np + +from kernbench.sim_engine.data_executor import DataExecutor +from kernbench.sim_engine.memory_store import MemoryStore +from kernbench.sim_engine.op_log import OpRecord + + +def test_gemm_execution(): + """Phase 2 GEMM: out = a @ b with f32 accumulation.""" + store = MemoryStore() + a = np.ones((4, 8), dtype=np.float16) + b = np.ones((8, 4), dtype=np.float16) * 2.0 + store.write("tcm", 0x0, a) + store.write("tcm", 0x100, b) + + op = OpRecord( + t_start=0.0, t_end=100.0, + component_id="pe_gemm", + op_kind="gemm", op_name="gemm_f16", + params={ + "src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200, + "shape_a": (4, 8), "shape_b": (8, 4), "shape_out": (4, 4), + "dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16", + "addr_space": "tcm", + }, + ) + + executor = DataExecutor([op], store) + executor.run() + + result = store.read("tcm", 0x200) + expected = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float16) + assert np.allclose(result, expected) + + +def test_math_exp(): + store = MemoryStore() + x = np.array([0.0, 1.0, 2.0], dtype=np.float32) + store.write("tcm", 0x0, x) + + op = OpRecord( + t_start=0.0, t_end=10.0, + component_id="pe_math", + op_kind="math", op_name="exp", + params={ + "op": "exp", + "input_addrs": [0x0], "input_shapes": [(3,)], + "dst_addr": 0x100, "shape_out": (3,), + "dtype": "f32", "axis": None, "addr_space": "tcm", + }, + ) + + executor = DataExecutor([op], store) + executor.run() + + result = store.read("tcm", 0x100) + assert np.allclose(result, np.exp(x)) + + +def test_math_add(): + store = MemoryStore() + a = np.array([1.0, 2.0], dtype=np.float32) + b = np.array([3.0, 4.0], dtype=np.float32) + store.write("tcm", 0x0, a) + store.write("tcm", 0x100, b) + + op = OpRecord( + t_start=0.0, t_end=5.0, + component_id="pe_math", + op_kind="math", op_name="add", + params={ + "op": "add", + "input_addrs": [0x0, 0x100], "input_shapes": [(2,), (2,)], + "dst_addr": 0x200, "shape_out": (2,), + "dtype": "f32", "axis": None, "addr_space": "tcm", + }, + ) + + executor = DataExecutor([op], store) + executor.run() + + result = store.read("tcm", 0x200) + assert np.array_equal(result, np.array([4.0, 6.0], dtype=np.float32)) + + +def test_math_sum_reduction(): + store = MemoryStore() + x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + store.write("tcm", 0x0, x) + + op = OpRecord( + t_start=0.0, t_end=5.0, + component_id="pe_math", + op_kind="math", op_name="sum", + params={ + "op": "sum", + "input_addrs": [0x0], "input_shapes": [(2, 2)], + "dst_addr": 0x100, "shape_out": (1, 2), + "dtype": "f32", "axis": 0, "addr_space": "tcm", + }, + ) + + executor = DataExecutor([op], store) + executor.run() + + result = store.read("tcm", 0x100) + assert np.array_equal(result, np.array([[4.0, 6.0]], dtype=np.float32)) + + +def test_verify_pass(): + store = MemoryStore() + store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32)) + + executor = DataExecutor([], store) + results = executor.verify({ + ("hbm", 0x0): np.array([1.0, 2.0], dtype=np.float32), + }) + assert results["hbm:0x0"] is True + + +def test_verify_fail(): + store = MemoryStore() + store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32)) + + executor = DataExecutor([], store) + results = executor.verify({ + ("hbm", 0x0): np.array([9.0, 9.0], dtype=np.float32), + }) + assert results["hbm:0x0"] is False + + +def test_memory_ops_skipped(): + """Memory ops in op_log should be skipped (handled in Phase 1).""" + store = MemoryStore() + op = OpRecord( + t_start=0.0, t_end=5.0, + component_id="pe_dma", + op_kind="memory", op_name="dma_read", + params={"src_addr": 0x0, "nbytes": 64, "handle_id": "t0"}, + ) + # Should not raise + executor = DataExecutor([op], store) + executor.run() + + +def test_sequential_gemm_then_math(): + """GEMM output feeds into math op.""" + store = MemoryStore() + a = np.eye(2, dtype=np.float16) + b = np.ones((2, 2), dtype=np.float16) + store.write("tcm", 0x0, a) + store.write("tcm", 0x100, b) + + ops = [ + OpRecord( + t_start=0.0, t_end=50.0, + component_id="pe_gemm", + op_kind="gemm", op_name="gemm_f16", + params={ + "src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200, + "shape_a": (2, 2), "shape_b": (2, 2), "shape_out": (2, 2), + "dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f32", + "addr_space": "tcm", + }, + ), + OpRecord( + t_start=50.0, t_end=55.0, + component_id="pe_math", + op_kind="math", op_name="exp", + params={ + "op": "exp", + "input_addrs": [0x200], "input_shapes": [(2, 2)], + "dst_addr": 0x300, "shape_out": (2, 2), + "dtype": "f32", "axis": None, "addr_space": "tcm", + }, + ), + ] + + executor = DataExecutor(ops, store) + executor.run() + + gemm_result = store.read("tcm", 0x200) + expected_gemm = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float32) + assert np.allclose(gemm_result, expected_gemm) + + exp_result = store.read("tcm", 0x300) + assert np.allclose(exp_result, np.exp(expected_gemm)) diff --git a/tests/test_kernel_runner.py b/tests/test_kernel_runner.py new file mode 100644 index 0000000..3cb2c64 --- /dev/null +++ b/tests/test_kernel_runner.py @@ -0,0 +1,140 @@ +"""Tests for KernelRunner greenlet-based execution (ADR-0020 D3).""" +import numpy as np +import simpy + +from kernbench.sim_engine.memory_store import MemoryStore +from kernbench.triton_emu.kernel_runner import KernelRunner + + +def _make_runner(env, store=None): + """Create a minimal KernelRunner with mock scheduler port.""" + scheduler_id = "sip0.cube0.pe0.pe_scheduler" + out_ports = {scheduler_id: simpy.Store(env)} + runner = KernelRunner( + pe_prefix="sip0.cube0.pe0", + pe_idx=0, sip_idx=0, cube_idx=0, + scheduler_id=scheduler_id, + out_ports=out_ports, + store=store, + ) + return runner, out_ports[scheduler_id] + + +def _mock_scheduler(env, inbox): + """Consume PeInternalTxn from inbox and immediately succeed.""" + while True: + pe_txn = yield inbox.get() + pe_txn.done.succeed() + + +def test_kernel_runner_basic_load(): + """Kernel with tl.load runs through greenlet without hanging.""" + env = simpy.Environment() + store = MemoryStore() + data = np.ones((4, 4), dtype=np.float16) + store.write("hbm", 0x1000, data) + + runner, sched_port = _make_runner(env, store) + env.process(_mock_scheduler(env, sched_port)) + + def kernel(a_ptr, tl): + a = tl.load(a_ptr, (4, 4), "f16") + assert a.data is not None + assert a.data.shape == (4, 4) + + def run(): + yield from runner.run(env, kernel, [0x1000], num_programs=1) + + env.process(run()) + env.run() + + +def test_kernel_runner_load_returns_data(): + """tl.load returns actual numpy data from MemoryStore.""" + env = simpy.Environment() + store = MemoryStore() + data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16) + store.write("hbm", 0x2000, data) + + runner, sched_port = _make_runner(env, store) + env.process(_mock_scheduler(env, sched_port)) + + results = {} + + def kernel(ptr, tl): + a = tl.load(ptr, (2, 2), "f16") + results["data"] = a.data + + def run(): + yield from runner.run(env, kernel, [0x2000], num_programs=1) + + env.process(run()) + env.run() + assert results["data"] is data # reference equality + + +def test_kernel_runner_composite(): + """Composite commands pass through without blocking kernel.""" + env = simpy.Environment() + runner, sched_port = _make_runner(env) + env.process(_mock_scheduler(env, sched_port)) + + def kernel(a_ptr, b_ptr, out_ptr, tl): + a = tl.ref(a_ptr, (4, 8), "f16") + b = tl.ref(b_ptr, (8, 4), "f16") + h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr) + tl.wait(h) + + def run(): + yield from runner.run(env, kernel, [0, 64, 128], num_programs=1) + + env.process(run()) + env.run() + + +def test_kernel_runner_dynamic_branch(): + """Kernel can branch based on loaded data (ADR-0020 D3).""" + env = simpy.Environment() + store = MemoryStore() + store.write("hbm", 0x100, np.array([1.0], dtype=np.float32)) + store.write("hbm", 0x200, np.array([0.0], dtype=np.float32)) + + runner, sched_port = _make_runner(env, store) + env.process(_mock_scheduler(env, sched_port)) + + results = {"branch": None} + + def kernel(flag_ptr, tl): + flag = tl.load(flag_ptr, (1,), "f32") + if flag.data is not None and flag.data[0] > 0.5: + results["branch"] = "taken" + else: + results["branch"] = "not_taken" + + # Test with flag=1.0 → branch taken + def run(): + yield from runner.run(env, kernel, [0x100], num_programs=1) + + env.process(run()) + env.run() + assert results["branch"] == "taken" + + +def test_kernel_runner_no_store(): + """Without MemoryStore, tl.load returns handle with data=None.""" + env = simpy.Environment() + runner, sched_port = _make_runner(env, store=None) + env.process(_mock_scheduler(env, sched_port)) + + results = {} + + def kernel(ptr, tl): + a = tl.load(ptr, (4,), "f16") + results["data"] = a.data + + def run(): + yield from runner.run(env, kernel, [0], num_programs=1) + + env.process(run()) + env.run() + assert results["data"] is None diff --git a/tests/test_memory_store.py b/tests/test_memory_store.py new file mode 100644 index 0000000..45dc64b --- /dev/null +++ b/tests/test_memory_store.py @@ -0,0 +1,85 @@ +"""Tests for MemoryStore (ADR-0020 D7).""" +import numpy as np +import pytest + +from kernbench.sim_engine.memory_store import MemoryStore + + +def test_write_read_reference(): + """Write and read return the same numpy array (no copy).""" + store = MemoryStore() + data = np.ones((4, 4), dtype=np.float16) + store.write("tcm", 0x0, data) + result = store.read("tcm", 0x0) + assert result is data + + +def test_overwrite_replaces(): + """Same addr write replaces the previous tensor.""" + store = MemoryStore() + data1 = np.zeros((4,), dtype=np.float32) + data2 = np.ones((4,), dtype=np.float32) + store.write("hbm", 0x100, data1) + store.write("hbm", 0x100, data2) + result = store.read("hbm", 0x100) + assert result is data2 + + +def test_read_missing_raises(): + store = MemoryStore() + with pytest.raises(KeyError): + store.read("hbm", 0x999) + + +def test_read_different_space(): + store = MemoryStore() + data = np.array([1, 2, 3], dtype=np.int32) + store.write("tcm", 0x0, data) + with pytest.raises(KeyError): + store.read("hbm", 0x0) # different space + + +def test_dtype_reinterpret(): + """Read with different dtype does view cast.""" + store = MemoryStore() + data = np.array([1.0, 2.0], dtype=np.float32) # 8 bytes + store.write("tcm", 0x0, data) + result = store.read("tcm", 0x0, dtype="u8") + assert result.dtype == np.uint8 + assert result.nbytes == data.nbytes + + +def test_reshape(): + store = MemoryStore() + data = np.arange(12, dtype=np.float32) + store.write("tcm", 0x0, data) + result = store.read("tcm", 0x0, shape=(3, 4)) + assert result.shape == (3, 4) + + +def test_shape_mismatch_raises(): + store = MemoryStore() + data = np.arange(12, dtype=np.float32) + store.write("tcm", 0x0, data) + with pytest.raises(ValueError, match="Shape mismatch"): + store.read("tcm", 0x0, shape=(5, 5)) + + +def test_has(): + store = MemoryStore() + assert not store.has("tcm", 0x0) + store.write("tcm", 0x0, np.array([1])) + assert store.has("tcm", 0x0) + + +def test_snapshot(): + store = MemoryStore() + data = np.ones((4,), dtype=np.float16) + store.write("hbm", 0x0, data) + + snap = store.snapshot() + assert snap.read("hbm", 0x0) is data # same reference + + # Modifying snap doesn't affect original + snap.write("hbm", 0x0, np.zeros((4,), dtype=np.float16)) + assert store.read("hbm", 0x0) is data diff --git a/tests/test_op_log.py b/tests/test_op_log.py new file mode 100644 index 0000000..0d77a8d --- /dev/null +++ b/tests/test_op_log.py @@ -0,0 +1,87 @@ +"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5).""" +import numpy as np + +from kernbench.common.pe_commands import ( + DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle, +) +from kernbench.sim_engine.op_log import OpLogger, OpRecord + + +def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"): + return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32) + + +def test_op_logger_record_start_end(): + logger = OpLogger() + cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64) + logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd) + logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd) + assert len(logger.records) == 1 + r = logger.records[0] + assert r.t_start == 10.0 + assert r.t_end == 15.0 + assert r.op_kind == "memory" + assert r.op_name == "dma_read" + assert r.params["src_addr"] == 0x1000 + + +def test_op_logger_gemm(): + logger = OpLogger() + a = _th("a", 0, (128, 256), "f16") + b = _th("b", 1024, (256, 128), "f16") + out = _th("out", 2048, (128, 128), "f16") + cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128) + logger.record_start(0.0, "pe_gemm", cmd) + logger.record_end(100.0, "pe_gemm", cmd) + r = logger.records[0] + assert r.op_kind == "gemm" + assert r.op_name == "gemm_f16" + assert r.params["m"] == 128 + + +def test_op_logger_math(): + logger = OpLogger() + x = _th("x", 0, (32,), "f32") + out = _th("out", 128, (32,), "f32") + cmd = MathCmd(op="exp", inputs=(x,), out=out) + logger.record_start(5.0, "pe_math", cmd) + logger.record_end(6.0, "pe_math", cmd) + r = logger.records[0] + assert r.op_kind == "math" + assert r.op_name == "exp" + + +def test_op_logger_stable_ordering(): + logger = OpLogger() + cmds = [ + DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64) + for i in range(5) + ] + for i, cmd in enumerate(cmds): + logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start + for i, cmd in enumerate(cmds): + logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd) + + # Verify insertion order preserved for same t_start + for i in range(len(logger.records) - 1): + assert logger.records[i].t_start <= logger.records[i + 1].t_start + + +def test_op_logger_unmatched_end_ignored(): + logger = OpLogger() + cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32) + logger.record_end(5.0, "comp", cmd) # no matching start + assert len(logger.records) == 0 + + +def test_data_op_flag(): + """DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't.""" + assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False) + assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False) + a, b, out = _th("a"), _th("b"), _th("out") + assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False) + assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False) + # WaitCmd and PeCpuOverheadCmd should not have data_op + from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd + assert not getattr(WaitCmd(), "data_op", False) + assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)