Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "kernbench"
|
name = "kernbench"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12"]
|
dependencies = ["pytest", "simpy", "pyyaml", "fastapi>=0.110", "uvicorn[standard]>=0.29", "websockets>=12", "numpy>=1.24", "greenlet>=3.0"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
kernbench = "kernbench.cli.main:main"
|
kernbench = "kernbench.cli.main:main"
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class DmaReadCmd:
|
|||||||
handle: TensorHandle
|
handle: TensorHandle
|
||||||
src_addr: int
|
src_addr: int
|
||||||
nbytes: int
|
nbytes: int
|
||||||
|
data_op: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -64,6 +65,7 @@ class DmaWriteCmd:
|
|||||||
handle: TensorHandle
|
handle: TensorHandle
|
||||||
dst_addr: int
|
dst_addr: int
|
||||||
nbytes: int
|
nbytes: int
|
||||||
|
data_op: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -79,6 +81,7 @@ class GemmCmd:
|
|||||||
m: int
|
m: int
|
||||||
k: int
|
k: int
|
||||||
n: int
|
n: int
|
||||||
|
data_op: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -94,6 +97,7 @@ class MathCmd:
|
|||||||
inputs: tuple[TensorHandle, ...]
|
inputs: tuple[TensorHandle, ...]
|
||||||
out: TensorHandle
|
out: TensorHandle
|
||||||
axis: int | None = None # for reductions
|
axis: int | None = None # for reductions
|
||||||
|
data_op: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -111,6 +115,7 @@ class CompositeCmd:
|
|||||||
out_addr: int
|
out_addr: int
|
||||||
out_nbytes: int
|
out_nbytes: int
|
||||||
math_op: str | None = None # for op="math": which math operation
|
math_op: str | None = None # for op="math": which math operation
|
||||||
|
data_op: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class ComponentBase(ABC):
|
|||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.in_ports: dict[str, simpy.Store] = {}
|
self.in_ports: dict[str, simpy.Store] = {}
|
||||||
self.out_ports: dict[str, simpy.Store] = {}
|
self.out_ports: dict[str, simpy.Store] = {}
|
||||||
|
self._op_logger: Any | None = None # OpLogger, set by GraphEngine if enabled
|
||||||
|
|
||||||
def start(self, env: simpy.Environment) -> None:
|
def start(self, env: simpy.Environment) -> None:
|
||||||
"""Called once after all ports are wired.
|
"""Called once after all ports are wired.
|
||||||
@@ -64,9 +65,21 @@ class ComponentBase(ABC):
|
|||||||
txn: Any = yield self._inbox.get()
|
txn: Any = yield self._inbox.get()
|
||||||
env.process(self._forward_txn(env, txn))
|
env.process(self._forward_txn(env, txn))
|
||||||
|
|
||||||
|
def _on_process_start(self, env: simpy.Environment, msg: Any) -> None:
|
||||||
|
"""Op log hook: record service start for data_op messages (ADR-0020 D2)."""
|
||||||
|
if self._op_logger and getattr(msg, "data_op", False):
|
||||||
|
self._op_logger.record_start(env.now, self.node.id, msg)
|
||||||
|
|
||||||
|
def _on_process_end(self, env: simpy.Environment, msg: Any) -> None:
|
||||||
|
"""Op log hook: record service end for data_op messages (ADR-0020 D2)."""
|
||||||
|
if self._op_logger and getattr(msg, "data_op", False):
|
||||||
|
self._op_logger.record_end(env.now, self.node.id, msg)
|
||||||
|
|
||||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||||
"""Apply run() latency, then forward to next hop or drain at terminal."""
|
"""Apply run() latency, then forward to next hop or drain at terminal."""
|
||||||
|
self._on_process_start(env, txn)
|
||||||
yield from self.run(env, txn.nbytes)
|
yield from self.run(env, txn.nbytes)
|
||||||
|
self._on_process_end(env, txn)
|
||||||
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
|
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
|
||||||
if next_hop:
|
if next_hop:
|
||||||
yield self.out_ports[next_hop].put(txn.advance())
|
yield self.out_ports[next_hop].put(txn.advance())
|
||||||
@@ -120,10 +133,16 @@ class PeEngineBase(ComponentBase):
|
|||||||
while True:
|
while True:
|
||||||
msg: Any = yield self._inbox.get()
|
msg: Any = yield self._inbox.get()
|
||||||
if isinstance(msg, PeInternalTxn):
|
if isinstance(msg, PeInternalTxn):
|
||||||
env.process(self.handle_command(env, msg))
|
env.process(self._handle_with_hooks(env, msg))
|
||||||
else:
|
else:
|
||||||
env.process(self._forward_txn(env, msg))
|
env.process(self._forward_txn(env, msg))
|
||||||
|
|
||||||
|
def _handle_with_hooks(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||||
|
"""Wrap handle_command with op log hooks on the inner command."""
|
||||||
|
self._on_process_start(env, pe_txn.command)
|
||||||
|
yield from self.handle_command(env, pe_txn)
|
||||||
|
self._on_process_end(env, pe_txn.command)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||||
"""Process a PE-internal command (PeInternalTxn).
|
"""Process a PE-internal command (PeInternalTxn).
|
||||||
|
|||||||
@@ -65,24 +65,45 @@ class PeCpuComponent(ComponentBase):
|
|||||||
yield from self._forward_txn(env, txn)
|
yield from self._forward_txn(env, txn)
|
||||||
|
|
||||||
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||||
"""Compile kernel function and replay command trace."""
|
"""Execute kernel: greenlet mode (ADR-0020) or legacy Phase 0 + replay."""
|
||||||
from kernbench.common.pe_commands import (
|
|
||||||
CompositeCmd,
|
|
||||||
PeCpuOverheadCmd,
|
|
||||||
PeInternalTxn,
|
|
||||||
WaitCmd,
|
|
||||||
)
|
|
||||||
from kernbench.triton_emu.registry import get_kernel
|
from kernbench.triton_emu.registry import get_kernel
|
||||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
|
||||||
|
|
||||||
request = txn.request
|
request = txn.request
|
||||||
|
|
||||||
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
|
|
||||||
yield from self.run(env, 0)
|
yield from self.run(env, 0)
|
||||||
|
|
||||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||||
|
num_programs = self._derive_num_programs(request)
|
||||||
|
kernel_args = self._unpack_kernel_args(request)
|
||||||
|
|
||||||
# Derive num_programs from the number of PE shards in this cube
|
pe_exec_start = env.now
|
||||||
|
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||||
|
|
||||||
|
# Choose execution mode: greenlet (ADR-0020) or legacy command-list
|
||||||
|
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
|
||||||
|
|
||||||
|
if store is not None:
|
||||||
|
composite_results = yield from self._execute_greenlet(
|
||||||
|
env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
composite_results = yield from self._execute_legacy(
|
||||||
|
env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record PE-internal execution time
|
||||||
|
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||||
|
total_dma_ns = 0.0
|
||||||
|
total_compute_ns = 0.0
|
||||||
|
for rd in composite_results:
|
||||||
|
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||||
|
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||||
|
txn.result_data["dma_ns"] = total_dma_ns
|
||||||
|
txn.result_data["compute_ns"] = total_compute_ns
|
||||||
|
|
||||||
|
# Send ResponseMsg on reverse path
|
||||||
|
yield from self._send_response(env, txn, request)
|
||||||
|
|
||||||
|
def _derive_num_programs(self, request: Any) -> int:
|
||||||
num_programs = 1
|
num_programs = 1
|
||||||
for arg in request.args:
|
for arg in request.args:
|
||||||
if arg.arg_kind == "tensor":
|
if arg.arg_kind == "tensor":
|
||||||
@@ -92,11 +113,9 @@ class PeCpuComponent(ComponentBase):
|
|||||||
)
|
)
|
||||||
if cube_pe_count > num_programs:
|
if cube_pe_count > num_programs:
|
||||||
num_programs = cube_pe_count
|
num_programs = cube_pe_count
|
||||||
|
return num_programs
|
||||||
|
|
||||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
def _unpack_kernel_args(self, request: Any) -> list:
|
||||||
|
|
||||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
|
||||||
# TensorArg → va_base (already local, set by runtime) or PA fallback
|
|
||||||
kernel_args: list = []
|
kernel_args: list = []
|
||||||
for arg in request.args:
|
for arg in request.args:
|
||||||
if arg.arg_kind == "tensor":
|
if arg.arg_kind == "tensor":
|
||||||
@@ -107,15 +126,41 @@ class PeCpuComponent(ComponentBase):
|
|||||||
kernel_args.append(shard.pa)
|
kernel_args.append(shard.pa)
|
||||||
elif arg.arg_kind == "scalar":
|
elif arg.arg_kind == "scalar":
|
||||||
kernel_args.append(arg.value)
|
kernel_args.append(arg.value)
|
||||||
|
return kernel_args
|
||||||
|
|
||||||
|
def _execute_greenlet(
|
||||||
|
self, env, kernel_fn, kernel_args, num_programs, scheduler_id, store,
|
||||||
|
) -> Generator:
|
||||||
|
"""Greenlet-based execution (ADR-0020 D3): kernel ↔ SimPy interleaved."""
|
||||||
|
from kernbench.triton_emu.kernel_runner import KernelRunner
|
||||||
|
|
||||||
|
runner = KernelRunner(
|
||||||
|
pe_prefix=self._pe_prefix,
|
||||||
|
pe_idx=self._pe_idx,
|
||||||
|
sip_idx=self._sip_idx,
|
||||||
|
cube_idx=self._cube_idx,
|
||||||
|
scheduler_id=scheduler_id,
|
||||||
|
out_ports=self.out_ports,
|
||||||
|
store=store,
|
||||||
|
)
|
||||||
|
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
|
||||||
|
return getattr(runner, "_composite_results", [])
|
||||||
|
|
||||||
|
def _execute_legacy(
|
||||||
|
self, env, kernel_fn, kernel_args, num_programs, scheduler_id,
|
||||||
|
) -> Generator:
|
||||||
|
"""Legacy Phase 0 + replay: generate command list, then dispatch."""
|
||||||
|
from kernbench.common.pe_commands import (
|
||||||
|
CompositeCmd, PeCpuOverheadCmd, PeInternalTxn, WaitCmd,
|
||||||
|
)
|
||||||
|
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||||
|
|
||||||
|
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||||
run_kernel(kernel_fn, tl, *kernel_args)
|
run_kernel(kernel_fn, tl, *kernel_args)
|
||||||
commands = tl.commands
|
commands = tl.commands
|
||||||
|
|
||||||
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
|
pending: dict[str, simpy.Event] = {}
|
||||||
pe_exec_start = env.now
|
composite_results: list[dict] = []
|
||||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
|
||||||
pending: dict[str, simpy.Event] = {} # completion_id → done event
|
|
||||||
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
|
|
||||||
|
|
||||||
for cmd in commands:
|
for cmd in commands:
|
||||||
if isinstance(cmd, PeCpuOverheadCmd):
|
if isinstance(cmd, PeCpuOverheadCmd):
|
||||||
@@ -126,47 +171,30 @@ class PeCpuComponent(ComponentBase):
|
|||||||
if evt:
|
if evt:
|
||||||
yield evt
|
yield evt
|
||||||
else:
|
else:
|
||||||
# Wait all pending completions
|
|
||||||
for evt in pending.values():
|
for evt in pending.values():
|
||||||
yield evt
|
yield evt
|
||||||
pending.clear()
|
pending.clear()
|
||||||
elif isinstance(cmd, CompositeCmd):
|
elif isinstance(cmd, CompositeCmd):
|
||||||
# Non-blocking: dispatch to scheduler, track completion
|
|
||||||
done_evt = env.event()
|
done_evt = env.event()
|
||||||
pe_txn = PeInternalTxn(
|
pe_txn = PeInternalTxn(
|
||||||
command=cmd, done=done_evt,
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
pe_prefix=self._pe_prefix,
|
|
||||||
)
|
)
|
||||||
composite_results.append(pe_txn.result_data)
|
composite_results.append(pe_txn.result_data)
|
||||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||||
pending[cmd.completion.id] = done_evt
|
pending[cmd.completion.id] = done_evt
|
||||||
else:
|
else:
|
||||||
# Blocking: dispatch and wait for completion
|
|
||||||
done_evt = env.event()
|
done_evt = env.event()
|
||||||
pe_txn = PeInternalTxn(
|
pe_txn = PeInternalTxn(
|
||||||
command=cmd, done=done_evt,
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
pe_prefix=self._pe_prefix,
|
|
||||||
)
|
)
|
||||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||||
yield done_evt
|
yield done_evt
|
||||||
|
|
||||||
# Wait for any remaining pending completions
|
|
||||||
for evt in pending.values():
|
for evt in pending.values():
|
||||||
yield evt
|
yield evt
|
||||||
|
return composite_results
|
||||||
|
|
||||||
# Record PE-internal execution time
|
def _send_response(self, env, txn, request) -> Generator:
|
||||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
|
||||||
|
|
||||||
# Aggregate dma_ns / compute_ns from CompositeCmd results
|
|
||||||
total_dma_ns = 0.0
|
|
||||||
total_compute_ns = 0.0
|
|
||||||
for rd in composite_results:
|
|
||||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
|
||||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
|
||||||
txn.result_data["dma_ns"] = total_dma_ns
|
|
||||||
txn.result_data["compute_ns"] = total_compute_ns
|
|
||||||
|
|
||||||
# Send ResponseMsg on reverse path (PE_CPU → NOC → M_CPU)
|
|
||||||
reverse_path = list(reversed(txn.path))
|
reverse_path = list(reversed(txn.path))
|
||||||
if len(reverse_path) >= 2:
|
if len(reverse_path) >= 2:
|
||||||
from kernbench.runtime_api.kernel import ResponseMsg
|
from kernbench.runtime_api.kernel import ResponseMsg
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class ComponentContext:
|
|||||||
ns_per_mm: float # wire propagation constant (from topology spec)
|
ns_per_mm: float # wire propagation constant (from topology spec)
|
||||||
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
|
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||||
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
|
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
|
||||||
|
memory_store: Any = None # MemoryStore for Phase 1 data-aware execution (ADR-0020)
|
||||||
|
op_logger: Any = None # OpLogger for Phase 1 op recording (ADR-0020)
|
||||||
|
|
||||||
def get_shared_resource(
|
def get_shared_resource(
|
||||||
self, env: simpy.Environment, key: str, capacity: int = 1,
|
self, env: simpy.Environment, key: str, capacity: int = 1,
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""DataExecutor: Phase 2 op_log-based data execution (ADR-0020 D6).
|
||||||
|
|
||||||
|
Executes GEMM/Math operations from the op_log using numpy.
|
||||||
|
Memory ops are skipped (already handled in Phase 1 via MemoryStore).
|
||||||
|
Same-timestamp independent ops can be batched for efficiency.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from itertools import groupby
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from kernbench.sim_engine.memory_store import MemoryStore, _resolve_dtype
|
||||||
|
from kernbench.sim_engine.op_log import OpRecord
|
||||||
|
|
||||||
|
|
||||||
|
class DataExecutor:
|
||||||
|
"""Phase 2 executor: replay op_log with actual numpy computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_log: list of OpRecords from Phase 1.
|
||||||
|
store: MemoryStore snapshot from Phase 1 (contains tensor data).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None:
|
||||||
|
self._op_log = op_log
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Execute all ops in op_log order, grouped by t_start."""
|
||||||
|
for _t, ops_iter in groupby(self._op_log, key=lambda r: r.t_start):
|
||||||
|
ops = list(ops_iter)
|
||||||
|
for op in ops:
|
||||||
|
self._execute_op(op)
|
||||||
|
|
||||||
|
def _execute_op(self, op: OpRecord) -> None:
|
||||||
|
if op.op_kind == "memory":
|
||||||
|
self._execute_memory(op)
|
||||||
|
elif op.op_kind == "gemm":
|
||||||
|
self._execute_gemm(op)
|
||||||
|
elif op.op_kind == "math":
|
||||||
|
self._execute_math(op)
|
||||||
|
|
||||||
|
def _execute_memory(self, op: OpRecord) -> None:
|
||||||
|
"""Memory ops are already handled by Phase 1 MemoryStore. Skip."""
|
||||||
|
|
||||||
|
def _execute_gemm(self, op: OpRecord) -> None:
|
||||||
|
"""Execute GEMM: out = a @ b."""
|
||||||
|
p = op.params
|
||||||
|
if "src_a_addr" not in p:
|
||||||
|
return # composite record without full params
|
||||||
|
space = p.get("addr_space", "tcm")
|
||||||
|
dtype_in = p.get("dtype_in", "f16")
|
||||||
|
dtype_out = p.get("dtype_out", dtype_in)
|
||||||
|
|
||||||
|
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in)
|
||||||
|
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in)
|
||||||
|
|
||||||
|
# Compute in higher precision if specified
|
||||||
|
dtype_acc = p.get("dtype_acc", "f32")
|
||||||
|
a_f = a.astype(_resolve_dtype(dtype_acc))
|
||||||
|
b_f = b.astype(_resolve_dtype(dtype_acc))
|
||||||
|
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
|
||||||
|
|
||||||
|
self.store.write(space, p["dst_addr"], result)
|
||||||
|
|
||||||
|
def _execute_math(self, op: OpRecord) -> None:
|
||||||
|
"""Execute math op: unary, binary, or reduction."""
|
||||||
|
p = op.params
|
||||||
|
math_op = p.get("op", op.op_name)
|
||||||
|
space = p.get("addr_space", "tcm")
|
||||||
|
dtype = p.get("dtype", "f32")
|
||||||
|
input_addrs = p.get("input_addrs", [])
|
||||||
|
input_shapes = p.get("input_shapes", [])
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
for addr, shape in zip(input_addrs, input_shapes):
|
||||||
|
inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype))
|
||||||
|
|
||||||
|
result = _compute_math(math_op, inputs, p.get("axis"))
|
||||||
|
if result is not None:
|
||||||
|
self.store.write(space, p["dst_addr"], result)
|
||||||
|
|
||||||
|
def verify(self, expected: dict[tuple[str, int], np.ndarray],
|
||||||
|
rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]:
|
||||||
|
"""Compare MemoryStore contents against expected tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected: {(space, addr): expected_ndarray}
|
||||||
|
rtol, atol: tolerance for floating-point comparison.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{key_str: passed} dict.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for (space, addr), exp in expected.items():
|
||||||
|
key = f"{space}:0x{addr:x}"
|
||||||
|
try:
|
||||||
|
actual = self.store.read(space, addr)
|
||||||
|
if np.issubdtype(actual.dtype, np.integer):
|
||||||
|
results[key] = bool(np.array_equal(actual, exp))
|
||||||
|
else:
|
||||||
|
results[key] = bool(np.allclose(actual, exp, rtol=rtol, atol=atol))
|
||||||
|
except KeyError:
|
||||||
|
results[key] = False
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.ndarray | None:
|
||||||
|
"""Execute a math operation on numpy arrays."""
|
||||||
|
if not inputs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x = inputs[0]
|
||||||
|
|
||||||
|
# Unary
|
||||||
|
if op == "exp":
|
||||||
|
return np.exp(x)
|
||||||
|
if op == "log":
|
||||||
|
return np.log(x)
|
||||||
|
if op == "sqrt":
|
||||||
|
return np.sqrt(x)
|
||||||
|
if op == "abs":
|
||||||
|
return np.abs(x)
|
||||||
|
if op == "sigmoid":
|
||||||
|
return 1.0 / (1.0 + np.exp(-x))
|
||||||
|
if op == "cos":
|
||||||
|
return np.cos(x)
|
||||||
|
if op == "sin":
|
||||||
|
return np.sin(x)
|
||||||
|
|
||||||
|
# Reduction
|
||||||
|
if op == "sum":
|
||||||
|
return np.sum(x, axis=axis, keepdims=True)
|
||||||
|
if op == "max":
|
||||||
|
return np.max(x, axis=axis, keepdims=True)
|
||||||
|
if op == "min":
|
||||||
|
return np.min(x, axis=axis, keepdims=True)
|
||||||
|
|
||||||
|
# Binary
|
||||||
|
if len(inputs) >= 2:
|
||||||
|
y = inputs[1]
|
||||||
|
if op == "add":
|
||||||
|
return x + y
|
||||||
|
if op == "sub":
|
||||||
|
return x - y
|
||||||
|
if op == "mul":
|
||||||
|
return x * y
|
||||||
|
if op == "div":
|
||||||
|
return x / y
|
||||||
|
|
||||||
|
# Ternary
|
||||||
|
if op == "where" and len(inputs) >= 3:
|
||||||
|
return np.where(inputs[0], inputs[1], inputs[2])
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""MemoryStore: tensor-granular storage for Phase 1 and Phase 2 (ADR-0020 D7).
|
||||||
|
|
||||||
|
Logically byte-addressable, implemented as addr → numpy ndarray mapping.
|
||||||
|
Read/write are reference-based (no copy) for Phase 1 performance.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# numpy dtype string → numpy dtype mapping
|
||||||
|
_DTYPE_MAP = {
|
||||||
|
"f16": np.float16,
|
||||||
|
"f32": np.float32,
|
||||||
|
"f64": np.float64,
|
||||||
|
"bf16": np.float16, # numpy has no bfloat16; use float16 as proxy
|
||||||
|
"i8": np.int8,
|
||||||
|
"i16": np.int16,
|
||||||
|
"i32": np.int32,
|
||||||
|
"i64": np.int64,
|
||||||
|
"u8": np.uint8,
|
||||||
|
"u16": np.uint16,
|
||||||
|
"u32": np.uint32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_dtype(dtype: str) -> np.dtype:
|
||||||
|
if dtype in _DTYPE_MAP:
|
||||||
|
return np.dtype(_DTYPE_MAP[dtype])
|
||||||
|
return np.dtype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStore:
|
||||||
|
"""Tensor-granular memory storage (ADR-0020 D7).
|
||||||
|
|
||||||
|
Stores numpy ndarrays by (space, addr) key.
|
||||||
|
Write = reference store (no copy), read = reference return (no copy).
|
||||||
|
Overwrite at same addr replaces the entire tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# {space: {addr: ndarray}}
|
||||||
|
self._storage: dict[str, dict[int, np.ndarray]] = {}
|
||||||
|
|
||||||
|
def write(self, space: str, addr: int, data: np.ndarray) -> None:
|
||||||
|
"""Store tensor at (space, addr). Reference-only, no copy."""
|
||||||
|
if space not in self._storage:
|
||||||
|
self._storage[space] = {}
|
||||||
|
self._storage[space][addr] = data
|
||||||
|
|
||||||
|
def read(self, space: str, addr: int, shape: tuple[int, ...] | None = None,
|
||||||
|
dtype: str | None = None) -> np.ndarray:
|
||||||
|
"""Read tensor from (space, addr). Returns reference, no copy.
|
||||||
|
|
||||||
|
If shape/dtype match stored tensor, returns as-is.
|
||||||
|
If dtype differs, performs reinterpret cast (view).
|
||||||
|
If shape differs but nbytes match, reshapes.
|
||||||
|
"""
|
||||||
|
store = self._storage.get(space)
|
||||||
|
if store is None or addr not in store:
|
||||||
|
raise KeyError(f"No data at ({space}, 0x{addr:x})")
|
||||||
|
arr = store[addr]
|
||||||
|
if dtype is not None:
|
||||||
|
np_dtype = _resolve_dtype(dtype)
|
||||||
|
if arr.dtype != np_dtype:
|
||||||
|
arr = arr.view(np_dtype)
|
||||||
|
if shape is not None and arr.shape != shape:
|
||||||
|
if arr.nbytes != np.prod(shape) * arr.dtype.itemsize:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shape mismatch: stored {arr.shape} ({arr.nbytes}B) "
|
||||||
|
f"vs requested {shape} ({np.prod(shape) * arr.dtype.itemsize}B)"
|
||||||
|
)
|
||||||
|
arr = arr.reshape(shape)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def has(self, space: str, addr: int) -> bool:
|
||||||
|
return addr in self._storage.get(space, {})
|
||||||
|
|
||||||
|
def snapshot(self) -> MemoryStore:
|
||||||
|
"""Create a shallow copy for Phase 2 initialization."""
|
||||||
|
new = MemoryStore()
|
||||||
|
for space, addrs in self._storage.items():
|
||||||
|
new._storage[space] = dict(addrs) # shallow copy of addr→ndarray map
|
||||||
|
return new
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
"""Op log infrastructure for 2-pass data execution (ADR-0020 D2, D5).
|
||||||
|
|
||||||
|
OpRecord: single data operation with timing, params, and dependencies.
|
||||||
|
OpLogger: collects OpRecords from ComponentBase hooks during Phase 1.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpRecord:
|
||||||
|
"""Single data operation record (ADR-0020 D5)."""
|
||||||
|
|
||||||
|
t_start: float
|
||||||
|
t_end: float
|
||||||
|
component_id: str
|
||||||
|
op_kind: str # "memory" | "gemm" | "math"
|
||||||
|
op_name: str # e.g. "dma_read", "gemm_f16", "exp"
|
||||||
|
params: dict[str, Any]
|
||||||
|
dependency_ids: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class OpLogger:
|
||||||
|
"""Collects OpRecords during Phase 1 simulation (ADR-0020 D2).
|
||||||
|
|
||||||
|
Thread-safe is not required — SimPy is single-threaded.
|
||||||
|
Records are maintained in t_start stable ordering (insertion order).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._records: list[OpRecord] = []
|
||||||
|
self._pending: dict[int, dict[str, Any]] = {} # msg id → partial record
|
||||||
|
|
||||||
|
@property
|
||||||
|
def records(self) -> list[OpRecord]:
|
||||||
|
"""Records sorted by t_start (stable ordering per ADR-0020 D5)."""
|
||||||
|
self._records.sort(key=lambda r: r.t_start)
|
||||||
|
return self._records
|
||||||
|
|
||||||
|
def record_start(self, t: float, component_id: str, msg: Any) -> None:
|
||||||
|
"""Called by ComponentBase._on_process_start."""
|
||||||
|
self._pending[id(msg)] = {
|
||||||
|
"t_start": t,
|
||||||
|
"component_id": component_id,
|
||||||
|
"msg": msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
def record_end(self, t: float, component_id: str, msg: Any) -> None:
|
||||||
|
"""Called by ComponentBase._on_process_end."""
|
||||||
|
pending = self._pending.pop(id(msg), None)
|
||||||
|
if pending is None:
|
||||||
|
return
|
||||||
|
op_kind, op_name, params = _extract_op_info(msg)
|
||||||
|
self._records.append(OpRecord(
|
||||||
|
t_start=pending["t_start"],
|
||||||
|
t_end=t,
|
||||||
|
component_id=pending["component_id"],
|
||||||
|
op_kind=op_kind,
|
||||||
|
op_name=op_name,
|
||||||
|
params=params,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||||
|
"""Extract op_kind, op_name, params from a data_op message."""
|
||||||
|
from kernbench.common.pe_commands import (
|
||||||
|
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd,
|
||||||
|
)
|
||||||
|
if isinstance(msg, DmaReadCmd):
|
||||||
|
return "memory", "dma_read", {
|
||||||
|
"src_addr": msg.src_addr,
|
||||||
|
"nbytes": msg.nbytes,
|
||||||
|
"handle_id": msg.handle.id,
|
||||||
|
}
|
||||||
|
if isinstance(msg, DmaWriteCmd):
|
||||||
|
return "memory", "dma_write", {
|
||||||
|
"dst_addr": msg.dst_addr,
|
||||||
|
"nbytes": msg.nbytes,
|
||||||
|
"handle_id": msg.handle.id,
|
||||||
|
}
|
||||||
|
if isinstance(msg, GemmCmd):
|
||||||
|
return "gemm", f"gemm_{msg.a.dtype}", {
|
||||||
|
"src_a_addr": msg.a.addr,
|
||||||
|
"src_b_addr": msg.b.addr,
|
||||||
|
"dst_addr": msg.out.addr,
|
||||||
|
"shape_a": msg.a.shape,
|
||||||
|
"shape_b": msg.b.shape,
|
||||||
|
"shape_out": msg.out.shape,
|
||||||
|
"dtype_in": msg.a.dtype,
|
||||||
|
"dtype_out": msg.out.dtype,
|
||||||
|
"m": msg.m, "k": msg.k, "n": msg.n,
|
||||||
|
}
|
||||||
|
if isinstance(msg, MathCmd):
|
||||||
|
return "math", msg.op, {
|
||||||
|
"input_addrs": [h.addr for h in msg.inputs],
|
||||||
|
"input_shapes": [h.shape for h in msg.inputs],
|
||||||
|
"dst_addr": msg.out.addr,
|
||||||
|
"shape_out": msg.out.shape,
|
||||||
|
"dtype": msg.out.dtype,
|
||||||
|
"axis": msg.axis,
|
||||||
|
}
|
||||||
|
if isinstance(msg, CompositeCmd):
|
||||||
|
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", {
|
||||||
|
"op": msg.op,
|
||||||
|
"out_addr": msg.out_addr,
|
||||||
|
"out_nbytes": msg.out_nbytes,
|
||||||
|
}
|
||||||
|
# Fallback for unknown data_op messages
|
||||||
|
return "unknown", type(msg).__name__, {}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
"""KernelRunner: greenlet-based kernel ↔ SimPy bridge (ADR-0020 D3).
|
||||||
|
|
||||||
|
Replaces Phase 0 (static command list) with interleaved execution:
|
||||||
|
- tl.load() → SimPy DMA timing + MemoryStore read → real data to kernel
|
||||||
|
- tl.store() → MemoryStore write + SimPy DMA timing
|
||||||
|
- tl.composite(gemm) → SimPy timing + op_log (actual compute in Phase 2)
|
||||||
|
|
||||||
|
The kernel runs as a child greenlet; SimPy loop is the parent.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import simpy
|
||||||
|
from greenlet import greenlet
|
||||||
|
|
||||||
|
from kernbench.common.pe_commands import (
|
||||||
|
CompletionHandle,
|
||||||
|
CompositeCmd,
|
||||||
|
DmaReadCmd,
|
||||||
|
DmaWriteCmd,
|
||||||
|
GemmCmd,
|
||||||
|
MathCmd,
|
||||||
|
PeCommand,
|
||||||
|
PeCpuOverheadCmd,
|
||||||
|
PeInternalTxn,
|
||||||
|
TensorHandle,
|
||||||
|
WaitCmd,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kernbench.sim_engine.memory_store import MemoryStore
|
||||||
|
|
||||||
|
|
||||||
|
class KernelRunner:
|
||||||
|
"""Greenlet ↔ SimPy bridge for kernel execution (ADR-0020 D3).
|
||||||
|
|
||||||
|
PE_CPU creates a KernelRunner and yields from its run() method.
|
||||||
|
The kernel function executes as a plain Python function inside a
|
||||||
|
child greenlet, using TLContext methods that switch back to SimPy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pe_prefix: str,
|
||||||
|
pe_idx: int,
|
||||||
|
sip_idx: int,
|
||||||
|
cube_idx: int,
|
||||||
|
scheduler_id: str,
|
||||||
|
out_ports: dict[str, simpy.Store],
|
||||||
|
store: MemoryStore | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._pe_prefix = pe_prefix
|
||||||
|
self._pe_idx = pe_idx
|
||||||
|
self._sip_idx = sip_idx
|
||||||
|
self._cube_idx = cube_idx
|
||||||
|
self._scheduler_id = scheduler_id
|
||||||
|
self._out_ports = out_ports
|
||||||
|
self._store = store
|
||||||
|
self._parent: greenlet | None = None
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
env: simpy.Environment,
|
||||||
|
kernel_fn: Any,
|
||||||
|
kernel_args: list,
|
||||||
|
num_programs: int,
|
||||||
|
) -> Generator:
|
||||||
|
"""SimPy generator: run kernel with greenlet interleaving.
|
||||||
|
|
||||||
|
This is the SimPy-side loop. It:
|
||||||
|
1. Creates a TLContext connected to this runner
|
||||||
|
2. Spawns the kernel in a child greenlet
|
||||||
|
3. Receives commands via greenlet.switch
|
||||||
|
4. Dispatches each command through SimPy components
|
||||||
|
5. Returns results to the kernel
|
||||||
|
"""
|
||||||
|
from kernbench.triton_emu.tl_context import TLContext
|
||||||
|
|
||||||
|
self._parent = greenlet.getcurrent()
|
||||||
|
|
||||||
|
tl = TLContext(
|
||||||
|
pe_id=self._pe_idx,
|
||||||
|
num_programs=num_programs,
|
||||||
|
dispatch_cycles=0,
|
||||||
|
runner=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _kernel_entry():
|
||||||
|
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||||
|
try:
|
||||||
|
kernel_fn(*kernel_args, tl=tl)
|
||||||
|
finally:
|
||||||
|
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||||
|
return None # signal kernel completion
|
||||||
|
|
||||||
|
g = greenlet(_kernel_entry)
|
||||||
|
pending: dict[str, simpy.Event] = {}
|
||||||
|
composite_results: list[dict] = []
|
||||||
|
|
||||||
|
# Start kernel — first switch returns first command (or None if kernel is done)
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
while cmd is not None:
|
||||||
|
if isinstance(cmd, PeCpuOverheadCmd):
|
||||||
|
yield env.timeout(cmd.cycles)
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
elif isinstance(cmd, WaitCmd):
|
||||||
|
if cmd.handle is not None:
|
||||||
|
evt = pending.pop(cmd.handle.id, None)
|
||||||
|
if evt:
|
||||||
|
yield evt
|
||||||
|
else:
|
||||||
|
for evt in pending.values():
|
||||||
|
yield evt
|
||||||
|
pending.clear()
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
elif isinstance(cmd, DmaReadCmd):
|
||||||
|
# Dispatch DMA through SimPy components
|
||||||
|
done_evt = env.event()
|
||||||
|
pe_txn = PeInternalTxn(
|
||||||
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
|
)
|
||||||
|
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||||
|
yield done_evt
|
||||||
|
|
||||||
|
# Read actual data from MemoryStore (if available)
|
||||||
|
data = None
|
||||||
|
if self._store is not None:
|
||||||
|
try:
|
||||||
|
data = self._store.read(
|
||||||
|
"hbm", cmd.src_addr,
|
||||||
|
shape=cmd.handle.shape, dtype=cmd.handle.dtype,
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
cmd = g.switch(data)
|
||||||
|
|
||||||
|
elif isinstance(cmd, DmaWriteCmd):
|
||||||
|
# Write to MemoryStore first (visibility = issue, ADR-0020 D3)
|
||||||
|
if self._store is not None and cmd.handle.data is not None:
|
||||||
|
self._store.write("hbm", cmd.dst_addr, cmd.handle.data)
|
||||||
|
|
||||||
|
done_evt = env.event()
|
||||||
|
pe_txn = PeInternalTxn(
|
||||||
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
|
)
|
||||||
|
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||||
|
yield done_evt
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
elif isinstance(cmd, CompositeCmd):
|
||||||
|
# Non-blocking composite
|
||||||
|
done_evt = env.event()
|
||||||
|
pe_txn = PeInternalTxn(
|
||||||
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
|
)
|
||||||
|
composite_results.append(pe_txn.result_data)
|
||||||
|
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||||
|
pending[cmd.completion.id] = done_evt
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
elif isinstance(cmd, (GemmCmd, MathCmd)):
|
||||||
|
# Blocking compute command
|
||||||
|
done_evt = env.event()
|
||||||
|
pe_txn = PeInternalTxn(
|
||||||
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
|
)
|
||||||
|
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||||
|
yield done_evt
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unknown command — pass through as blocking
|
||||||
|
done_evt = env.event()
|
||||||
|
pe_txn = PeInternalTxn(
|
||||||
|
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||||
|
)
|
||||||
|
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||||
|
yield done_evt
|
||||||
|
cmd = g.switch()
|
||||||
|
|
||||||
|
# Wait remaining pending composites
|
||||||
|
for evt in pending.values():
|
||||||
|
yield evt
|
||||||
|
|
||||||
|
# Return composite results for PE_CPU aggregation
|
||||||
|
self._composite_results = composite_results
|
||||||
|
|
||||||
|
def switch_to_simpy(self, cmd: PeCommand) -> Any:
|
||||||
|
"""Called from TLContext (child greenlet) to send command to SimPy.
|
||||||
|
|
||||||
|
Returns the result from SimPy (e.g., numpy array for DMA read).
|
||||||
|
"""
|
||||||
|
assert self._parent is not None
|
||||||
|
return self._parent.switch(cmd)
|
||||||
@@ -52,6 +52,7 @@ class TLContext:
|
|||||||
pe_id: int = 0,
|
pe_id: int = 0,
|
||||||
num_programs: int = 1,
|
num_programs: int = 1,
|
||||||
dispatch_cycles: int = 1,
|
dispatch_cycles: int = 1,
|
||||||
|
runner: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._pe_id = pe_id
|
self._pe_id = pe_id
|
||||||
self._num_programs = num_programs
|
self._num_programs = num_programs
|
||||||
@@ -59,6 +60,7 @@ class TLContext:
|
|||||||
self._commands: list[PeCommand] = []
|
self._commands: list[PeCommand] = []
|
||||||
self._handle_counter = 0
|
self._handle_counter = 0
|
||||||
self._completion_counter = 0
|
self._completion_counter = 0
|
||||||
|
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def commands(self) -> list[PeCommand]:
|
def commands(self) -> list[PeCommand]:
|
||||||
@@ -83,7 +85,7 @@ class TLContext:
|
|||||||
|
|
||||||
def _emit_dispatch_overhead(self) -> None:
|
def _emit_dispatch_overhead(self) -> None:
|
||||||
if self._dispatch_cycles > 0:
|
if self._dispatch_cycles > 0:
|
||||||
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||||
|
|
||||||
def _make_handle(
|
def _make_handle(
|
||||||
self, addr: int, shape: tuple[int, ...], dtype: str,
|
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||||
@@ -108,23 +110,38 @@ class TLContext:
|
|||||||
|
|
||||||
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
||||||
|
|
||||||
|
def _emit(self, cmd: PeCommand) -> Any:
|
||||||
|
"""Emit command: greenlet switch if runner available, else append to list."""
|
||||||
|
if self._runner is not None:
|
||||||
|
return self._runner.switch_to_simpy(cmd)
|
||||||
|
self._commands.append(cmd)
|
||||||
|
return None
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
"""Load tensor from HBM to TCM. Returns TensorHandle.
|
||||||
|
|
||||||
|
In greenlet mode: returns TensorHandle with actual numpy data.
|
||||||
|
In command-list mode: returns TensorHandle with data=None.
|
||||||
|
"""
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
||||||
self._commands.append(DmaReadCmd(
|
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
||||||
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
|
data = self._emit(cmd)
|
||||||
))
|
if data is not None:
|
||||||
|
# Greenlet mode: attach real data to handle
|
||||||
|
return TensorHandle(
|
||||||
|
id=handle.id, addr=handle.addr, shape=handle.shape,
|
||||||
|
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
||||||
|
)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||||
"""Store tensor from TCM to HBM."""
|
"""Store tensor from TCM to HBM."""
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(DmaWriteCmd(
|
cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes)
|
||||||
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
|
self._emit(cmd)
|
||||||
))
|
|
||||||
|
|
||||||
# ── GEMM Engine (blocking) ────────────────────────────────────
|
# ── GEMM Engine (blocking) ────────────────────────────────────
|
||||||
|
|
||||||
@@ -143,7 +160,7 @@ class TLContext:
|
|||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
|
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||||
@@ -151,7 +168,7 @@ class TLContext:
|
|||||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||||
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
|
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
self._emit(MathCmd(op=op, inputs=(x,), out=out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def exp(self, x: TensorHandle) -> TensorHandle:
|
def exp(self, x: TensorHandle) -> TensorHandle:
|
||||||
@@ -184,7 +201,7 @@ class TLContext:
|
|||||||
out_shape[axis] = 1
|
out_shape[axis] = 1
|
||||||
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
|
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||||
@@ -203,7 +220,7 @@ class TLContext:
|
|||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def where(
|
def where(
|
||||||
@@ -211,7 +228,7 @@ class TLContext:
|
|||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||||
@@ -276,7 +293,7 @@ class TLContext:
|
|||||||
|
|
||||||
completion = CompletionHandle(id=self._next_completion_id())
|
completion = CompletionHandle(id=self._next_completion_id())
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
self._commands.append(CompositeCmd(
|
self._emit(CompositeCmd(
|
||||||
completion=completion, op=op,
|
completion=completion, op=op,
|
||||||
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
||||||
math_op=math_op,
|
math_op=math_op,
|
||||||
@@ -285,11 +302,11 @@ class TLContext:
|
|||||||
|
|
||||||
def wait(self, handle: CompletionHandle | None = None) -> None:
|
def wait(self, handle: CompletionHandle | None = None) -> None:
|
||||||
"""Wait for a specific composite or all pending composites."""
|
"""Wait for a specific composite or all pending composites."""
|
||||||
self._commands.append(WaitCmd(handle=handle))
|
self._emit(WaitCmd(handle=handle))
|
||||||
|
|
||||||
def cycles(self, n: int) -> None:
|
def cycles(self, n: int) -> None:
|
||||||
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
||||||
self._commands.append(PeCpuOverheadCmd(cycles=n))
|
self._emit(PeCpuOverheadCmd(cycles=n))
|
||||||
|
|
||||||
|
|
||||||
# ── TensorHandle arithmetic operators ─────────────────────────────
|
# ── TensorHandle arithmetic operators ─────────────────────────────
|
||||||
|
|||||||
@@ -0,0 +1,188 @@
|
|||||||
|
"""Tests for DataExecutor Phase 2 execution (ADR-0020 D6)."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from kernbench.sim_engine.data_executor import DataExecutor
|
||||||
|
from kernbench.sim_engine.memory_store import MemoryStore
|
||||||
|
from kernbench.sim_engine.op_log import OpRecord
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemm_execution():
|
||||||
|
"""Phase 2 GEMM: out = a @ b with f32 accumulation."""
|
||||||
|
store = MemoryStore()
|
||||||
|
a = np.ones((4, 8), dtype=np.float16)
|
||||||
|
b = np.ones((8, 4), dtype=np.float16) * 2.0
|
||||||
|
store.write("tcm", 0x0, a)
|
||||||
|
store.write("tcm", 0x100, b)
|
||||||
|
|
||||||
|
op = OpRecord(
|
||||||
|
t_start=0.0, t_end=100.0,
|
||||||
|
component_id="pe_gemm",
|
||||||
|
op_kind="gemm", op_name="gemm_f16",
|
||||||
|
params={
|
||||||
|
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
|
||||||
|
"shape_a": (4, 8), "shape_b": (8, 4), "shape_out": (4, 4),
|
||||||
|
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f16",
|
||||||
|
"addr_space": "tcm",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = DataExecutor([op], store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
result = store.read("tcm", 0x200)
|
||||||
|
expected = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float16)
|
||||||
|
assert np.allclose(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_math_exp():
|
||||||
|
store = MemoryStore()
|
||||||
|
x = np.array([0.0, 1.0, 2.0], dtype=np.float32)
|
||||||
|
store.write("tcm", 0x0, x)
|
||||||
|
|
||||||
|
op = OpRecord(
|
||||||
|
t_start=0.0, t_end=10.0,
|
||||||
|
component_id="pe_math",
|
||||||
|
op_kind="math", op_name="exp",
|
||||||
|
params={
|
||||||
|
"op": "exp",
|
||||||
|
"input_addrs": [0x0], "input_shapes": [(3,)],
|
||||||
|
"dst_addr": 0x100, "shape_out": (3,),
|
||||||
|
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = DataExecutor([op], store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
result = store.read("tcm", 0x100)
|
||||||
|
assert np.allclose(result, np.exp(x))
|
||||||
|
|
||||||
|
|
||||||
|
def test_math_add():
|
||||||
|
store = MemoryStore()
|
||||||
|
a = np.array([1.0, 2.0], dtype=np.float32)
|
||||||
|
b = np.array([3.0, 4.0], dtype=np.float32)
|
||||||
|
store.write("tcm", 0x0, a)
|
||||||
|
store.write("tcm", 0x100, b)
|
||||||
|
|
||||||
|
op = OpRecord(
|
||||||
|
t_start=0.0, t_end=5.0,
|
||||||
|
component_id="pe_math",
|
||||||
|
op_kind="math", op_name="add",
|
||||||
|
params={
|
||||||
|
"op": "add",
|
||||||
|
"input_addrs": [0x0, 0x100], "input_shapes": [(2,), (2,)],
|
||||||
|
"dst_addr": 0x200, "shape_out": (2,),
|
||||||
|
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = DataExecutor([op], store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
result = store.read("tcm", 0x200)
|
||||||
|
assert np.array_equal(result, np.array([4.0, 6.0], dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_math_sum_reduction():
|
||||||
|
store = MemoryStore()
|
||||||
|
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||||
|
store.write("tcm", 0x0, x)
|
||||||
|
|
||||||
|
op = OpRecord(
|
||||||
|
t_start=0.0, t_end=5.0,
|
||||||
|
component_id="pe_math",
|
||||||
|
op_kind="math", op_name="sum",
|
||||||
|
params={
|
||||||
|
"op": "sum",
|
||||||
|
"input_addrs": [0x0], "input_shapes": [(2, 2)],
|
||||||
|
"dst_addr": 0x100, "shape_out": (1, 2),
|
||||||
|
"dtype": "f32", "axis": 0, "addr_space": "tcm",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = DataExecutor([op], store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
result = store.read("tcm", 0x100)
|
||||||
|
assert np.array_equal(result, np.array([[4.0, 6.0]], dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_pass():
|
||||||
|
store = MemoryStore()
|
||||||
|
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
|
||||||
|
|
||||||
|
executor = DataExecutor([], store)
|
||||||
|
results = executor.verify({
|
||||||
|
("hbm", 0x0): np.array([1.0, 2.0], dtype=np.float32),
|
||||||
|
})
|
||||||
|
assert results["hbm:0x0"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_fail():
|
||||||
|
store = MemoryStore()
|
||||||
|
store.write("hbm", 0x0, np.array([1.0, 2.0], dtype=np.float32))
|
||||||
|
|
||||||
|
executor = DataExecutor([], store)
|
||||||
|
results = executor.verify({
|
||||||
|
("hbm", 0x0): np.array([9.0, 9.0], dtype=np.float32),
|
||||||
|
})
|
||||||
|
assert results["hbm:0x0"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_ops_skipped():
|
||||||
|
"""Memory ops in op_log should be skipped (handled in Phase 1)."""
|
||||||
|
store = MemoryStore()
|
||||||
|
op = OpRecord(
|
||||||
|
t_start=0.0, t_end=5.0,
|
||||||
|
component_id="pe_dma",
|
||||||
|
op_kind="memory", op_name="dma_read",
|
||||||
|
params={"src_addr": 0x0, "nbytes": 64, "handle_id": "t0"},
|
||||||
|
)
|
||||||
|
# Should not raise
|
||||||
|
executor = DataExecutor([op], store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_gemm_then_math():
|
||||||
|
"""GEMM output feeds into math op."""
|
||||||
|
store = MemoryStore()
|
||||||
|
a = np.eye(2, dtype=np.float16)
|
||||||
|
b = np.ones((2, 2), dtype=np.float16)
|
||||||
|
store.write("tcm", 0x0, a)
|
||||||
|
store.write("tcm", 0x100, b)
|
||||||
|
|
||||||
|
ops = [
|
||||||
|
OpRecord(
|
||||||
|
t_start=0.0, t_end=50.0,
|
||||||
|
component_id="pe_gemm",
|
||||||
|
op_kind="gemm", op_name="gemm_f16",
|
||||||
|
params={
|
||||||
|
"src_a_addr": 0x0, "src_b_addr": 0x100, "dst_addr": 0x200,
|
||||||
|
"shape_a": (2, 2), "shape_b": (2, 2), "shape_out": (2, 2),
|
||||||
|
"dtype_in": "f16", "dtype_acc": "f32", "dtype_out": "f32",
|
||||||
|
"addr_space": "tcm",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
OpRecord(
|
||||||
|
t_start=50.0, t_end=55.0,
|
||||||
|
component_id="pe_math",
|
||||||
|
op_kind="math", op_name="exp",
|
||||||
|
params={
|
||||||
|
"op": "exp",
|
||||||
|
"input_addrs": [0x200], "input_shapes": [(2, 2)],
|
||||||
|
"dst_addr": 0x300, "shape_out": (2, 2),
|
||||||
|
"dtype": "f32", "axis": None, "addr_space": "tcm",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
executor = DataExecutor(ops, store)
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
gemm_result = store.read("tcm", 0x200)
|
||||||
|
expected_gemm = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.float32)
|
||||||
|
assert np.allclose(gemm_result, expected_gemm)
|
||||||
|
|
||||||
|
exp_result = store.read("tcm", 0x300)
|
||||||
|
assert np.allclose(exp_result, np.exp(expected_gemm))
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
"""Tests for KernelRunner greenlet-based execution (ADR-0020 D3)."""
|
||||||
|
import numpy as np
|
||||||
|
import simpy
|
||||||
|
|
||||||
|
from kernbench.sim_engine.memory_store import MemoryStore
|
||||||
|
from kernbench.triton_emu.kernel_runner import KernelRunner
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(env, store=None):
|
||||||
|
"""Create a minimal KernelRunner with mock scheduler port."""
|
||||||
|
scheduler_id = "sip0.cube0.pe0.pe_scheduler"
|
||||||
|
out_ports = {scheduler_id: simpy.Store(env)}
|
||||||
|
runner = KernelRunner(
|
||||||
|
pe_prefix="sip0.cube0.pe0",
|
||||||
|
pe_idx=0, sip_idx=0, cube_idx=0,
|
||||||
|
scheduler_id=scheduler_id,
|
||||||
|
out_ports=out_ports,
|
||||||
|
store=store,
|
||||||
|
)
|
||||||
|
return runner, out_ports[scheduler_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_scheduler(env, inbox):
|
||||||
|
"""Consume PeInternalTxn from inbox and immediately succeed."""
|
||||||
|
while True:
|
||||||
|
pe_txn = yield inbox.get()
|
||||||
|
pe_txn.done.succeed()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_runner_basic_load():
|
||||||
|
"""Kernel with tl.load runs through greenlet without hanging."""
|
||||||
|
env = simpy.Environment()
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.ones((4, 4), dtype=np.float16)
|
||||||
|
store.write("hbm", 0x1000, data)
|
||||||
|
|
||||||
|
runner, sched_port = _make_runner(env, store)
|
||||||
|
env.process(_mock_scheduler(env, sched_port))
|
||||||
|
|
||||||
|
def kernel(a_ptr, tl):
|
||||||
|
a = tl.load(a_ptr, (4, 4), "f16")
|
||||||
|
assert a.data is not None
|
||||||
|
assert a.data.shape == (4, 4)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
yield from runner.run(env, kernel, [0x1000], num_programs=1)
|
||||||
|
|
||||||
|
env.process(run())
|
||||||
|
env.run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_runner_load_returns_data():
|
||||||
|
"""tl.load returns actual numpy data from MemoryStore."""
|
||||||
|
env = simpy.Environment()
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16)
|
||||||
|
store.write("hbm", 0x2000, data)
|
||||||
|
|
||||||
|
runner, sched_port = _make_runner(env, store)
|
||||||
|
env.process(_mock_scheduler(env, sched_port))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def kernel(ptr, tl):
|
||||||
|
a = tl.load(ptr, (2, 2), "f16")
|
||||||
|
results["data"] = a.data
|
||||||
|
|
||||||
|
def run():
|
||||||
|
yield from runner.run(env, kernel, [0x2000], num_programs=1)
|
||||||
|
|
||||||
|
env.process(run())
|
||||||
|
env.run()
|
||||||
|
assert results["data"] is data # reference equality
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_runner_composite():
|
||||||
|
"""Composite commands pass through without blocking kernel."""
|
||||||
|
env = simpy.Environment()
|
||||||
|
runner, sched_port = _make_runner(env)
|
||||||
|
env.process(_mock_scheduler(env, sched_port))
|
||||||
|
|
||||||
|
def kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||||
|
a = tl.ref(a_ptr, (4, 8), "f16")
|
||||||
|
b = tl.ref(b_ptr, (8, 4), "f16")
|
||||||
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||||
|
tl.wait(h)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
yield from runner.run(env, kernel, [0, 64, 128], num_programs=1)
|
||||||
|
|
||||||
|
env.process(run())
|
||||||
|
env.run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_runner_dynamic_branch():
|
||||||
|
"""Kernel can branch based on loaded data (ADR-0020 D3)."""
|
||||||
|
env = simpy.Environment()
|
||||||
|
store = MemoryStore()
|
||||||
|
store.write("hbm", 0x100, np.array([1.0], dtype=np.float32))
|
||||||
|
store.write("hbm", 0x200, np.array([0.0], dtype=np.float32))
|
||||||
|
|
||||||
|
runner, sched_port = _make_runner(env, store)
|
||||||
|
env.process(_mock_scheduler(env, sched_port))
|
||||||
|
|
||||||
|
results = {"branch": None}
|
||||||
|
|
||||||
|
def kernel(flag_ptr, tl):
|
||||||
|
flag = tl.load(flag_ptr, (1,), "f32")
|
||||||
|
if flag.data is not None and flag.data[0] > 0.5:
|
||||||
|
results["branch"] = "taken"
|
||||||
|
else:
|
||||||
|
results["branch"] = "not_taken"
|
||||||
|
|
||||||
|
# Test with flag=1.0 → branch taken
|
||||||
|
def run():
|
||||||
|
yield from runner.run(env, kernel, [0x100], num_programs=1)
|
||||||
|
|
||||||
|
env.process(run())
|
||||||
|
env.run()
|
||||||
|
assert results["branch"] == "taken"
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_runner_no_store():
|
||||||
|
"""Without MemoryStore, tl.load returns handle with data=None."""
|
||||||
|
env = simpy.Environment()
|
||||||
|
runner, sched_port = _make_runner(env, store=None)
|
||||||
|
env.process(_mock_scheduler(env, sched_port))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def kernel(ptr, tl):
|
||||||
|
a = tl.load(ptr, (4,), "f16")
|
||||||
|
results["data"] = a.data
|
||||||
|
|
||||||
|
def run():
|
||||||
|
yield from runner.run(env, kernel, [0], num_programs=1)
|
||||||
|
|
||||||
|
env.process(run())
|
||||||
|
env.run()
|
||||||
|
assert results["data"] is None
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
"""Tests for MemoryStore (ADR-0020 D7)."""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kernbench.sim_engine.memory_store import MemoryStore
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_read_reference():
|
||||||
|
"""Write and read return the same numpy array (no copy)."""
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.ones((4, 4), dtype=np.float16)
|
||||||
|
store.write("tcm", 0x0, data)
|
||||||
|
result = store.read("tcm", 0x0)
|
||||||
|
assert result is data
|
||||||
|
|
||||||
|
|
||||||
|
def test_overwrite_replaces():
|
||||||
|
"""Same addr write replaces the previous tensor."""
|
||||||
|
store = MemoryStore()
|
||||||
|
data1 = np.zeros((4,), dtype=np.float32)
|
||||||
|
data2 = np.ones((4,), dtype=np.float32)
|
||||||
|
store.write("hbm", 0x100, data1)
|
||||||
|
store.write("hbm", 0x100, data2)
|
||||||
|
result = store.read("hbm", 0x100)
|
||||||
|
assert result is data2
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_missing_raises():
|
||||||
|
store = MemoryStore()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
store.read("hbm", 0x999)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_different_space():
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
store.write("tcm", 0x0, data)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
store.read("hbm", 0x0) # different space
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtype_reinterpret():
|
||||||
|
"""Read with different dtype does view cast."""
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.array([1.0, 2.0], dtype=np.float32) # 8 bytes
|
||||||
|
store.write("tcm", 0x0, data)
|
||||||
|
result = store.read("tcm", 0x0, dtype="u8")
|
||||||
|
assert result.dtype == np.uint8
|
||||||
|
assert result.nbytes == data.nbytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape():
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.arange(12, dtype=np.float32)
|
||||||
|
store.write("tcm", 0x0, data)
|
||||||
|
result = store.read("tcm", 0x0, shape=(3, 4))
|
||||||
|
assert result.shape == (3, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shape_mismatch_raises():
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.arange(12, dtype=np.float32)
|
||||||
|
store.write("tcm", 0x0, data)
|
||||||
|
with pytest.raises(ValueError, match="Shape mismatch"):
|
||||||
|
store.read("tcm", 0x0, shape=(5, 5))
|
||||||
|
|
||||||
|
|
||||||
|
def test_has():
|
||||||
|
store = MemoryStore()
|
||||||
|
assert not store.has("tcm", 0x0)
|
||||||
|
store.write("tcm", 0x0, np.array([1]))
|
||||||
|
assert store.has("tcm", 0x0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_snapshot():
|
||||||
|
store = MemoryStore()
|
||||||
|
data = np.ones((4,), dtype=np.float16)
|
||||||
|
store.write("hbm", 0x0, data)
|
||||||
|
|
||||||
|
snap = store.snapshot()
|
||||||
|
assert snap.read("hbm", 0x0) is data # same reference
|
||||||
|
|
||||||
|
# Modifying snap doesn't affect original
|
||||||
|
snap.write("hbm", 0x0, np.zeros((4,), dtype=np.float16))
|
||||||
|
assert store.read("hbm", 0x0) is data
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
"""Tests for OpLogger and OpRecord (ADR-0020 D2/D5)."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from kernbench.common.pe_commands import (
|
||||||
|
DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, TensorHandle,
|
||||||
|
)
|
||||||
|
from kernbench.sim_engine.op_log import OpLogger, OpRecord
|
||||||
|
|
||||||
|
|
||||||
|
def _th(name="t0", addr=0, shape=(4, 4), dtype="f16"):
|
||||||
|
return TensorHandle(id=name, addr=addr, shape=shape, dtype=dtype, nbytes=32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_logger_record_start_end():
|
||||||
|
logger = OpLogger()
|
||||||
|
cmd = DmaReadCmd(handle=_th(), src_addr=0x1000, nbytes=64)
|
||||||
|
logger.record_start(10.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||||
|
logger.record_end(15.0, "sip0.cube0.pe0.pe_dma", cmd)
|
||||||
|
assert len(logger.records) == 1
|
||||||
|
r = logger.records[0]
|
||||||
|
assert r.t_start == 10.0
|
||||||
|
assert r.t_end == 15.0
|
||||||
|
assert r.op_kind == "memory"
|
||||||
|
assert r.op_name == "dma_read"
|
||||||
|
assert r.params["src_addr"] == 0x1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_logger_gemm():
|
||||||
|
logger = OpLogger()
|
||||||
|
a = _th("a", 0, (128, 256), "f16")
|
||||||
|
b = _th("b", 1024, (256, 128), "f16")
|
||||||
|
out = _th("out", 2048, (128, 128), "f16")
|
||||||
|
cmd = GemmCmd(a=a, b=b, out=out, m=128, k=256, n=128)
|
||||||
|
logger.record_start(0.0, "pe_gemm", cmd)
|
||||||
|
logger.record_end(100.0, "pe_gemm", cmd)
|
||||||
|
r = logger.records[0]
|
||||||
|
assert r.op_kind == "gemm"
|
||||||
|
assert r.op_name == "gemm_f16"
|
||||||
|
assert r.params["m"] == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_logger_math():
|
||||||
|
logger = OpLogger()
|
||||||
|
x = _th("x", 0, (32,), "f32")
|
||||||
|
out = _th("out", 128, (32,), "f32")
|
||||||
|
cmd = MathCmd(op="exp", inputs=(x,), out=out)
|
||||||
|
logger.record_start(5.0, "pe_math", cmd)
|
||||||
|
logger.record_end(6.0, "pe_math", cmd)
|
||||||
|
r = logger.records[0]
|
||||||
|
assert r.op_kind == "math"
|
||||||
|
assert r.op_name == "exp"
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_logger_stable_ordering():
|
||||||
|
logger = OpLogger()
|
||||||
|
cmds = [
|
||||||
|
DmaReadCmd(handle=_th(f"t{i}"), src_addr=i * 100, nbytes=64)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
for i, cmd in enumerate(cmds):
|
||||||
|
logger.record_start(float(i % 3), f"comp{i}", cmd) # some share t_start
|
||||||
|
for i, cmd in enumerate(cmds):
|
||||||
|
logger.record_end(float(i % 3) + 1.0, f"comp{i}", cmd)
|
||||||
|
|
||||||
|
# Verify insertion order preserved for same t_start
|
||||||
|
for i in range(len(logger.records) - 1):
|
||||||
|
assert logger.records[i].t_start <= logger.records[i + 1].t_start
|
||||||
|
|
||||||
|
|
||||||
|
def test_op_logger_unmatched_end_ignored():
|
||||||
|
logger = OpLogger()
|
||||||
|
cmd = DmaReadCmd(handle=_th(), src_addr=0, nbytes=32)
|
||||||
|
logger.record_end(5.0, "comp", cmd) # no matching start
|
||||||
|
assert len(logger.records) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_op_flag():
|
||||||
|
"""DmaReadCmd, GemmCmd, MathCmd have data_op=True; others don't."""
|
||||||
|
assert getattr(DmaReadCmd(handle=_th(), src_addr=0, nbytes=32), "data_op", False)
|
||||||
|
assert getattr(DmaWriteCmd(handle=_th(), dst_addr=0, nbytes=32), "data_op", False)
|
||||||
|
a, b, out = _th("a"), _th("b"), _th("out")
|
||||||
|
assert getattr(GemmCmd(a=a, b=b, out=out, m=4, k=4, n=4), "data_op", False)
|
||||||
|
assert getattr(MathCmd(op="exp", inputs=(a,), out=out), "data_op", False)
|
||||||
|
# WaitCmd and PeCpuOverheadCmd should not have data_op
|
||||||
|
from kernbench.common.pe_commands import WaitCmd, PeCpuOverheadCmd
|
||||||
|
assert not getattr(WaitCmd(), "data_op", False)
|
||||||
|
assert not getattr(PeCpuOverheadCmd(cycles=10), "data_op", False)
|
||||||
Reference in New Issue
Block a user