"""Phase 1 spec test for the math-input snapshot race (IPCQ slot wrap). Context (sub-cycle 4c.0 diagnostic): The mesh decode kernel (_attention_mesh_mlo.py) issues many tl.recv() calls against an IPCQ ring of ~8 slots. With n_ranks=8 and bidirectional fan-out, each PE issues 3 recvs per step × 7 steps × 2 directions = 42 recvs per panel. The IPCQ slot index is ``my_tail % n_slots``, so the ring wraps and a fresh recv overwrites a slot whose data a prior math op had not yet snapshotted. OpLogger.record_end currently snapshots math inputs by re-reading MemoryStore at record_end time (op_log.py:97-113). When a later recv has overwritten the input addr with a DIFFERENT-shape array between record_start and record_end, MemoryStore.read raises ``Shape mismatch: stored (16, 64) vs requested (16, 1)`` and the snapshot becomes None (or, in Phase 2 replay, surfaces the same exception in DataExecutor). Phase 1 expectation: this test currently fails. It asserts the *desired* behavior: when the math input TensorHandle carries a .data snapshot (captured at recv time before the slot was wrapped), OpLogger MUST prefer that snapshot over MemoryStore.read. After Phase 2 (snapshot propagation fix), this test passes — and the sub-cycle 4c.0 mesh decode end-to-end (test_attention_mesh_decode_diag and test_milestone_gqa_llama70b) passes for the same reason. See: docs/adr/ADR-0020 (two-phase execution), docs/adr/ADR-0023 (IPCQ ring slots), docs/adr/ADR-0027 (snapshot discipline for dma_write). """ from __future__ import annotations import numpy as np from kernbench.common.pe_commands import MathCmd, TensorHandle from kernbench.sim_engine.memory_store import MemoryStore from kernbench.sim_engine.op_log import OpLogger # ── Helpers ────────────────────────────────────────────────────── def _slot_handle(addr: int, shape: tuple[int, ...], dtype: str, data: np.ndarray | None) -> TensorHandle: """Build a TensorHandle as tl.recv() would: addr=slot, .data=snapshot.""" nbytes = int(np.prod(shape)) * np.dtype( {"f16": np.float16, "f32": np.float32}[dtype] ).itemsize return TensorHandle( id=f"slot_{addr:x}", addr=addr, shape=shape, dtype=dtype, nbytes=nbytes, data=data, space="tcm", ) def _out_handle(addr: int, shape: tuple[int, ...], dtype: str) -> TensorHandle: nbytes = int(np.prod(shape)) * np.dtype( {"f16": np.float16, "f32": np.float32}[dtype] ).itemsize return TensorHandle( id=f"out_{addr:x}", addr=addr, shape=shape, dtype=dtype, nbytes=nbytes, data=None, space="tcm", ) # ── Tests ───────────────────────────────────────────────────────── def test_math_snapshot_lost_when_input_slot_overwritten_with_same_nbytes(): """Baseline (passes today): if a later write at the input addr has the SAME nbytes as the math input's expected shape, MemoryStore.read returns the LATER data — the snapshot is silently wrong. This is the quiet variant of the bug; it does not raise, it just produces incorrect numerical output in Phase 2. This test documents that the current OpLogger behavior is wrong even when shapes coincidentally match. The Phase 2 fix removes this silent-corruption mode by preferring handle.data. """ store = MemoryStore() slot_addr = 0x3000 # Original at recv time: filled with 7s. original = np.full((16, 1), 7.0, dtype=np.float16) store.write("tcm", slot_addr, original) inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy()) out = _out_handle(0x4000, (16, 1), "f16") cmd = MathCmd(op="maximum", inputs=(inp,), out=out) logger = OpLogger(memory_store=store) logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) # SIMULATE: a later recv writes a DIFFERENT array at the same slot # (same nbytes as (16,1), so MemoryStore.read does not raise). later = np.full((16, 1), 99.0, dtype=np.float16) store.write("tcm", slot_addr, later) logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) snap = logger.records[0].params["input_snapshots"][0] assert snap is not None # Desired post-fix behavior: snapshot equals ``original``. # Today: snapshot equals ``later`` — silent corruption. np.testing.assert_array_equal(snap, original) def test_math_snapshot_survives_input_slot_wrap_with_different_shape(): """The hard-failure variant: a later recv overwrites the input slot with a DIFFERENT-shape array (different nbytes), so MemoryStore.read at record_end raises and the snapshot becomes None. Phase 2 replay then surfaces this as the (16, 64) vs (16, 1) crash seen in test_attention_mesh_decode_diag. Desired behavior: handle.data carries the recv-time snapshot, so OpLogger never has to look at MemoryStore for this input → no race, snapshot is correct. """ store = MemoryStore() slot_addr = 0x3000 # Original at recv time: an (m, ℓ) reduction result, shape (16, 1). original = np.full((16, 1), 7.0, dtype=np.float16) store.write("tcm", slot_addr, original) inp = _slot_handle(slot_addr, (16, 1), "f16", data=original.copy()) out = _out_handle(0x4000, (16, 1), "f16") cmd = MathCmd(op="maximum", inputs=(inp,), out=out) logger = OpLogger(memory_store=store) logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) # SIMULATE the slot-wrap race: a later recv (an o triplet, shape # (16, 64)) writes the same TCM slot. MemoryStore.read for shape # (16, 1) now raises ValueError("Shape mismatch ..."). overwrite = np.full((16, 64), 99.0, dtype=np.float16) store.write("tcm", slot_addr, overwrite) logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) snap = logger.records[0].params["input_snapshots"][0] # Today: snap is None (read raised, except branch returned None). # Post-fix: handle.data preferred → snap is original. assert snap is not None, ( "input snapshot was lost when the recv slot was wrapped — " "OpLogger must prefer handle.data over MemoryStore.read for " "math inputs whose handle carries a .data snapshot" ) assert snap.shape == (16, 1) np.testing.assert_array_equal(snap, original) def test_math_snapshot_handle_data_with_multiple_inputs(): """maximum/binary math has 2 inputs; both must use their carried snapshots independently (e.g. m_running merged with m_from_W where only m_from_W came from a recv slot).""" store = MemoryStore() # Input 0: a running m value held in PE scratch (no .data; OpLogger # falls back to MemoryStore.read as today). Its addr is stable — # not subject to the slot-wrap race. scratch_addr = 0x5000 m_running = np.full((16, 1), 3.0, dtype=np.float16) store.write("tcm", scratch_addr, m_running) inp0 = _slot_handle(scratch_addr, (16, 1), "f16", data=None) # Input 1: m_from_W via tl.recv — carries snapshot in .data, addr # is the recv slot which WILL be wrapped before record_end. slot_addr = 0x3000 m_from_W = np.full((16, 1), 7.0, dtype=np.float16) store.write("tcm", slot_addr, m_from_W) inp1 = _slot_handle(slot_addr, (16, 1), "f16", data=m_from_W.copy()) out = _out_handle(0x4000, (16, 1), "f16") cmd = MathCmd(op="maximum", inputs=(inp0, inp1), out=out) logger = OpLogger(memory_store=store) logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) # Slot 0x3000 gets wrapped by a later recv with a different shape. overwrite = np.full((16, 64), 99.0, dtype=np.float16) store.write("tcm", slot_addr, overwrite) logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) snaps = logger.records[0].params["input_snapshots"] assert len(snaps) == 2 # Input 0 (no carried snapshot, addr stable): MemoryStore read still # works. This must keep working post-fix. assert snaps[0] is not None np.testing.assert_array_equal(snaps[0], m_running) # Input 1 (carried snapshot, slot wrapped): must come from .data. assert snaps[1] is not None assert snaps[1].shape == (16, 1) np.testing.assert_array_equal(snaps[1], m_from_W) def test_math_snapshot_falls_back_to_memory_store_when_handle_data_is_none(): """Backward-compat: handles with .data=None must continue to use MemoryStore.read as today. Most math inputs (intermediate results from local tl.dot / tl.exp etc.) have data=None and their TCM addrs are stable for the kernel's lifetime.""" store = MemoryStore() addr = 0x6000 arr = np.full((8, 8), 2.0, dtype=np.float16) store.write("tcm", addr, arr) inp = _slot_handle(addr, (8, 8), "f16", data=None) out = _out_handle(0x7000, (8, 8), "f16") cmd = MathCmd(op="exp", inputs=(inp,), out=out) logger = OpLogger(memory_store=store) logger.record_start(10.0, "sip0.cube0.pe0.pe_math", cmd) logger.record_end(15.0, "sip0.cube0.pe0.pe_math", cmd) snap = logger.records[0].params["input_snapshots"][0] assert snap is not None np.testing.assert_array_equal(snap, arr)