sim_engine: fix IPCQ slot-wrap snapshot race in Phase 2 replay

Phase 1 cannot snapshot math-output sources at outbound send time
because math executes only in Phase 2 — so token.data stays None and
PE_DMA inbound can't write the recv slot. For own-sends this is harmless
(Phase 2 replay reads the stable scratch addr after math runs). For
forwarded sends in mesh kernels (ADR-0059), src_addr is a recv slot
that gets wrapped by later inbounds before this read's Phase 2 turn,
yielding a shape mismatch on the fallback MemoryStore.read.

Fix: DataExecutor maintains a per-slot, time-ordered, shape-keyed
history. Every ipcq_copy write appends (t_write, value) to the slot's
history; _resolve_read falls back to the most recent shape-matching
entry with t_write <= the consuming op's t_start. Applied uniformly
to _execute_memory, _execute_gemm, and _execute_math.

Secondary: OpLogger.record_end for math ops now prefers
TensorHandle.data carried by the input handle over a MemoryStore
re-read, closing the smaller record-end race covered by the new
test_op_log_input_snapshot_race.py unit tests.

Tests: 4 new race tests + 6 existing op_log + mesh decode diag +
mesh kv/mlo spec — all green. Full repo sweep: 760 passed (3
pre-existing failures unrelated: bench-registry list drift +
Windows Tkinter env).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-01 19:14:09 -07:00
parent b1d6fafd3a
commit 313dee503c
3 changed files with 290 additions and 14 deletions
+61 -12
View File
@@ -25,6 +25,37 @@ class DataExecutor:
def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None:
self._op_log = op_log
self.store = store
# Per-slot time-ordered shape-keyed history. Populated on every
# ipcq_copy WRITE; consulted on reads that find a shape-mismatched
# value in MemoryStore (the slot was wrapped by a later inbound
# before this read's Phase 2 turn). Required because Phase 1 cannot
# snapshot math-output sources at outbound time (math executes only
# in Phase 2), so token.data is None and slot wraps lose the recv-
# time value. See test_attention_mesh_decode_diag (ADR-0059 mesh).
self._slot_history: dict[tuple[str, int], list[tuple[float, Any]]] = {}
def _resolve_read(
self, space: str, addr: int,
shape: tuple[int, ...] | None, dtype: str | None,
t_at_or_before: float,
) -> Any:
"""Read (space, addr) with expected shape. On KeyError or shape
mismatch in MemoryStore, fall back to ``_slot_history`` for the
most recent shape-matching entry with t_write <= t_at_or_before.
Returns None when no match is found."""
try:
return self.store.read(space, addr, shape=shape, dtype=dtype)
except (KeyError, ValueError):
pass
hist = self._slot_history.get((space, addr))
if hist is None:
return None
for t_w, val in reversed(hist):
if t_w > t_at_or_before:
continue
if shape is None or getattr(val, "shape", None) == shape:
return val
return None
# Ordering priority within the same t_start: memory copies must run
# before math/gemm so that slot data is populated before a consumer
@@ -87,14 +118,23 @@ class DataExecutor:
# only get populated by Phase 2's math replay).
data = p.get("snapshot")
if data is None:
try:
data = self.store.read(
src_space, src_addr,
shape=p.get("shape"), dtype=p.get("dtype"),
)
except KeyError:
data = self._resolve_read(
src_space, src_addr,
p.get("shape"), p.get("dtype"), op.t_start,
)
if data is None:
return
self.store.write(dst_space, dst_addr, data)
# Record this write in slot history so a later forwarded read
# at src=dst_addr (a different ipcq_copy whose src is this slot)
# can recover by shape even after the slot has been wrapped.
if op.op_name == "ipcq_copy":
self._slot_history.setdefault(
(dst_space, dst_addr), [],
).append((
op.t_start,
data.copy() if hasattr(data, "copy") else data,
))
def _execute_gemm(self, op: OpRecord) -> None:
"""Execute GEMM: out = a @ b."""
@@ -110,10 +150,16 @@ class DataExecutor:
dtype_in = p.get("dtype_in", "f16")
dtype_out = p.get("dtype_out", dtype_in)
a = self.store.read(src_a_space, p["src_a_addr"],
shape=p.get("shape_a"), dtype=dtype_in)
b = self.store.read(src_b_space, p["src_b_addr"],
shape=p.get("shape_b"), dtype=dtype_in)
a = self._resolve_read(src_a_space, p["src_a_addr"],
p.get("shape_a"), dtype_in, op.t_start)
if a is None:
a = self.store.read(src_a_space, p["src_a_addr"],
shape=p.get("shape_a"), dtype=dtype_in)
b = self._resolve_read(src_b_space, p["src_b_addr"],
p.get("shape_b"), dtype_in, op.t_start)
if b is None:
b = self.store.read(src_b_space, p["src_b_addr"],
shape=p.get("shape_b"), dtype=dtype_in)
# Compute in higher precision if specified
dtype_acc = p.get("dtype_acc", "f32")
@@ -150,8 +196,11 @@ class DataExecutor:
):
if snap is not None:
inputs.append(snap)
else:
inputs.append(self.store.read(space, addr, shape=shape, dtype=idtype))
continue
resolved = self._resolve_read(space, addr, shape, idtype, op.t_start)
if resolved is None:
resolved = self.store.read(space, addr, shape=shape, dtype=idtype)
inputs.append(resolved)
result = _compute_math(math_op, inputs, p.get("axis"))
if result is not None:
+11 -2
View File
@@ -96,13 +96,20 @@ class OpLogger:
# gets reused on the next ring round).
if self._memory_store is not None:
if op_kind == "math":
handle_snaps = params.get("input_handle_data") or ()
snaps: list[Any] = []
for addr, shape, space, idtype in zip(
for i, (addr, shape, space, idtype) in enumerate(zip(
params.get("input_addrs", []),
params.get("input_shapes", []),
params.get("input_spaces", []),
params.get("input_dtypes", []),
):
)):
if i < len(handle_snaps) and handle_snaps[i] is not None:
carried = handle_snaps[i]
snaps.append(
carried.copy() if hasattr(carried, "copy") else carried
)
continue
try:
arr = self._memory_store.read(
space, addr, shape=shape, dtype=idtype,
@@ -111,6 +118,7 @@ class OpLogger:
except Exception:
snaps.append(None)
params["input_snapshots"] = snaps
params.pop("input_handle_data", None)
elif op_name == "dma_write":
# ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch)
# sources are repopulated by Phase 2 math/gemm replay —
@@ -222,6 +230,7 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"input_shapes": [h.shape for h in msg.inputs],
"input_spaces": [getattr(h, "space", "tcm") for h in msg.inputs],
"input_dtypes": [h.dtype for h in msg.inputs],
"input_handle_data": tuple(getattr(h, "data", None) for h in msg.inputs),
"dst_addr": msg.out.addr,
"dst_space": getattr(msg.out, "space", "tcm"),
"shape_out": msg.out.shape,