Implement ADR-0020: 2-pass data execution with greenlet kernel runner
Step 1 — Foundation: - OpRecord/OpLogger: op log infrastructure with t_start stable ordering - MemoryStore: numpy ndarray tensor-granular storage (reference semantics) - data_op=True flag on DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd, CompositeCmd - numpy/greenlet dependencies added to pyproject.toml Step 2 — ComponentBase hooks: - _on_process_start/end hooks in _forward_txn (fabric messages) - _handle_with_hooks in PeEngineBase (PE-internal commands) - op_logger optional — zero overhead when disabled Step 3 — KernelRunner + greenlet: - KernelRunner: greenlet ↔ SimPy bridge in triton_emu/kernel_runner.py - TLContext: _emit() method routes to greenlet switch or command list - tl.load() returns real numpy data in greenlet mode - Dynamic control flow supported (memory-read based branching) Step 4 — PE_CPU integration: - Greenlet mode when ctx.memory_store is set, legacy fallback otherwise - Refactored into _execute_greenlet/_execute_legacy/_send_response - ComponentContext gains memory_store and op_logger fields Step 5 — DataExecutor: - Phase 2 numpy execution for GEMM/Math ops from op_log - _compute_math: all unary/binary/reduction ops - verify(): compare MemoryStore against expected with dtype tolerance 28 new tests, 366 total passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
"""KernelRunner: greenlet-based kernel ↔ SimPy bridge (ADR-0020 D3).
|
||||
|
||||
Replaces Phase 0 (static command list) with interleaved execution:
|
||||
- tl.load() → SimPy DMA timing + MemoryStore read → real data to kernel
|
||||
- tl.store() → MemoryStore write + SimPy DMA timing
|
||||
- tl.composite(gemm) → SimPy timing + op_log (actual compute in Phase 2)
|
||||
|
||||
The kernel runs as a child greenlet; SimPy loop is the parent.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
PeCommand,
|
||||
PeCpuOverheadCmd,
|
||||
PeInternalTxn,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
|
||||
|
||||
class KernelRunner:
|
||||
"""Greenlet ↔ SimPy bridge for kernel execution (ADR-0020 D3).
|
||||
|
||||
PE_CPU creates a KernelRunner and yields from its run() method.
|
||||
The kernel function executes as a plain Python function inside a
|
||||
child greenlet, using TLContext methods that switch back to SimPy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pe_prefix: str,
|
||||
pe_idx: int,
|
||||
sip_idx: int,
|
||||
cube_idx: int,
|
||||
scheduler_id: str,
|
||||
out_ports: dict[str, simpy.Store],
|
||||
store: MemoryStore | None = None,
|
||||
) -> None:
|
||||
self._pe_prefix = pe_prefix
|
||||
self._pe_idx = pe_idx
|
||||
self._sip_idx = sip_idx
|
||||
self._cube_idx = cube_idx
|
||||
self._scheduler_id = scheduler_id
|
||||
self._out_ports = out_ports
|
||||
self._store = store
|
||||
self._parent: greenlet | None = None
|
||||
|
||||
def run(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
kernel_fn: Any,
|
||||
kernel_args: list,
|
||||
num_programs: int,
|
||||
) -> Generator:
|
||||
"""SimPy generator: run kernel with greenlet interleaving.
|
||||
|
||||
This is the SimPy-side loop. It:
|
||||
1. Creates a TLContext connected to this runner
|
||||
2. Spawns the kernel in a child greenlet
|
||||
3. Receives commands via greenlet.switch
|
||||
4. Dispatches each command through SimPy components
|
||||
5. Returns results to the kernel
|
||||
"""
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
self._parent = greenlet.getcurrent()
|
||||
|
||||
tl = TLContext(
|
||||
pe_id=self._pe_idx,
|
||||
num_programs=num_programs,
|
||||
dispatch_cycles=0,
|
||||
runner=self,
|
||||
)
|
||||
|
||||
def _kernel_entry():
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(*kernel_args, tl=tl)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
return None # signal kernel completion
|
||||
|
||||
g = greenlet(_kernel_entry)
|
||||
pending: dict[str, simpy.Event] = {}
|
||||
composite_results: list[dict] = []
|
||||
|
||||
# Start kernel — first switch returns first command (or None if kernel is done)
|
||||
cmd = g.switch()
|
||||
|
||||
while cmd is not None:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
cmd = g.switch()
|
||||
|
||||
elif isinstance(cmd, WaitCmd):
|
||||
if cmd.handle is not None:
|
||||
evt = pending.pop(cmd.handle.id, None)
|
||||
if evt:
|
||||
yield evt
|
||||
else:
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
cmd = g.switch()
|
||||
|
||||
elif isinstance(cmd, DmaReadCmd):
|
||||
# Dispatch DMA through SimPy components
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
|
||||
# Read actual data from MemoryStore (if available)
|
||||
data = None
|
||||
if self._store is not None:
|
||||
try:
|
||||
data = self._store.read(
|
||||
"hbm", cmd.src_addr,
|
||||
shape=cmd.handle.shape, dtype=cmd.handle.dtype,
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
cmd = g.switch(data)
|
||||
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
# Write to MemoryStore first (visibility = issue, ADR-0020 D3)
|
||||
if self._store is not None and cmd.handle.data is not None:
|
||||
self._store.write("hbm", cmd.dst_addr, cmd.handle.data)
|
||||
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking composite
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
cmd = g.switch()
|
||||
|
||||
elif isinstance(cmd, (GemmCmd, MathCmd)):
|
||||
# Blocking compute command
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
|
||||
else:
|
||||
# Unknown command — pass through as blocking
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt, pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
|
||||
# Wait remaining pending composites
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
|
||||
# Return composite results for PE_CPU aggregation
|
||||
self._composite_results = composite_results
|
||||
|
||||
def switch_to_simpy(self, cmd: PeCommand) -> Any:
|
||||
"""Called from TLContext (child greenlet) to send command to SimPy.
|
||||
|
||||
Returns the result from SimPy (e.g., numpy array for DMA read).
|
||||
"""
|
||||
assert self._parent is not None
|
||||
return self._parent.switch(cmd)
|
||||
@@ -52,6 +52,7 @@ class TLContext:
|
||||
pe_id: int = 0,
|
||||
num_programs: int = 1,
|
||||
dispatch_cycles: int = 1,
|
||||
runner: Any = None,
|
||||
) -> None:
|
||||
self._pe_id = pe_id
|
||||
self._num_programs = num_programs
|
||||
@@ -59,6 +60,7 @@ class TLContext:
|
||||
self._commands: list[PeCommand] = []
|
||||
self._handle_counter = 0
|
||||
self._completion_counter = 0
|
||||
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
|
||||
|
||||
@property
|
||||
def commands(self) -> list[PeCommand]:
|
||||
@@ -83,7 +85,7 @@ class TLContext:
|
||||
|
||||
def _emit_dispatch_overhead(self) -> None:
|
||||
if self._dispatch_cycles > 0:
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||
self._emit(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||
|
||||
def _make_handle(
|
||||
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||
@@ -108,23 +110,38 @@ class TLContext:
|
||||
|
||||
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
||||
|
||||
def _emit(self, cmd: PeCommand) -> Any:
|
||||
"""Emit command: greenlet switch if runner available, else append to list."""
|
||||
if self._runner is not None:
|
||||
return self._runner.switch_to_simpy(cmd)
|
||||
self._commands.append(cmd)
|
||||
return None
|
||||
|
||||
def load(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
||||
"""Load tensor from HBM to TCM. Returns TensorHandle.
|
||||
|
||||
In greenlet mode: returns TensorHandle with actual numpy data.
|
||||
In command-list mode: returns TensorHandle with data=None.
|
||||
"""
|
||||
self._emit_dispatch_overhead()
|
||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
||||
self._commands.append(DmaReadCmd(
|
||||
handle=handle, src_addr=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
||||
data = self._emit(cmd)
|
||||
if data is not None:
|
||||
# Greenlet mode: attach real data to handle
|
||||
return TensorHandle(
|
||||
id=handle.id, addr=handle.addr, shape=handle.shape,
|
||||
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
||||
)
|
||||
return handle
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
"""Store tensor from TCM to HBM."""
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(DmaWriteCmd(
|
||||
handle=handle, dst_addr=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
cmd = DmaWriteCmd(handle=handle, dst_addr=ptr, nbytes=handle.nbytes)
|
||||
self._emit(cmd)
|
||||
|
||||
# ── GEMM Engine (blocking) ────────────────────────────────────
|
||||
|
||||
@@ -143,7 +160,7 @@ class TLContext:
|
||||
out_dtype = a.dtype
|
||||
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||
return out
|
||||
|
||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||
@@ -151,7 +168,7 @@ class TLContext:
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
||||
self._emit(MathCmd(op=op, inputs=(x,), out=out))
|
||||
return out
|
||||
|
||||
def exp(self, x: TensorHandle) -> TensorHandle:
|
||||
@@ -184,7 +201,7 @@ class TLContext:
|
||||
out_shape[axis] = 1
|
||||
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||
return out
|
||||
|
||||
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
@@ -203,7 +220,7 @@ class TLContext:
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
||||
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
|
||||
return out
|
||||
|
||||
def where(
|
||||
@@ -211,7 +228,7 @@ class TLContext:
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||
return out
|
||||
|
||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||
@@ -276,7 +293,7 @@ class TLContext:
|
||||
|
||||
completion = CompletionHandle(id=self._next_completion_id())
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(CompositeCmd(
|
||||
self._emit(CompositeCmd(
|
||||
completion=completion, op=op,
|
||||
a=a, b=b, out_addr=out_ptr, out_nbytes=out_nbytes,
|
||||
math_op=math_op,
|
||||
@@ -285,11 +302,11 @@ class TLContext:
|
||||
|
||||
def wait(self, handle: CompletionHandle | None = None) -> None:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
self._commands.append(WaitCmd(handle=handle))
|
||||
self._emit(WaitCmd(handle=handle))
|
||||
|
||||
def cycles(self, n: int) -> None:
|
||||
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=n))
|
||||
self._emit(PeCpuOverheadCmd(cycles=n))
|
||||
|
||||
|
||||
# ── TensorHandle arithmetic operators ─────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user