sim_engine: fix IPCQ slot-wrap snapshot race in Phase 2 replay
Phase 1 cannot snapshot math-output sources at outbound send time because math executes only in Phase 2 — so token.data stays None and PE_DMA inbound can't write the recv slot. For own-sends this is harmless (Phase 2 replay reads the stable scratch addr after math runs). For forwarded sends in mesh kernels (ADR-0059), src_addr is a recv slot that gets wrapped by later inbounds before this read's Phase 2 turn, yielding a shape mismatch on the fallback MemoryStore.read. Fix: DataExecutor maintains a per-slot, time-ordered, shape-keyed history. Every ipcq_copy write appends (t_write, value) to the slot's history; _resolve_read falls back to the most recent shape-matching entry with t_write <= the consuming op's t_start. Applied uniformly to _execute_memory, _execute_gemm, and _execute_math. Secondary: OpLogger.record_end for math ops now prefers TensorHandle.data carried by the input handle over a MemoryStore re-read, closing the smaller record-end race covered by the new test_op_log_input_snapshot_race.py unit tests. Tests: 4 new race tests + 6 existing op_log + mesh decode diag + mesh kv/mlo spec — all green. Full repo sweep: 760 passed (3 pre-existing failures unrelated: bench-registry list drift + Windows Tkinter env). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -25,6 +25,37 @@ class DataExecutor:
|
||||
def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None:
|
||||
self._op_log = op_log
|
||||
self.store = store
|
||||
# Per-slot time-ordered shape-keyed history. Populated on every
|
||||
# ipcq_copy WRITE; consulted on reads that find a shape-mismatched
|
||||
# value in MemoryStore (the slot was wrapped by a later inbound
|
||||
# before this read's Phase 2 turn). Required because Phase 1 cannot
|
||||
# snapshot math-output sources at outbound time (math executes only
|
||||
# in Phase 2), so token.data is None and slot wraps lose the recv-
|
||||
# time value. See test_attention_mesh_decode_diag (ADR-0059 mesh).
|
||||
self._slot_history: dict[tuple[str, int], list[tuple[float, Any]]] = {}
|
||||
|
||||
def _resolve_read(
|
||||
self, space: str, addr: int,
|
||||
shape: tuple[int, ...] | None, dtype: str | None,
|
||||
t_at_or_before: float,
|
||||
) -> Any:
|
||||
"""Read (space, addr) with expected shape. On KeyError or shape
|
||||
mismatch in MemoryStore, fall back to ``_slot_history`` for the
|
||||
most recent shape-matching entry with t_write <= t_at_or_before.
|
||||
Returns None when no match is found."""
|
||||
try:
|
||||
return self.store.read(space, addr, shape=shape, dtype=dtype)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
hist = self._slot_history.get((space, addr))
|
||||
if hist is None:
|
||||
return None
|
||||
for t_w, val in reversed(hist):
|
||||
if t_w > t_at_or_before:
|
||||
continue
|
||||
if shape is None or getattr(val, "shape", None) == shape:
|
||||
return val
|
||||
return None
|
||||
|
||||
# Ordering priority within the same t_start: memory copies must run
|
||||
# before math/gemm so that slot data is populated before a consumer
|
||||
@@ -87,14 +118,23 @@ class DataExecutor:
|
||||
# only get populated by Phase 2's math replay).
|
||||
data = p.get("snapshot")
|
||||
if data is None:
|
||||
try:
|
||||
data = self.store.read(
|
||||
src_space, src_addr,
|
||||
shape=p.get("shape"), dtype=p.get("dtype"),
|
||||
)
|
||||
except KeyError:
|
||||
data = self._resolve_read(
|
||||
src_space, src_addr,
|
||||
p.get("shape"), p.get("dtype"), op.t_start,
|
||||
)
|
||||
if data is None:
|
||||
return
|
||||
self.store.write(dst_space, dst_addr, data)
|
||||
# Record this write in slot history so a later forwarded read
|
||||
# at src=dst_addr (a different ipcq_copy whose src is this slot)
|
||||
# can recover by shape even after the slot has been wrapped.
|
||||
if op.op_name == "ipcq_copy":
|
||||
self._slot_history.setdefault(
|
||||
(dst_space, dst_addr), [],
|
||||
).append((
|
||||
op.t_start,
|
||||
data.copy() if hasattr(data, "copy") else data,
|
||||
))
|
||||
|
||||
def _execute_gemm(self, op: OpRecord) -> None:
|
||||
"""Execute GEMM: out = a @ b."""
|
||||
@@ -110,10 +150,16 @@ class DataExecutor:
|
||||
dtype_in = p.get("dtype_in", "f16")
|
||||
dtype_out = p.get("dtype_out", dtype_in)
|
||||
|
||||
a = self.store.read(src_a_space, p["src_a_addr"],
|
||||
shape=p.get("shape_a"), dtype=dtype_in)
|
||||
b = self.store.read(src_b_space, p["src_b_addr"],
|
||||
shape=p.get("shape_b"), dtype=dtype_in)
|
||||
a = self._resolve_read(src_a_space, p["src_a_addr"],
|
||||
p.get("shape_a"), dtype_in, op.t_start)
|
||||
if a is None:
|
||||
a = self.store.read(src_a_space, p["src_a_addr"],
|
||||
shape=p.get("shape_a"), dtype=dtype_in)
|
||||
b = self._resolve_read(src_b_space, p["src_b_addr"],
|
||||
p.get("shape_b"), dtype_in, op.t_start)
|
||||
if b is None:
|
||||
b = self.store.read(src_b_space, p["src_b_addr"],
|
||||
shape=p.get("shape_b"), dtype=dtype_in)
|
||||
|
||||
# Compute in higher precision if specified
|
||||
dtype_acc = p.get("dtype_acc", "f32")
|
||||
@@ -150,8 +196,11 @@ class DataExecutor:
|
||||
):
|
||||
if snap is not None:
|
||||
inputs.append(snap)
|
||||
else:
|
||||
inputs.append(self.store.read(space, addr, shape=shape, dtype=idtype))
|
||||
continue
|
||||
resolved = self._resolve_read(space, addr, shape, idtype, op.t_start)
|
||||
if resolved is None:
|
||||
resolved = self.store.read(space, addr, shape=shape, dtype=idtype)
|
||||
inputs.append(resolved)
|
||||
|
||||
result = _compute_math(math_op, inputs, p.get("axis"))
|
||||
if result is not None:
|
||||
|
||||
@@ -96,13 +96,20 @@ class OpLogger:
|
||||
# gets reused on the next ring round).
|
||||
if self._memory_store is not None:
|
||||
if op_kind == "math":
|
||||
handle_snaps = params.get("input_handle_data") or ()
|
||||
snaps: list[Any] = []
|
||||
for addr, shape, space, idtype in zip(
|
||||
for i, (addr, shape, space, idtype) in enumerate(zip(
|
||||
params.get("input_addrs", []),
|
||||
params.get("input_shapes", []),
|
||||
params.get("input_spaces", []),
|
||||
params.get("input_dtypes", []),
|
||||
):
|
||||
)):
|
||||
if i < len(handle_snaps) and handle_snaps[i] is not None:
|
||||
carried = handle_snaps[i]
|
||||
snaps.append(
|
||||
carried.copy() if hasattr(carried, "copy") else carried
|
||||
)
|
||||
continue
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
space, addr, shape=shape, dtype=idtype,
|
||||
@@ -111,6 +118,7 @@ class OpLogger:
|
||||
except Exception:
|
||||
snaps.append(None)
|
||||
params["input_snapshots"] = snaps
|
||||
params.pop("input_handle_data", None)
|
||||
elif op_name == "dma_write":
|
||||
# ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch)
|
||||
# sources are repopulated by Phase 2 math/gemm replay —
|
||||
@@ -222,6 +230,7 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
"input_shapes": [h.shape for h in msg.inputs],
|
||||
"input_spaces": [getattr(h, "space", "tcm") for h in msg.inputs],
|
||||
"input_dtypes": [h.dtype for h in msg.inputs],
|
||||
"input_handle_data": tuple(getattr(h, "data", None) for h in msg.inputs),
|
||||
"dst_addr": msg.out.addr,
|
||||
"dst_space": getattr(msg.out, "space", "tcm"),
|
||||
"shape_out": msg.out.shape,
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
"""Phase 1 spec test for the math-input snapshot race (IPCQ slot wrap).
|
||||
|
||||
Context (sub-cycle 4c.0 diagnostic):
|
||||
|
||||
The mesh decode kernel (_attention_mesh_mlo.py) issues many tl.recv()
|
||||
calls against an IPCQ ring of ~8 slots. With n_ranks=8 and bidirectional
|
||||
fan-out, each PE issues 3 recvs per step × 7 steps × 2 directions =
|
||||
42 recvs per panel. The IPCQ slot index is ``my_tail % n_slots``, so
|
||||
the ring wraps and a fresh recv overwrites a slot whose data a prior
|
||||
math op had not yet snapshotted.
|
||||
|
||||
OpLogger.record_end currently snapshots math inputs by re-reading
|
||||
MemoryStore at record_end time (op_log.py:97-113). When a later recv
|
||||
has overwritten the input addr with a DIFFERENT-shape array between
|
||||
record_start and record_end, MemoryStore.read raises
|
||||
``Shape mismatch: stored (16, 64) vs requested (16, 1)`` and the
|
||||
snapshot becomes None (or, in Phase 2 replay, surfaces the same
|
||||
exception in DataExecutor).
|
||||
|
||||
Phase 1 expectation: this test currently fails. It asserts the
|
||||
*desired* behavior: when the math input TensorHandle carries a
|
||||
.data snapshot (captured at recv time before the slot was wrapped),
|
||||
OpLogger MUST prefer that snapshot over MemoryStore.read.
|
||||
|
||||
After Phase 2 (snapshot propagation fix), this test passes — and the
|
||||
sub-cycle 4c.0 mesh decode end-to-end (test_attention_mesh_decode_diag
|
||||
and test_milestone_gqa_llama70b) passes for the same reason.
|
||||
|
||||
See: docs/adr/ADR-0020 (two-phase execution),
|
||||
docs/adr/ADR-0023 (IPCQ ring slots),
|
||||
docs/adr/ADR-0027 (snapshot discipline for dma_write).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from kernbench.common.pe_commands import MathCmd, TensorHandle
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
from kernbench.sim_engine.op_log import OpLogger
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _slot_handle(addr: int, shape: tuple[int, ...], dtype: str,
|
||||
data: np.ndarray | None) -> TensorHandle:
|
||||
"""Build a TensorHandle as tl.recv() would: addr=slot, .data=snapshot."""
|
||||
nbytes = int(np.prod(shape)) * np.dtype(
|
||||
{"f16": np.float16, "f32": np.float32}[dtype]
|
||||
).itemsize
|
||||
return TensorHandle(
|
||||
id=f"slot_{addr:x}", addr=addr, shape=shape, dtype=dtype,
|
||||
nbytes=nbytes, data=data, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
def _out_handle(addr: int, shape: tuple[int, ...], dtype: str) -> TensorHandle:
|
||||
nbytes = int(np.prod(shape)) * np.dtype(
|
||||
{"f16": np.float16, "f32": np.float32}[dtype]
|
||||
).itemsize
|
||||
return TensorHandle(
|
||||
id=f"out_{addr:x}", addr=addr, shape=shape, dtype=dtype,
|
||||
nbytes=nbytes, data=None, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_math_snapshot_lost_when_input_slot_overwritten_with_same_nbytes():
|
||||
"""Baseline (passes today): if a later write at the input addr has the
|
||||
SAME nbytes as the math input's expected shape, MemoryStore.read
|
||||
returns the LATER data — the snapshot is silently wrong. This is the
|
||||
quiet variant of the bug; it does not raise, it just produces
|
||||
incorrect numerical output in Phase 2.
|
||||
|
||||
This test documents that the current OpLogger behavior is wrong even
|
||||
when shapes coincidentally match. The Phase 2 fix removes this
|
||||
silent-corruption mode by preferring handle.data.
|
||||
"""
|
||||
store = MemoryStore()
|
||||
slot_addr = 0x3000
|
||||
# Original at recv time: filled with 7s.
|
||||
original = np.full((16, 1), 7.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, original)
|
||||
|
||||
inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy())
|
||||
out = _out_handle(0x4000, (16, 1), "f16")
|
||||
cmd = MathCmd(op="maximum", inputs=(inp,), out=out)
|
||||
|
||||
logger = OpLogger(memory_store=store)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
# SIMULATE: a later recv writes a DIFFERENT array at the same slot
|
||||
# (same nbytes as (16,1), so MemoryStore.read does not raise).
|
||||
later = np.full((16, 1), 99.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, later)
|
||||
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
snap = logger.records[0].params["input_snapshots"][0]
|
||||
assert snap is not None
|
||||
# Desired post-fix behavior: snapshot equals ``original``.
|
||||
# Today: snapshot equals ``later`` — silent corruption.
|
||||
np.testing.assert_array_equal(snap, original)
|
||||
|
||||
|
||||
def test_math_snapshot_survives_input_slot_wrap_with_different_shape():
|
||||
"""The hard-failure variant: a later recv overwrites the input slot
|
||||
with a DIFFERENT-shape array (different nbytes), so MemoryStore.read
|
||||
at record_end raises and the snapshot becomes None. Phase 2 replay
|
||||
then surfaces this as the (16, 64) vs (16, 1) crash seen in
|
||||
test_attention_mesh_decode_diag.
|
||||
|
||||
Desired behavior: handle.data carries the recv-time snapshot, so
|
||||
OpLogger never has to look at MemoryStore for this input → no race,
|
||||
snapshot is correct.
|
||||
"""
|
||||
store = MemoryStore()
|
||||
slot_addr = 0x3000
|
||||
|
||||
# Original at recv time: an (m, ℓ) reduction result, shape (16, 1).
|
||||
original = np.full((16, 1), 7.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, original)
|
||||
|
||||
inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy())
|
||||
out = _out_handle(0x4000, (16, 1), "f16")
|
||||
cmd = MathCmd(op="maximum", inputs=(inp,), out=out)
|
||||
|
||||
logger = OpLogger(memory_store=store)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
# SIMULATE the slot-wrap race: a later recv (an o triplet, shape
|
||||
# (16, 64)) writes the same TCM slot. MemoryStore.read for shape
|
||||
# (16, 1) now raises ValueError("Shape mismatch ...").
|
||||
overwrite = np.full((16, 64), 99.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, overwrite)
|
||||
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
snap = logger.records[0].params["input_snapshots"][0]
|
||||
# Today: snap is None (read raised, except branch returned None).
|
||||
# Post-fix: handle.data preferred → snap is original.
|
||||
assert snap is not None, (
|
||||
"input snapshot was lost when the recv slot was wrapped — "
|
||||
"OpLogger must prefer handle.data over MemoryStore.read for "
|
||||
"math inputs whose handle carries a .data snapshot"
|
||||
)
|
||||
assert snap.shape == (16, 1)
|
||||
np.testing.assert_array_equal(snap, original)
|
||||
|
||||
|
||||
def test_math_snapshot_handle_data_with_multiple_inputs():
|
||||
"""maximum/binary math has 2 inputs; both must use their carried
|
||||
snapshots independently (e.g. m_running merged with m_from_W where
|
||||
only m_from_W came from a recv slot)."""
|
||||
store = MemoryStore()
|
||||
|
||||
# Input 0: a running m value held in PE scratch (no .data; OpLogger
|
||||
# falls back to MemoryStore.read as today). Its addr is stable —
|
||||
# not subject to the slot-wrap race.
|
||||
scratch_addr = 0x5000
|
||||
m_running = np.full((16, 1), 3.0, dtype=np.float16)
|
||||
store.write("tcm", scratch_addr, m_running)
|
||||
inp0 = _slot_handle(scratch_addr, (16, 1), "f16", data=None)
|
||||
|
||||
# Input 1: m_from_W via tl.recv — carries snapshot in .data, addr
|
||||
# is the recv slot which WILL be wrapped before record_end.
|
||||
slot_addr = 0x3000
|
||||
m_from_W = np.full((16, 1), 7.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, m_from_W)
|
||||
inp1 = _slot_handle(slot_addr, (16, 1), "f16", data=m_from_W.copy())
|
||||
|
||||
out = _out_handle(0x4000, (16, 1), "f16")
|
||||
cmd = MathCmd(op="maximum", inputs=(inp0, inp1), out=out)
|
||||
|
||||
logger = OpLogger(memory_store=store)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
# Slot 0x3000 gets wrapped by a later recv with a different shape.
|
||||
overwrite = np.full((16, 64), 99.0, dtype=np.float16)
|
||||
store.write("tcm", slot_addr, overwrite)
|
||||
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
snaps = logger.records[0].params["input_snapshots"]
|
||||
assert len(snaps) == 2
|
||||
# Input 0 (no carried snapshot, addr stable): MemoryStore read still
|
||||
# works. This must keep working post-fix.
|
||||
assert snaps[0] is not None
|
||||
np.testing.assert_array_equal(snaps[0], m_running)
|
||||
# Input 1 (carried snapshot, slot wrapped): must come from .data.
|
||||
assert snaps[1] is not None
|
||||
assert snaps[1].shape == (16, 1)
|
||||
np.testing.assert_array_equal(snaps[1], m_from_W)
|
||||
|
||||
|
||||
def test_math_snapshot_falls_back_to_memory_store_when_handle_data_is_none():
|
||||
"""Backward-compat: handles with .data=None must continue to use
|
||||
MemoryStore.read as today. Most math inputs (intermediate results
|
||||
from local tl.dot / tl.exp etc.) have data=None and their TCM addrs
|
||||
are stable for the kernel's lifetime."""
|
||||
store = MemoryStore()
|
||||
addr = 0x6000
|
||||
arr = np.full((8, 8), 2.0, dtype=np.float16)
|
||||
store.write("tcm", addr, arr)
|
||||
|
||||
inp = _slot_handle(addr, (8, 8), "f16", data=None)
|
||||
out = _out_handle(0x7000, (8, 8), "f16")
|
||||
cmd = MathCmd(op="exp", inputs=(inp,), out=out)
|
||||
|
||||
logger = OpLogger(memory_store=store)
|
||||
logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd)
|
||||
|
||||
snap = logger.records[0].params["input_snapshots"][0]
|
||||
assert snap is not None
|
||||
np.testing.assert_array_equal(snap, arr)
|
||||
Reference in New Issue
Block a user