"""Tests for MemoryStore (ADR-0020 D7).""" import numpy as np import pytest from kernbench.sim_engine.memory_store import MemoryStore def test_write_read_reference(): """Write and read return the same numpy array (no copy).""" store = MemoryStore() data = np.ones((4, 4), dtype=np.float16) store.write("tcm", 0x0, data) result = store.read("tcm", 0x0) assert result is data def test_overwrite_replaces(): """Same addr write replaces the previous tensor.""" store = MemoryStore() data1 = np.zeros((4,), dtype=np.float32) data2 = np.ones((4,), dtype=np.float32) store.write("hbm", 0x100, data1) store.write("hbm", 0x100, data2) result = store.read("hbm", 0x100) assert result is data2 def test_read_missing_raises(): store = MemoryStore() with pytest.raises(KeyError): store.read("hbm", 0x999) def test_read_different_space(): store = MemoryStore() data = np.array([1, 2, 3], dtype=np.int32) store.write("tcm", 0x0, data) with pytest.raises(KeyError): store.read("hbm", 0x0) # different space def test_dtype_reinterpret(): """Read with different dtype does view cast.""" store = MemoryStore() data = np.array([1.0, 2.0], dtype=np.float32) # 8 bytes store.write("tcm", 0x0, data) result = store.read("tcm", 0x0, dtype="u8") assert result.dtype == np.uint8 assert result.nbytes == data.nbytes def test_reshape(): store = MemoryStore() data = np.arange(12, dtype=np.float32) store.write("tcm", 0x0, data) result = store.read("tcm", 0x0, shape=(3, 4)) assert result.shape == (3, 4) def test_shape_mismatch_raises(): store = MemoryStore() data = np.arange(12, dtype=np.float32) store.write("tcm", 0x0, data) with pytest.raises(ValueError, match="Shape mismatch"): store.read("tcm", 0x0, shape=(5, 5)) def test_has(): store = MemoryStore() assert not store.has("tcm", 0x0) store.write("tcm", 0x0, np.array([1])) assert store.has("tcm", 0x0) def test_snapshot(): store = MemoryStore() data = np.ones((4,), dtype=np.float16) store.write("hbm", 0x0, data) snap = store.snapshot() assert snap.read("hbm", 0x0) is data # same reference # Modifying snap doesn't affect original snap.write("hbm", 0x0, np.zeros((4,), dtype=np.float16)) assert store.read("hbm", 0x0) is data