From 313dee503cb64088a029ef7bbf30bbbe8c55d922 Mon Sep 17 00:00:00 2001 From: Mukesh Garg Date: Mon, 1 Jun 2026 19:14:09 -0700 Subject: [PATCH] sim_engine: fix IPCQ slot-wrap snapshot race in Phase 2 replay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 cannot snapshot math-output sources at outbound send time because math executes only in Phase 2 — so token.data stays None and PE_DMA inbound can't write the recv slot. For own-sends this is harmless (Phase 2 replay reads the stable scratch addr after math runs). For forwarded sends in mesh kernels (ADR-0059), src_addr is a recv slot that gets wrapped by later inbounds before this read's Phase 2 turn, yielding a shape mismatch on the fallback MemoryStore.read. Fix: DataExecutor maintains a per-slot, time-ordered, shape-keyed history. Every ipcq_copy write appends (t_write, value) to the slot's history; _resolve_read falls back to the most recent shape-matching entry with t_write <= the consuming op's t_start. Applied uniformly to _execute_memory, _execute_gemm, and _execute_math. Secondary: OpLogger.record_end for math ops now prefers TensorHandle.data carried by the input handle over a MemoryStore re-read, closing the smaller record-end race covered by the new test_op_log_input_snapshot_race.py unit tests. Tests: 4 new race tests + 6 existing op_log + mesh decode diag + mesh kv/mlo spec — all green. Full repo sweep: 760 passed (3 pre-existing failures unrelated: bench-registry list drift + Windows Tkinter env). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/kernbench/sim_engine/data_executor.py | 73 ++++++-- src/kernbench/sim_engine/op_log.py | 13 +- tests/test_op_log_input_snapshot_race.py | 218 ++++++++++++++++++++++ 3 files changed, 290 insertions(+), 14 deletions(-) create mode 100644 tests/test_op_log_input_snapshot_race.py diff --git a/src/kernbench/sim_engine/data_executor.py b/src/kernbench/sim_engine/data_executor.py index 9c16035..a0a429c 100644 --- a/src/kernbench/sim_engine/data_executor.py +++ b/src/kernbench/sim_engine/data_executor.py @@ -25,6 +25,37 @@ class DataExecutor: def __init__(self, op_log: list[OpRecord], store: MemoryStore) -> None: self._op_log = op_log self.store = store + # Per-slot time-ordered shape-keyed history. Populated on every + # ipcq_copy WRITE; consulted on reads that find a shape-mismatched + # value in MemoryStore (the slot was wrapped by a later inbound + # before this read's Phase 2 turn). Required because Phase 1 cannot + # snapshot math-output sources at outbound time (math executes only + # in Phase 2), so token.data is None and slot wraps lose the recv- + # time value. See test_attention_mesh_decode_diag (ADR-0059 mesh). + self._slot_history: dict[tuple[str, int], list[tuple[float, Any]]] = {} + + def _resolve_read( + self, space: str, addr: int, + shape: tuple[int, ...] | None, dtype: str | None, + t_at_or_before: float, + ) -> Any: + """Read (space, addr) with expected shape. On KeyError or shape + mismatch in MemoryStore, fall back to ``_slot_history`` for the + most recent shape-matching entry with t_write <= t_at_or_before. + Returns None when no match is found.""" + try: + return self.store.read(space, addr, shape=shape, dtype=dtype) + except (KeyError, ValueError): + pass + hist = self._slot_history.get((space, addr)) + if hist is None: + return None + for t_w, val in reversed(hist): + if t_w > t_at_or_before: + continue + if shape is None or getattr(val, "shape", None) == shape: + return val + return None # Ordering priority within the same t_start: memory copies must run # before math/gemm so that slot data is populated before a consumer @@ -87,14 +118,23 @@ class DataExecutor: # only get populated by Phase 2's math replay). data = p.get("snapshot") if data is None: - try: - data = self.store.read( - src_space, src_addr, - shape=p.get("shape"), dtype=p.get("dtype"), - ) - except KeyError: + data = self._resolve_read( + src_space, src_addr, + p.get("shape"), p.get("dtype"), op.t_start, + ) + if data is None: return self.store.write(dst_space, dst_addr, data) + # Record this write in slot history so a later forwarded read + # at src=dst_addr (a different ipcq_copy whose src is this slot) + # can recover by shape even after the slot has been wrapped. + if op.op_name == "ipcq_copy": + self._slot_history.setdefault( + (dst_space, dst_addr), [], + ).append(( + op.t_start, + data.copy() if hasattr(data, "copy") else data, + )) def _execute_gemm(self, op: OpRecord) -> None: """Execute GEMM: out = a @ b.""" @@ -110,10 +150,16 @@ class DataExecutor: dtype_in = p.get("dtype_in", "f16") dtype_out = p.get("dtype_out", dtype_in) - a = self.store.read(src_a_space, p["src_a_addr"], - shape=p.get("shape_a"), dtype=dtype_in) - b = self.store.read(src_b_space, p["src_b_addr"], - shape=p.get("shape_b"), dtype=dtype_in) + a = self._resolve_read(src_a_space, p["src_a_addr"], + p.get("shape_a"), dtype_in, op.t_start) + if a is None: + a = self.store.read(src_a_space, p["src_a_addr"], + shape=p.get("shape_a"), dtype=dtype_in) + b = self._resolve_read(src_b_space, p["src_b_addr"], + p.get("shape_b"), dtype_in, op.t_start) + if b is None: + b = self.store.read(src_b_space, p["src_b_addr"], + shape=p.get("shape_b"), dtype=dtype_in) # Compute in higher precision if specified dtype_acc = p.get("dtype_acc", "f32") @@ -150,8 +196,11 @@ class DataExecutor: ): if snap is not None: inputs.append(snap) - else: - inputs.append(self.store.read(space, addr, shape=shape, dtype=idtype)) + continue + resolved = self._resolve_read(space, addr, shape, idtype, op.t_start) + if resolved is None: + resolved = self.store.read(space, addr, shape=shape, dtype=idtype) + inputs.append(resolved) result = _compute_math(math_op, inputs, p.get("axis")) if result is not None: diff --git a/src/kernbench/sim_engine/op_log.py b/src/kernbench/sim_engine/op_log.py index 51d2d30..1e8bb8f 100644 --- a/src/kernbench/sim_engine/op_log.py +++ b/src/kernbench/sim_engine/op_log.py @@ -96,13 +96,20 @@ class OpLogger: # gets reused on the next ring round). if self._memory_store is not None: if op_kind == "math": + handle_snaps = params.get("input_handle_data") or () snaps: list[Any] = [] - for addr, shape, space, idtype in zip( + for i, (addr, shape, space, idtype) in enumerate(zip( params.get("input_addrs", []), params.get("input_shapes", []), params.get("input_spaces", []), params.get("input_dtypes", []), - ): + )): + if i < len(handle_snaps) and handle_snaps[i] is not None: + carried = handle_snaps[i] + snaps.append( + carried.copy() if hasattr(carried, "copy") else carried + ) + continue try: arr = self._memory_store.read( space, addr, shape=shape, dtype=idtype, @@ -111,6 +118,7 @@ class OpLogger: except Exception: snaps.append(None) params["input_snapshots"] = snaps + params.pop("input_handle_data", None) elif op_name == "dma_write": # ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch) # sources are repopulated by Phase 2 math/gemm replay — @@ -222,6 +230,7 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]: "input_shapes": [h.shape for h in msg.inputs], "input_spaces": [getattr(h, "space", "tcm") for h in msg.inputs], "input_dtypes": [h.dtype for h in msg.inputs], + "input_handle_data": tuple(getattr(h, "data", None) for h in msg.inputs), "dst_addr": msg.out.addr, "dst_space": getattr(msg.out, "space", "tcm"), "shape_out": msg.out.shape, diff --git a/tests/test_op_log_input_snapshot_race.py b/tests/test_op_log_input_snapshot_race.py new file mode 100644 index 0000000..8458e0a --- /dev/null +++ b/tests/test_op_log_input_snapshot_race.py @@ -0,0 +1,218 @@ +"""Phase 1 spec test for the math-input snapshot race (IPCQ slot wrap). + +Context (sub-cycle 4c.0 diagnostic): + + The mesh decode kernel (_attention_mesh_mlo.py) issues many tl.recv() + calls against an IPCQ ring of ~8 slots. With n_ranks=8 and bidirectional + fan-out, each PE issues 3 recvs per step × 7 steps × 2 directions = + 42 recvs per panel. The IPCQ slot index is ``my_tail % n_slots``, so + the ring wraps and a fresh recv overwrites a slot whose data a prior + math op had not yet snapshotted. + + OpLogger.record_end currently snapshots math inputs by re-reading + MemoryStore at record_end time (op_log.py:97-113). When a later recv + has overwritten the input addr with a DIFFERENT-shape array between + record_start and record_end, MemoryStore.read raises + ``Shape mismatch: stored (16, 64) vs requested (16, 1)`` and the + snapshot becomes None (or, in Phase 2 replay, surfaces the same + exception in DataExecutor). + +Phase 1 expectation: this test currently fails. It asserts the +*desired* behavior: when the math input TensorHandle carries a +.data snapshot (captured at recv time before the slot was wrapped), +OpLogger MUST prefer that snapshot over MemoryStore.read. + +After Phase 2 (snapshot propagation fix), this test passes — and the +sub-cycle 4c.0 mesh decode end-to-end (test_attention_mesh_decode_diag +and test_milestone_gqa_llama70b) passes for the same reason. + +See: docs/adr/ADR-0020 (two-phase execution), + docs/adr/ADR-0023 (IPCQ ring slots), + docs/adr/ADR-0027 (snapshot discipline for dma_write). +""" +from __future__ import annotations + +import numpy as np + +from kernbench.common.pe_commands import MathCmd, TensorHandle +from kernbench.sim_engine.memory_store import MemoryStore +from kernbench.sim_engine.op_log import OpLogger + + +# ── Helpers ────────────────────────────────────────────────────── + + +def _slot_handle(addr: int, shape: tuple[int, ...], dtype: str, + data: np.ndarray | None) -> TensorHandle: + """Build a TensorHandle as tl.recv() would: addr=slot, .data=snapshot.""" + nbytes = int(np.prod(shape)) * np.dtype( + {"f16": np.float16, "f32": np.float32}[dtype] + ).itemsize + return TensorHandle( + id=f"slot_{addr:x}", addr=addr, shape=shape, dtype=dtype, + nbytes=nbytes, data=data, space="tcm", + ) + + +def _out_handle(addr: int, shape: tuple[int, ...], dtype: str) -> TensorHandle: + nbytes = int(np.prod(shape)) * np.dtype( + {"f16": np.float16, "f32": np.float32}[dtype] + ).itemsize + return TensorHandle( + id=f"out_{addr:x}", addr=addr, shape=shape, dtype=dtype, + nbytes=nbytes, data=None, space="tcm", + ) + + +# ── Tests ───────────────────────────────────────────────────────── + + +def test_math_snapshot_lost_when_input_slot_overwritten_with_same_nbytes(): + """Baseline (passes today): if a later write at the input addr has the + SAME nbytes as the math input's expected shape, MemoryStore.read + returns the LATER data — the snapshot is silently wrong. This is the + quiet variant of the bug; it does not raise, it just produces + incorrect numerical output in Phase 2. + + This test documents that the current OpLogger behavior is wrong even + when shapes coincidentally match. The Phase 2 fix removes this + silent-corruption mode by preferring handle.data. + """ + store = MemoryStore() + slot_addr = 0x3000 + # Original at recv time: filled with 7s. + original = np.full((16, 1), 7.0, dtype=np.float16) + store.write("tcm", slot_addr, original) + + inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy()) + out = _out_handle(0x4000, (16, 1), "f16") + cmd = MathCmd(op="maximum", inputs=(inp,), out=out) + + logger = OpLogger(memory_store=store) + logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) + + # SIMULATE: a later recv writes a DIFFERENT array at the same slot + # (same nbytes as (16,1), so MemoryStore.read does not raise). + later = np.full((16, 1), 99.0, dtype=np.float16) + store.write("tcm", slot_addr, later) + + logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) + + snap = logger.records[0].params["input_snapshots"][0] + assert snap is not None + # Desired post-fix behavior: snapshot equals ``original``. + # Today: snapshot equals ``later`` — silent corruption. + np.testing.assert_array_equal(snap, original) + + +def test_math_snapshot_survives_input_slot_wrap_with_different_shape(): + """The hard-failure variant: a later recv overwrites the input slot + with a DIFFERENT-shape array (different nbytes), so MemoryStore.read + at record_end raises and the snapshot becomes None. Phase 2 replay + then surfaces this as the (16, 64) vs (16, 1) crash seen in + test_attention_mesh_decode_diag. + + Desired behavior: handle.data carries the recv-time snapshot, so + OpLogger never has to look at MemoryStore for this input → no race, + snapshot is correct. + """ + store = MemoryStore() + slot_addr = 0x3000 + + # Original at recv time: an (m, ℓ) reduction result, shape (16, 1). + original = np.full((16, 1), 7.0, dtype=np.float16) + store.write("tcm", slot_addr, original) + + inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy()) + out = _out_handle(0x4000, (16, 1), "f16") + cmd = MathCmd(op="maximum", inputs=(inp,), out=out) + + logger = OpLogger(memory_store=store) + logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) + + # SIMULATE the slot-wrap race: a later recv (an o triplet, shape + # (16, 64)) writes the same TCM slot. MemoryStore.read for shape + # (16, 1) now raises ValueError("Shape mismatch ..."). + overwrite = np.full((16, 64), 99.0, dtype=np.float16) + store.write("tcm", slot_addr, overwrite) + + logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) + + snap = logger.records[0].params["input_snapshots"][0] + # Today: snap is None (read raised, except branch returned None). + # Post-fix: handle.data preferred → snap is original. + assert snap is not None, ( + "input snapshot was lost when the recv slot was wrapped — " + "OpLogger must prefer handle.data over MemoryStore.read for " + "math inputs whose handle carries a .data snapshot" + ) + assert snap.shape == (16, 1) + np.testing.assert_array_equal(snap, original) + + +def test_math_snapshot_handle_data_with_multiple_inputs(): + """maximum/binary math has 2 inputs; both must use their carried + snapshots independently (e.g. m_running merged with m_from_W where + only m_from_W came from a recv slot).""" + store = MemoryStore() + + # Input 0: a running m value held in PE scratch (no .data; OpLogger + # falls back to MemoryStore.read as today). Its addr is stable — + # not subject to the slot-wrap race. + scratch_addr = 0x5000 + m_running = np.full((16, 1), 3.0, dtype=np.float16) + store.write("tcm", scratch_addr, m_running) + inp0 = _slot_handle(scratch_addr, (16, 1), "f16", data=None) + + # Input 1: m_from_W via tl.recv — carries snapshot in .data, addr + # is the recv slot which WILL be wrapped before record_end. + slot_addr = 0x3000 + m_from_W = np.full((16, 1), 7.0, dtype=np.float16) + store.write("tcm", slot_addr, m_from_W) + inp1 = _slot_handle(slot_addr, (16, 1), "f16", data=m_from_W.copy()) + + out = _out_handle(0x4000, (16, 1), "f16") + cmd = MathCmd(op="maximum", inputs=(inp0, inp1), out=out) + + logger = OpLogger(memory_store=store) + logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) + + # Slot 0x3000 gets wrapped by a later recv with a different shape. + overwrite = np.full((16, 64), 99.0, dtype=np.float16) + store.write("tcm", slot_addr, overwrite) + + logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) + + snaps = logger.records[0].params["input_snapshots"] + assert len(snaps) == 2 + # Input 0 (no carried snapshot, addr stable): MemoryStore read still + # works. This must keep working post-fix. + assert snaps[0] is not None + np.testing.assert_array_equal(snaps[0], m_running) + # Input 1 (carried snapshot, slot wrapped): must come from .data. + assert snaps[1] is not None + assert snaps[1].shape == (16, 1) + np.testing.assert_array_equal(snaps[1], m_from_W) + + +def test_math_snapshot_falls_back_to_memory_store_when_handle_data_is_none(): + """Backward-compat: handles with .data=None must continue to use + MemoryStore.read as today. Most math inputs (intermediate results + from local tl.dot / tl.exp etc.) have data=None and their TCM addrs + are stable for the kernel's lifetime.""" + store = MemoryStore() + addr = 0x6000 + arr = np.full((8, 8), 2.0, dtype=np.float16) + store.write("tcm", addr, arr) + + inp = _slot_handle(addr, (8, 8), "f16", data=None) + out = _out_handle(0x7000, (8, 8), "f16") + cmd = MathCmd(op="exp", inputs=(inp,), out=out) + + logger = OpLogger(memory_store=store) + logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) + logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) + + snap = logger.records[0].params["input_snapshots"][0] + assert snap is not None + np.testing.assert_array_equal(snap, arr)