"""Tests for IPCQ strict shape/dtype validation (ADR-0023 D14 F2).""" from __future__ import annotations from dataclasses import dataclass, field from typing import Any import pytest import simpy from kernbench.common.ipcq_types import ( IpcqDmaToken, IpcqEndpoint, IpcqInitEntry, IpcqInvalidDirection, IpcqMetaArrival, IpcqRecvCmd, IpcqRequest, IpcqSendCmd, ) from kernbench.components.builtin.pe_ipcq import PeIpcqComponent from kernbench.runtime_api.kernel import IpcqInitMsg from kernbench.topology.types import Node # ── helpers (smaller copy of test_pe_ipcq fixtures) ──────────────── @dataclass class _FakeTxn: request: Any done: simpy.Event result_data: dict[str, Any] = field(default_factory=dict) def _make(env, strict: bool = True): node = Node( id="sip0.cube0.pe0.pe_ipcq", kind="pe_ipcq", impl="builtin.pe_ipcq", attrs={"strict_validation": strict}, pos_mm=None, ) comp = PeIpcqComponent(node, ctx=None) comp.in_ports["host"] = simpy.Store(env) comp.out_ports["sip0.cube0.pe0.pe_dma"] = simpy.Store(env) comp.start(env) peer_credit = simpy.Store(env) ep = IpcqEndpoint( sip=0, cube=0, pe=1, buffer_kind="tcm", rx_base_pa=0x10_000, rx_base_va=0, n_slots=4, slot_size=4096, ) init_msg = IpcqInitMsg( correlation_id="t", request_id="t", target_sips=(0,), target_cubes=(0,), target_pe=0, entries=(IpcqInitEntry( direction="W", peer=ep, my_rx_base_pa=0x40_000, my_rx_base_va=0, n_slots=4, slot_size=4096, peer_credit_store=peer_credit, ),), backpressure_mode="sleep", buffer_kind="tcm", credit_size_bytes=16, ) done = env.event() comp.in_ports["host"].put(_FakeTxn(request=init_msg, done=done)) env.run(until=done) return comp # ── F2 tests ───────────────────────────────────────────────────────── def test_strict_mode_dtype_mismatch_raises(): env = simpy.Environment() comp = _make(env, strict=True) # Pre-arrive metadata with f32 dtype fake_token = IpcqDmaToken( src_addr=0, src_space="tcm", dst_addr=0x40_000, dst_endpoint=comp._queue_pairs["W"]["peer"], nbytes=64, handle_id="x", shape=(8,), dtype="f32", # mismatched sender_seq=0, src_sip=0, src_cube=0, src_pe=1, src_direction="E", ) comp.in_ports["host"].put(IpcqMetaArrival(token=fake_token)) env.run(until=5) # recv expecting f16 → should raise on strict recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r") req = IpcqRequest(command=recv_cmd, done=env.event()) comp.in_ports["host"].put(req) with pytest.raises(ValueError, match="dtype"): env.run(until=req.done) def test_strict_mode_shape_mismatch_raises(): env = simpy.Environment() comp = _make(env, strict=True) fake_token = IpcqDmaToken( src_addr=0, src_space="tcm", dst_addr=0x40_000, dst_endpoint=comp._queue_pairs["W"]["peer"], nbytes=64, handle_id="x", shape=(16,), dtype="f16", # wrong shape sender_seq=0, src_sip=0, src_cube=0, src_pe=1, src_direction="E", ) comp.in_ports["host"].put(IpcqMetaArrival(token=fake_token)) env.run(until=5) recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r") req = IpcqRequest(command=recv_cmd, done=env.event()) comp.in_ports["host"].put(req) with pytest.raises(ValueError, match="shape"): env.run(until=req.done) def test_non_strict_mode_silently_accepts(): env = simpy.Environment() comp = _make(env, strict=False) fake_token = IpcqDmaToken( src_addr=0, src_space="tcm", dst_addr=0x40_000, dst_endpoint=comp._queue_pairs["W"]["peer"], nbytes=64, handle_id="x", shape=(16,), dtype="f32", # both wrong sender_seq=0, src_sip=0, src_cube=0, src_pe=1, src_direction="E", ) comp.in_ports["host"].put(IpcqMetaArrival(token=fake_token)) env.run(until=5) recv_cmd = IpcqRecvCmd(direction="W", shape=(8,), dtype="f16", handle_id="r") req = IpcqRequest(command=recv_cmd, done=env.event()) comp.in_ports["host"].put(req) env.run(until=req.done) assert req.done.triggered # no exception