diff --git a/CHANGES b/CHANGES index 86547ee9..35798ba1 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,14 @@ PyVISA-py Changelog A second addressed issue is that timeout values never decrement to 0. A timeout value of 0 is undefined in VXI-11 standard. It can mean "timeout immediately if no data is in buffer" or "block permanently until transfer is finished". +- Implement the VISA event subsystem for VXI-11 (TCPIP::INSTR) resources: + `viEnableEvent`, `viDisableEvent`, `viDiscardEvents`, `viWaitOnEvent`, + `viInstallHandler`, and `viUninstallHandler`. SRQ (service request) events + are now supported via both queue-based (`wait_on_event`) and handler-based + (`install_handler`) delivery. A daemon thread runs an ONC RPC TCP interrupt + server to receive VXI-11 `DEVICE_INTR_SRQ` callbacks. This also fixes the + `create_intr_chan` XDR packer in `protocols/vxi11.py`. + Other transports (GPIB, USBTMC, HiSLIP, Serial) remain unsupported for now. PR #577 0.8.1 (04-09-2025) ------------------ diff --git a/docs/index.rst b/docs/index.rst index d5eceb21..4d7569f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,6 +94,9 @@ No. We have implemented those attributes and methods that are most commonly needed. We would like to reach feature parity. If there is something that you need, let us know. +Event handling (``wait_on_event``, ``install_handler``, etc.) is currently +supported for **TCPIP INSTR** (VXI-11) resources only. + Why are you developing this? ---------------------------- diff --git a/pyvisa_py/events.py b/pyvisa_py/events.py new file mode 100644 index 00000000..3d63b3cd --- /dev/null +++ b/pyvisa_py/events.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +"""Event handling primitives for pyvisa-py. + +This module provides the thread-safe building blocks used by the VISA event +subsystem: event contexts, queues, handler registries, and per-session state. + +""" + +import collections +import enum +import random +import threading +import time +import warnings +from dataclasses import dataclass, field +from typing import Any, Callable + +from pyvisa import constants +from pyvisa.typing import VISASession + + +class EventMechanismFlag(enum.Flag): + """Internal Flag enum mirroring VISA event-delivery mechanisms. + + ``ALL`` is a convenience alias for ``QUEUE | HANDLER | SUSPEND``. + The `from_int` classmethod canonicalises the VISA sentinel + ``0xFFFF`` (``constants.EventMechanism.all``) to this composite + value so that bitwise ``~`` works correctly. + """ + + NONE = 0 + QUEUE = 1 # VI_QUEUE (1) + HANDLER = 2 # VI_HNDLR (2) + SUSPEND = 4 # VI_SUSPEND_HNDLR (4) + ALL = QUEUE | HANDLER | SUSPEND # = 7, not VI_ALL_MECH (0xFFFF) + + @classmethod + def from_int(cls, value: int) -> "EventMechanismFlag": + if value == int(constants.EventMechanism.all): # 0xFFFF + return cls.ALL + return cls(value & (cls.QUEUE | cls.HANDLER | cls.SUSPEND).value) + + +@dataclass(frozen=True, slots=True) +class EventContext: + """Immutable description of a single VISA event occurrence.""" + + event_type: constants.EventType + status_byte: int = 0 + timestamp: float = field(default_factory=time.time) + context_id: int = field(default_factory=lambda: random.getrandbits(32)) + + +class EventQueue: + """Thread-safe FIFO queue for :class:`EventContext` objects.""" + + def __init__(self) -> None: + self._deque: collections.deque[EventContext] = collections.deque() + self._cond = threading.Condition() + + def put(self, ctx: EventContext) -> None: + """Add an event context to the queue (non-blocking).""" + with self._cond: + self._deque.append(ctx) + self._cond.notify_all() + + def get(self, timeout_ms: int | None) -> EventContext | None: + """Retrieve an event context. + + Parameters + ---------- + timeout_ms : + ``None`` blocks forever, ``0`` returns immediately if empty, + and a positive value blocks up to that many milliseconds. + + Returns + ------- + EventContext or None + The retrieved context, or ``None`` if the queue was empty. + + """ + if timeout_ms is None: + with self._cond: + while not self._deque: + self._cond.wait() + return self._deque.popleft() + if timeout_ms == 0: + with self._cond: + if self._deque: + return self._deque.popleft() + return None + deadline = time.time() + timeout_ms / 1000.0 + with self._cond: + while not self._deque: + remaining = deadline - time.time() + if remaining <= 0: + return None + self._cond.wait(remaining) + return self._deque.popleft() + + def get_matching( + self, + event_type: constants.EventType | None, + timeout_ms: int | None, + ) -> EventContext | None: + """Retrieve the first event matching *event_type*. + + If *event_type* is ``None``, matches any event. + ``timeout_ms`` semantics are the same as :meth:`get`. + """ + if timeout_ms is None: + with self._cond: + while True: + for idx, ctx in enumerate(self._deque): + if event_type is None or ctx.event_type == event_type: + del self._deque[idx] + return ctx + self._cond.wait() + if timeout_ms == 0: + with self._cond: + for idx, ctx in enumerate(self._deque): + if event_type is None or ctx.event_type == event_type: + del self._deque[idx] + return ctx + return None + deadline = time.time() + timeout_ms / 1000.0 + with self._cond: + while True: + for idx, ctx in enumerate(self._deque): + if event_type is None or ctx.event_type == event_type: + del self._deque[idx] + return ctx + remaining = deadline - time.time() + if remaining <= 0: + return None + self._cond.wait(remaining) + + def discard_all(self, event_type: constants.EventType | None = None) -> None: + """Remove items from the queue. + + If *event_type* is ``None``, the entire queue is cleared. + Otherwise only contexts whose ``event_type`` matches are removed. + + """ + with self._cond: + if event_type is None: + self._deque.clear() + else: + kept = [ctx for ctx in self._deque if ctx.event_type != event_type] + self._deque.clear() + self._deque.extend(kept) + + +HandlerCallback = Callable[[VISASession, constants.EventType, int, Any], None] +"""Callable invoked when a VISA event fires. + +Parameters +---------- +session : VISASession + Session handle (vi). +event_type : constants.EventType + The event type that fired. +context_id : int + Event context id. +user_handle : Any + User-supplied handle passed at install_handler time. +""" + + +class HandlerRegistry: + """Thread-safe registry of user-installed event handlers.""" + + def __init__(self) -> None: + self._lock = threading.RLock() + # event_type -> list of (handler, user_handle) + self._handlers: collections.defaultdict[ + constants.EventType, list[tuple[HandlerCallback, Any]] + ] = collections.defaultdict(list) + + def install( + self, + event_type: constants.EventType, + handler: HandlerCallback, + user_handle: Any, + ) -> None: + """Register a handler for the given event type.""" + with self._lock: + self._handlers[event_type].append((handler, user_handle)) + + def uninstall( + self, + event_type: constants.EventType, + handler: HandlerCallback, + user_handle: Any = None, + ) -> bool: + """Remove a previously installed handler. + + If *user_handle* is ``None``, the first entry matching *handler* + identity is removed regardless of its user handle. + + Returns ``True`` if a handler was removed, ``False`` otherwise. + + """ + with self._lock: + entries = self._handlers.get(event_type, []) + for idx, (h, uh) in enumerate(entries): + if h is handler and (user_handle is None or uh == user_handle): + entries.pop(idx) + return True + return False + + def fire( + self, + event_type: constants.EventType, + session: VISASession, + context_id: int, + ) -> None: + """Invoke all handlers registered for *event_type*. + + Each handler is called as ``handler(session, event_type, context_id, + user_handle)`` where *user_handle* is the value supplied at + installation. Exceptions raised by a handler are warned via + ``warnings.warn`` and do not prevent subsequent handlers from running. + + """ + with self._lock: + handlers = list(self._handlers.get(event_type, [])) + + for handler, user_handle in handlers: + try: + handler(session, event_type, context_id, user_handle) + except Exception as exc: + warnings.warn( + f"Event handler {handler!r} raised an exception: {exc!r}", + stacklevel=2, + ) + + +class EventState: + """Per-session container for event enablement, queuing, and handlers.""" + + def __init__(self) -> None: + # {event_type: EventMechanismFlag} + self._lock = threading.RLock() + self.enabled: dict[constants.EventType, EventMechanismFlag] = {} + self.queue = EventQueue() + self.registry = HandlerRegistry() + self.monitor_thread: threading.Thread | None = None + self.stop_flag: threading.Event = threading.Event() + + def enable( + self, + event_type: constants.EventType, + mechanism: constants.EventMechanism, + ) -> None: + """Enable delivery of *event_type* via *mechanism_flag*.""" + m = EventMechanismFlag.from_int(int(mechanism)) + with self._lock: + self.enabled[event_type] = ( + self.enabled.get(event_type, EventMechanismFlag.NONE) | m + ) + + def disable( + self, + event_type: constants.EventType, + mechanism: constants.EventMechanism, + ) -> None: + """Disable delivery of *event_type* via *mechanism_flag*.""" + m = EventMechanismFlag.from_int(int(mechanism)) + with self._lock: + if event_type not in self.enabled: + return + new = self.enabled[event_type] & ~m + if new is EventMechanismFlag.NONE: + del self.enabled[event_type] + else: + self.enabled[event_type] = new + + def is_queue_enabled(self, event_type: constants.EventType) -> bool: + """Return whether queue delivery is enabled for *event_type*.""" + with self._lock: + return bool( + self.enabled.get(event_type, EventMechanismFlag.NONE) + & EventMechanismFlag.QUEUE + ) + + def is_handler_enabled(self, event_type: constants.EventType) -> bool: + """Return whether handler (callback) delivery is enabled for *event_type*.""" + with self._lock: + return bool( + self.enabled.get(event_type, EventMechanismFlag.NONE) + & EventMechanismFlag.HANDLER + ) + + def get_delivery_mechanisms( + self, event_type: constants.EventType + ) -> tuple[bool, bool]: + """Return (queue_enabled, handler_enabled) for *event_type*. + + The check is performed atomically under the state lock. + """ + with self._lock: + mech = self.enabled.get(event_type, EventMechanismFlag.NONE) + return ( + bool(mech & EventMechanismFlag.QUEUE), + bool(mech & EventMechanismFlag.HANDLER), + ) + + def any_enabled(self) -> bool: + """Return ``True`` if any event type has any mechanism enabled.""" + with self._lock: + return any(m is not EventMechanismFlag.NONE for m in self.enabled.values()) diff --git a/pyvisa_py/highlevel.py b/pyvisa_py/highlevel.py index 208b538e..9e54335a 100644 --- a/pyvisa_py/highlevel.py +++ b/pyvisa_py/highlevel.py @@ -27,6 +27,7 @@ from pyvisa.util import DebugInfo, LibraryPath from .common import LOGGER +from .events import EventMechanismFlag from .sessions import OpenError, Session @@ -197,7 +198,9 @@ def open( except OpenError as e: return VISASession(0), self.handle_return_value(None, e.error_code) - return self._register(sess), StatusCode.success + visa_session = self._register(sess) + sess._session_handle = visa_session + return visa_session, StatusCode.success def clear(self, session: VISASession) -> StatusCode: """Clears a device. @@ -790,6 +793,50 @@ def unlock(self, session: VISASession) -> StatusCode: return self.handle_return_value(session, StatusCode.error_invalid_object) return self.handle_return_value(session, sess.unlock()) + def enable_event( + self, + session: VISASession, + event_type: constants.EventType, + mechanism: constants.EventMechanism, + context: None = None, + ) -> StatusCode: + """Enable notification for an event type via the specified mechanism. + + Corresponds to viEnableEvent function of the VISA library. + + Parameters + ---------- + session : VISASession + Unique logical identifier to a session. + event_type : constants.EventType + Event type. + mechanism : constants.EventMechanism + Event handling mechanisms to be enabled. + context : None, optional + Not used in pyvisa-py. + + Returns + ------- + StatusCode + Return value of the library call. + + """ + try: + sess = self.sessions[session] + except KeyError: + return self.handle_return_value(session, StatusCode.error_invalid_object) + + if event_type not in sess._supported_event_types: + return self.handle_return_value(session, StatusCode.error_invalid_event) + + sess._event_state.enable(event_type, mechanism) + status = sess._start_event_monitor() + if status != StatusCode.success: + sess._event_state.disable(event_type, mechanism) + return self.handle_return_value(session, status) + + return self.handle_return_value(session, StatusCode.success) + def disable_event( self, session: VISASession, @@ -815,7 +862,21 @@ def disable_event( Return value of the library call. """ - return StatusCode.error_nonimplemented_operation + try: + sess = self.sessions[session] + except KeyError: + return self.handle_return_value(session, StatusCode.error_invalid_object) + + if event_type == constants.EventType.all_enabled: + for et in list(sess._event_state.enabled.keys()): + sess._event_state.disable(et, mechanism) + else: + sess._event_state.disable(event_type, mechanism) + + if not sess._event_state.any_enabled(): + sess._stop_event_monitor() + + return self.handle_return_value(session, StatusCode.success) def discard_events( self, @@ -831,7 +892,7 @@ def discard_events( ---------- session : VISASession Unique logical identifier to a session. - event_type : constans.EventType + event_type : constants.EventType Logical event identifier. mechanism : constants.EventMechanism Specifies event handling mechanisms to be discarded. @@ -842,4 +903,146 @@ def discard_events( Return value of the library call. """ - return StatusCode.error_nonimplemented_operation + try: + sess = self.sessions[session] + except KeyError: + return self.handle_return_value(session, StatusCode.error_invalid_object) + + mech = EventMechanismFlag.from_int(int(mechanism)) + if mech & EventMechanismFlag.QUEUE: + et = None if event_type == constants.EventType.all_enabled else event_type + sess._event_state.queue.discard_all(et) + + return self.handle_return_value(session, StatusCode.success) + + def wait_on_event( + self, session: VISASession, in_event_type: constants.EventType, timeout: int + ) -> Tuple[constants.EventType, VISAEventContext, StatusCode]: + """Wait for an event occurrence for a given type. + + Corresponds to viWaitOnEvent function of the VISA library. + + Parameters + ---------- + session : VISASession + Unique logical identifier to a session. + in_event_type : constants.EventType + Event type to wait for. + timeout : int + Timeout in milliseconds. + + Returns + ------- + constants.EventType + Type of the event that occurred. + VISAEventContext + Context identifier for the event. + StatusCode + Return value of the library call. + + """ + try: + sess = self.sessions[session] + except KeyError: + return ( + in_event_type, + VISAEventContext(0), + self.handle_return_value(session, StatusCode.error_invalid_object), + ) + + et = None if in_event_type == constants.EventType.all_enabled else in_event_type + ctx = sess._event_state.queue.get_matching(et, timeout) + if ctx is None: + return ( + in_event_type, + VISAEventContext(0), + self.handle_return_value(session, StatusCode.error_timeout), + ) + return ctx.event_type, VISAEventContext(ctx.context_id), StatusCode.success + + def install_handler( + self, + session: VISASession, + event_type: constants.EventType, + handler: Any, + user_handle: Any, + ) -> Tuple[Any, Any, Any, StatusCode]: + """Install a handler for an event type. + + Corresponds to viInstallHandler function of the VISA library. + + Parameters + ---------- + session : VISASession + Unique logical identifier to a session. + event_type : constants.EventType + Event type. + handler : Any + Handler function to install. + user_handle : Any + User handle passed to the handler. + + Returns + ------- + Any + The handler that was installed. + Any + The user handle. + Any + The handler that was installed. + StatusCode + Return value of the library call. + + """ + try: + sess = self.sessions[session] + except KeyError: + return ( + handler, + user_handle, + handler, + self.handle_return_value(session, StatusCode.error_invalid_object), + ) + + sess._event_state.registry.install(event_type, handler, user_handle) + return (handler, user_handle, handler, StatusCode.success) + + def uninstall_handler( + self, + session: VISASession, + event_type: constants.EventType, + handler: Any, + user_handle: Any = None, + ) -> StatusCode: + """Uninstall a handler for an event type. + + Corresponds to viUninstallHandler function of the VISA library. + + Parameters + ---------- + session : VISASession + Unique logical identifier to a session. + event_type : constants.EventType + Event type. + handler : Any + Handler function to uninstall. + user_handle : Any, optional + User handle associated with the handler. + + Returns + ------- + StatusCode + Return value of the library call. + + """ + try: + sess = self.sessions[session] + except KeyError: + return self.handle_return_value(session, StatusCode.error_invalid_object) + + found = sess._event_state.registry.uninstall(event_type, handler, user_handle) + if not found: + return self.handle_return_value( + session, StatusCode.error_handler_not_installed + ) + return self.handle_return_value(session, StatusCode.success) diff --git a/pyvisa_py/protocols/rpc.py b/pyvisa_py/protocols/rpc.py index e8bd07ad..a642775f 100644 --- a/pyvisa_py/protocols/rpc.py +++ b/pyvisa_py/protocols/rpc.py @@ -968,7 +968,7 @@ def handle(self, call): def turn_around(self): try: self.unpacker.done() - except RuntimeError: + except (RuntimeError, xdrlib.Error): raise RPCGarbageArgs self.packer.pack_uint(AcceptStatus.success) diff --git a/pyvisa_py/protocols/vxi11.py b/pyvisa_py/protocols/vxi11.py index b3d22f98..6e00a656 100644 --- a/pyvisa_py/protocols/vxi11.py +++ b/pyvisa_py/protocols/vxi11.py @@ -11,8 +11,16 @@ """ import enum +import queue import socket +import struct +import threading +from pyvisa import constants +from pyvisa.constants import StatusCode + +from ..common import LOGGER +from ..events import EventContext from . import rpc # fmt: off @@ -42,11 +50,18 @@ CREATE_INTR_CHAN = 25 DESTROY_INTR_CHAN = 26 +# Status byte bit masks +STB_RQS_BIT = 0x40 # Request Service bit in serial poll status byte + # Device intr DEVICE_INTR_PROG = 0x0607B1 DEVICE_INTR_VERS = 1 DEVICE_INTR_SRQ = 30 +# Device address family for create_intr_chan (NOT IPPROTO_TCP/IPPROTO_UDP) +DEVICE_TCP = 0 +DEVICE_UDP = 1 + # Error states class ErrorCodes(enum.IntEnum): @@ -354,7 +369,7 @@ def create_intr_chan(self, host_addr, host_port, prog_num, prog_vers, prog_famil return self.make_call( CREATE_INTR_CHAN, params, - self.packer.pack_device_docmd_parms, + self.packer.pack_device_remote_func_parms, self.unpacker.unpack_device_error, ) @@ -362,3 +377,133 @@ def destroy_intr_chan(self): return self.make_call( DESTROY_INTR_CHAN, None, None, self.unpacker.unpack_device_error ) + + +class SrqInterruptTCPServer(rpc.TCPServer): + """TCP RPC server that receives VXI-11 DEVICE_INTR_SRQ (proc 30) interrupts.""" + + def __init__(self, host, prog, vers, port, session): + super().__init__(host, prog, vers, port) + self.session = session + self._srq_queue = queue.Queue() + self._srq_worker_thread = threading.Thread(target=self._srq_worker, daemon=True) + self._srq_worker_thread.start() + + def _srq_worker(self): + while True: + try: + item = self._srq_queue.get(timeout=1.0) + except queue.Empty: + continue + if item is None: + break + self._fire_srq() + + def stop(self): + self._srq_queue.put(None) + self._srq_worker_thread.join(timeout=2.0) + + def connect(self): + super().connect() + self.sock.listen(1) + + def loop(self): + """Accept connections from the instrument and handle SRQ until stopped.""" + self.sock.settimeout(1.0) + stop_flag = self.session._event_state.stop_flag + while not stop_flag.is_set(): + try: + conn, _addr = self.sock.accept() + except socket.timeout: + continue + except OSError: + break + try: + self._handle_connection(conn) + finally: + try: + conn.close() + except Exception: + pass + + def _handle_connection(self, conn): + """Read RPC calls from the instrument, send replies, and fire events. + + The VXI-11 interrupt channel is a persistent TCP connection. The + instrument may send multiple DEVICE_INTR_SRQ calls over the same + connection, so we keep reading until the connection is closed or + the session stop flag is set. + """ + stop_flag = self.session._event_state.stop_flag + try: + conn.settimeout(1.0) + while not stop_flag.is_set(): + try: + # Read record marker (4 bytes) + marker = self._recv_all(conn, 4) + if marker is None: + return # Connection closed by peer + except socket.timeout: + continue + + frag = struct.unpack(">I", marker)[0] + last_frag = frag >> 31 + frag_len = frag & 0x7FFFFFFF + + # Read the fragment payload + call = self._recv_all(conn, frag_len) + if call is None: + return + + # If there are more fragments, consume and append them + while not last_frag: + marker = self._recv_all(conn, 4) + if marker is None: + return + frag = struct.unpack(">I", marker)[0] + last_frag = frag >> 31 + frag_len = frag & 0x7FFFFFFF + leftover = self._recv_all(conn, frag_len) + if leftover is None: + return + call += leftover + + reply = self.handle(call) + if reply is not None: + reply_frag = struct.pack(">I", 0x80000000 | len(reply)) + reply + conn.sendall(reply_frag) + except Exception: + LOGGER.exception("Error handling TCP SRQ connection") + + def _recv_all(self, sock, n): + data = b"" + while len(data) < n: + chunk = sock.recv(n - len(data)) + if not chunk: + return None + data += chunk + return data + + def handle_30(self): + """Handle DEVICE_INTR_SRQ (procedure 30).""" + handle = self.unpacker.unpack_opaque() + self.turn_around() + if handle != b"srq": + LOGGER.warning("Ignoring VXI-11 SRQ with unexpected handle: %r", handle) + return + self._srq_queue.put(True) + + def _fire_srq(self): + try: + # Defensive: session may have been closed while we were spawned + if self.session.interface is None or self.session.link == 0: + return + stb, status = self.session.read_stb() + if status == StatusCode.success and (stb & STB_RQS_BIT): + ctx = EventContext( + event_type=constants.EventType.service_request, + status_byte=stb, + ) + self.session._fire_event(constants.EventType.service_request, ctx) + except Exception: + LOGGER.exception("Error handling VXI-11 SRQ interrupt") diff --git a/pyvisa_py/sessions.py b/pyvisa_py/sessions.py index e3224556..54e10530 100644 --- a/pyvisa_py/sessions.py +++ b/pyvisa_py/sessions.py @@ -24,9 +24,10 @@ from pyvisa import attributes, constants, rname from pyvisa.constants import ResourceAttribute, StatusCode -from pyvisa.typing import VISAJobID, VISARMSession +from pyvisa.typing import VISAJobID, VISARMSession, VISASession from .common import LOGGER, BytesBuffer, int_to_byte +from .events import EventContext, EventState #: Type var used when typing register. T = TypeVar("T", bound=Type["Session"]) @@ -141,9 +142,15 @@ def close(self) -> StatusCode: #: Session type as (Interface Type, Resource Class) session_type: Tuple[constants.InterfaceType, str] + #: Event types supported by this session class. + _supported_event_types: ClassVar[set[constants.EventType]] = set() + #: Timeout in milliseconds to use when opening the resource. open_timeout: Optional[int] + #: VISA session handle assigned by the library after registration. + _session_handle: VISASession + #: Value of the timeout in seconds used for general operation timeout: Optional[float] @@ -328,6 +335,9 @@ def __init__( self.after_parsing() + self._event_state = EventState() + self._session_handle = VISASession(0) + def after_parsing(self) -> None: """Override this method to provide custom initialization code, to be called after the resource name is properly parsed @@ -366,6 +376,43 @@ def after_parsing(self) -> None: """ pass + def _fire_event(self, event_type: constants.EventType, ctx: EventContext) -> None: + """Dispatch an event occurrence to the queue and/or handlers. + + This method is called by transport-specific monitor threads when an + SRQ (or other asynchronous signal) is detected. + """ + queue_enabled, handler_enabled = self._event_state.get_delivery_mechanisms( + event_type + ) + if queue_enabled: + self._event_state.queue.put(ctx) + if handler_enabled: + session_handle = self._session_handle + self._event_state.registry.fire(event_type, session_handle, ctx.context_id) + + def _start_event_monitor(self) -> StatusCode: + """Start a background thread to watch for event assertions. + + Transports that support asynchronous events (VXI-11, GPIB, USBTMC) + should override this method. The base implementation is a no-op. + + Returns + ------- + StatusCode + Return value of the library call. + """ + return StatusCode.success + + def _stop_event_monitor(self) -> None: + """Stop the event monitor thread. + + Transports should override this to signal their monitor thread + (via ``self._event_state.stop_flag.set()``) and join it. + The base implementation is a no-op. + """ + pass + def write(self, data: bytes) -> Tuple[int, StatusCode]: """Writes data to device or interface synchronously. diff --git a/pyvisa_py/tcpip.py b/pyvisa_py/tcpip.py index f3a0f3f7..c58fe9bb 100644 --- a/pyvisa_py/tcpip.py +++ b/pyvisa_py/tcpip.py @@ -7,10 +7,13 @@ """ +from __future__ import annotations + import ipaddress import random import select import socket +import threading import time import warnings from typing import Any, Dict, List, Optional, Tuple, Type, cast @@ -431,6 +434,7 @@ class Vxi11CoreClient(vxi11.CoreClient): def __init__( self, host: str, port: Optional[int], open_timeout: Optional[int] = 5000 ) -> None: + self._lock = threading.Lock() self.packer = vxi11.Vxi11Packer() self.unpacker = vxi11.Vxi11Unpacker(b"") prog, vers = vxi11.DEVICE_CORE_PROG, vxi11.DEVICE_CORE_VERS @@ -441,6 +445,10 @@ def __init__( # bypass the portmapper lookup and use the specified port instead rpc.RawTCPClient.__init__(self, host, prog, vers, port, open_timeout) + def make_call(self, proc, args, pack_func, unpack_func): + with self._lock: + return super().make_call(proc, args, pack_func, unpack_func) + class TCPIPInstrVxi11(Session): """A TCPIP Session built on socket standard library using VXI-11 protocol.""" @@ -450,6 +458,8 @@ class TCPIPInstrVxi11(Session): # need to define session_type to make the set_attribute machinery work. session_type = (constants.InterfaceType.tcpip, "INSTR") + _supported_event_types = {constants.EventType.service_request} + #: Maximum size of a chunk of data in bytes. max_recv_size: int @@ -541,6 +551,8 @@ def after_parsing(self) -> None: self.client_id = random.getrandbits(31) self.keepalive = False + self._srq_server: vxi11.SrqInterruptTCPServer | None = None + self._srq_lifecycle_lock = threading.Lock() error, link, _abort_port, max_recv_size = self.interface.create_link( self.client_id, 0, self.lock_timeout, self.parsed.lan_device_name @@ -561,6 +573,7 @@ def after_parsing(self) -> None: self.attrs[attribute] = attributes.AttributesByID[attribute].default def close(self) -> StatusCode: + self._stop_event_monitor() try: self.interface.destroy_link(self.link) except (errors.VisaIOError, socket.error, rpc.RPCError) as e: @@ -572,6 +585,106 @@ def close(self) -> StatusCode: return StatusCode.success + def _start_event_monitor(self) -> StatusCode: + """Start the VXI-11 interrupt server and enable events on the device.""" + with self._srq_lifecycle_lock: + with self._event_state._lock: + if ( + self._event_state.monitor_thread is not None + and self._event_state.monitor_thread.is_alive() + ): + return StatusCode.success + if not self._event_state.any_enabled(): + return StatusCode.success + + self._event_state.stop_flag.clear() + + server = vxi11.SrqInterruptTCPServer( + "", + vxi11.DEVICE_INTR_PROG, + vxi11.DEVICE_INTR_VERS, + 0, + self, + ) + port = server.sock.getsockname()[1] + + local_ip_str = self.interface.sock.getsockname()[0] + host_addr = int(ipaddress.IPv4Address(local_ip_str)) + + error = self.interface.create_intr_chan( + host_addr, + port, + vxi11.DEVICE_INTR_PROG, + vxi11.DEVICE_INTR_VERS, + vxi11.DEVICE_TCP, + ) + if error: + LOGGER.error("create_intr_chan failed with error %d", error) + try: + server.sock.close() + except Exception: + pass + return StatusCode.error_nonsupported_operation + + error = self.interface.device_enable_srq(self.link, True, b"srq") + if error: + LOGGER.error("device_enable_srq failed with error %d", error) + try: + self.interface.destroy_intr_chan() + except Exception: + pass + try: + server.sock.close() + except Exception: + pass + return StatusCode.error_io + + with self._event_state._lock: + if ( + self._event_state.monitor_thread is not None + and self._event_state.monitor_thread.is_alive() + ): + try: + server.sock.close() + except Exception: + pass + return StatusCode.success + thread = threading.Thread(target=server.loop, daemon=True) + self._event_state.monitor_thread = thread + self._srq_server = server + thread.start() + return StatusCode.success + + def _stop_event_monitor(self) -> None: + """Disable events and stop the interrupt server thread.""" + with self._srq_lifecycle_lock: + self._event_state.stop_flag.set() + try: + self.interface.device_enable_srq(self.link, False, b"") + except Exception: + LOGGER.exception("Error disabling VXI-11 SRQ") + try: + self.interface.destroy_intr_chan() + except Exception: + LOGGER.exception("Error destroying VXI-11 interrupt channel") + with self._event_state._lock: + thread = self._event_state.monitor_thread + self._event_state.monitor_thread = None + server = self._srq_server + self._srq_server = None + + if thread is not None: + thread.join(timeout=1.0) + if server is not None: + try: + server.stop() + except Exception: + pass + try: + server.sock.close() + except Exception: + pass + def read(self, count: int) -> Tuple[bytes, StatusCode]: """Reads data from device or interface synchronously. diff --git a/pyvisa_py/testsuite/keysight_assisted_tests/test_tcpip_resources.py b/pyvisa_py/testsuite/keysight_assisted_tests/test_tcpip_resources.py index 51615e03..a359f508 100644 --- a/pyvisa_py/testsuite/keysight_assisted_tests/test_tcpip_resources.py +++ b/pyvisa_py/testsuite/keysight_assisted_tests/test_tcpip_resources.py @@ -30,20 +30,16 @@ class TestTCPIPInstr(TCPIPInstrBaseTest): # XXX Skip test clear to see if it has some bad side effect test_clear = pytest.mark.skip(copy_func(TCPIPInstrBaseTest.test_clear)) - test_wrapping_handler = pytest.mark.xfail( - copy_func(TCPIPInstrBaseTest.test_wrapping_handler) - ) + test_wrapping_handler = copy_func(TCPIPInstrBaseTest.test_wrapping_handler) - test_managing_visa_handler = pytest.mark.xfail( - copy_func(TCPIPInstrBaseTest.test_managing_visa_handler) + test_managing_visa_handler = copy_func( + TCPIPInstrBaseTest.test_managing_visa_handler ) - test_wait_on_event = pytest.mark.xfail( - copy_func(TCPIPInstrBaseTest.test_wait_on_event) - ) + test_wait_on_event = copy_func(TCPIPInstrBaseTest.test_wait_on_event) - test_wait_on_event_timeout = pytest.mark.xfail( - copy_func(TCPIPInstrBaseTest.test_wait_on_event_timeout) + test_wait_on_event_timeout = copy_func( + TCPIPInstrBaseTest.test_wait_on_event_timeout ) test_getting_unknown_buffer = pytest.mark.xfail( diff --git a/pyvisa_py/testsuite/test_events.py b/pyvisa_py/testsuite/test_events.py new file mode 100644 index 00000000..895940b8 --- /dev/null +++ b/pyvisa_py/testsuite/test_events.py @@ -0,0 +1,725 @@ +"""Unit tests for the pyvisa-py event handling subsystem. + +These tests cover the core event primitives in ``events.py``, the high-level +library methods in ``highlevel.py``, and transport-specific SRQ logic for +VXI-11 (mocked). + +""" + +from __future__ import annotations + +import ipaddress +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest + +from pyvisa import constants, errors +from pyvisa.constants import StatusCode +from pyvisa.typing import VISASession +from pyvisa_py.events import ( + EventContext, + EventMechanismFlag, + EventQueue, + EventState, + HandlerRegistry, +) +from pyvisa_py.highlevel import PyVisaLibrary +from pyvisa_py.protocols import vxi11 + +# --------------------------------------------------------------------------- +# EventContext +# --------------------------------------------------------------------------- + + +class TestEventContext: + def test_defaults(self): + ctx = EventContext(event_type=constants.EventType.service_request) + assert ctx.event_type == constants.EventType.service_request + assert ctx.status_byte == 0 + assert ctx.timestamp <= time.time() + assert isinstance(ctx.context_id, int) + assert 0 <= ctx.context_id < 2**32 + + def test_context_id_randomness(self): + ctx1 = EventContext(event_type=constants.EventType.service_request) + ctx2 = EventContext(event_type=constants.EventType.service_request) + # Extremely unlikely to collide on 32-bit random space + assert ctx1.context_id != ctx2.context_id + + def test_explicit_values(self): + ctx = EventContext( + event_type=constants.EventType.io_completion, + status_byte=0x42, + timestamp=1234.5, + context_id=99, + ) + assert ctx.event_type == constants.EventType.io_completion + assert ctx.status_byte == 0x42 + assert ctx.timestamp == 1234.5 + assert ctx.context_id == 99 + + +# --------------------------------------------------------------------------- +# EventQueue +# --------------------------------------------------------------------------- + + +class TestEventQueue: + def test_put_get_roundtrip(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.service_request) + q.put(ctx) + assert q.get(timeout_ms=None) is ctx + + def test_get_zero_timeout_empty(self): + q = EventQueue() + assert q.get(timeout_ms=0) is None + + def test_get_positive_timeout_returns_item(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.service_request) + q.put(ctx) + assert q.get(timeout_ms=100) is ctx + + def test_get_positive_timeout_blocks_then_none(self): + q = EventQueue() + start = time.time() + result = q.get(timeout_ms=50) + elapsed = time.time() - start + assert result is None + assert elapsed >= 0.04 # generous tolerance + + def test_get_none_blocks_forever(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.service_request) + + def delayed_put(): + time.sleep(0.05) + q.put(ctx) + + t = threading.Thread(target=delayed_put) + t.start() + assert q.get(timeout_ms=None) is ctx + t.join() + + def test_discard_all_matching_event_type(self): + q = EventQueue() + ctx_srq = EventContext(event_type=constants.EventType.service_request) + ctx_io = EventContext(event_type=constants.EventType.io_completion) + q.put(ctx_srq) + q.put(ctx_io) + q.discard_all(constants.EventType.service_request) + assert q.get(timeout_ms=0) is ctx_io + assert q.get(timeout_ms=0) is None + + def test_discard_all_none_clears_everything(self): + q = EventQueue() + q.put(EventContext(event_type=constants.EventType.service_request)) + q.put(EventContext(event_type=constants.EventType.io_completion)) + q.discard_all(None) + assert q.get(timeout_ms=0) is None + + def test_get_matching_returns_matching_event(self): + q = EventQueue() + ctx_srq = EventContext(event_type=constants.EventType.service_request) + ctx_io = EventContext(event_type=constants.EventType.io_completion) + q.put(ctx_io) + q.put(ctx_srq) + assert ( + q.get_matching(constants.EventType.service_request, timeout_ms=0) is ctx_srq + ) + assert q.get_matching(constants.EventType.io_completion, timeout_ms=0) is ctx_io + + def test_get_matching_non_matching_returns_none(self): + q = EventQueue() + q.put(EventContext(event_type=constants.EventType.io_completion)) + assert q.get_matching(constants.EventType.service_request, timeout_ms=0) is None + + def test_get_matching_blocks_until_match(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.service_request) + + def delayed_put(): + time.sleep(0.05) + q.put(ctx) + + t = threading.Thread(target=delayed_put) + t.start() + assert ( + q.get_matching(constants.EventType.service_request, timeout_ms=None) is ctx + ) + t.join() + + def test_get_matching_positive_timeout(self): + q = EventQueue() + start = time.time() + result = q.get_matching(constants.EventType.service_request, timeout_ms=50) + elapsed = time.time() - start + assert result is None + assert elapsed >= 0.04 + + def test_get_matching_positive_timeout_event_arrives(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.service_request) + + def delayed_put(): + time.sleep(0.02) + q.put(ctx) + + t = threading.Thread(target=delayed_put) + t.start() + assert ( + q.get_matching(constants.EventType.service_request, timeout_ms=200) is ctx + ) + t.join() + + def test_get_matching_none_event_type_returns_any(self): + q = EventQueue() + ctx = EventContext(event_type=constants.EventType.io_completion) + q.put(ctx) + assert q.get_matching(None, timeout_ms=0) is ctx + + +# --------------------------------------------------------------------------- +# HandlerRegistry +# --------------------------------------------------------------------------- + + +class TestHandlerRegistry: + def test_install_and_fire(self): + reg = HandlerRegistry() + calls = [] + + def handler(sess, etype, cid, uhandle): + calls.append((sess, etype, cid, uhandle)) + + reg.install(constants.EventType.service_request, handler, "h1") + reg.fire(constants.EventType.service_request, VISASession(42), 42) + assert calls == [ + (VISASession(42), constants.EventType.service_request, 42, "h1") + ] + + def test_multiple_handlers_fire(self): + reg = HandlerRegistry() + calls = [] + + def h1(sess, etype, cid, uhandle): + calls.append("h1") + + def h2(sess, etype, cid, uhandle): + calls.append("h2") + + reg.install(constants.EventType.service_request, h1, None) + reg.install(constants.EventType.service_request, h2, None) + reg.fire(constants.EventType.service_request, VISASession(1), 1) + assert set(calls) == {"h1", "h2"} + + def test_uninstall_by_identity_and_handle(self): + reg = HandlerRegistry() + + def h1(*_): + pass + + def h2(*_): + pass + + reg.install(constants.EventType.service_request, h1, "a") + reg.install(constants.EventType.service_request, h2, "b") + assert reg.uninstall(constants.EventType.service_request, h1, "a") is True + assert reg.uninstall(constants.EventType.service_request, h2, "wrong") is False + assert reg.uninstall(constants.EventType.service_request, h2, "b") is True + assert reg.uninstall(constants.EventType.service_request, h1, "a") is False + + def test_uninstall_with_none_user_handle(self): + reg = HandlerRegistry() + + def h1(*_): + pass + + reg.install(constants.EventType.service_request, h1, "any") + assert reg.uninstall(constants.EventType.service_request, h1, None) is True + assert reg.uninstall(constants.EventType.service_request, h1, None) is False + + def test_fire_catches_exceptions(self): + reg = HandlerRegistry() + calls = [] + + def bad(*_): + raise RuntimeError("boom") + + def good(*_): + calls.append("good") + + reg.install(constants.EventType.service_request, bad, None) + reg.install(constants.EventType.service_request, good, None) + with pytest.warns(UserWarning, match="boom"): + reg.fire(constants.EventType.service_request, VISASession(1), 1) + assert calls == ["good"] + + def test_fire_no_handlers_noop(self): + reg = HandlerRegistry() + # Should not raise + reg.fire(constants.EventType.service_request, VISASession(1), 1) + + +# --------------------------------------------------------------------------- +# EventState +# --------------------------------------------------------------------------- + + +class TestEventState: + def test_enable_disable(self): + st = EventState() + st.enable(constants.EventType.service_request, constants.EventMechanism.queue) + assert ( + st.enabled[constants.EventType.service_request] is EventMechanismFlag.QUEUE + ) + assert st.is_queue_enabled(constants.EventType.service_request) is True + assert st.is_handler_enabled(constants.EventType.service_request) is False + st.enable(constants.EventType.service_request, constants.EventMechanism.handler) + assert st.enabled[constants.EventType.service_request] is ( + EventMechanismFlag.QUEUE | EventMechanismFlag.HANDLER + ) + assert st.is_handler_enabled(constants.EventType.service_request) is True + st.disable(constants.EventType.service_request, constants.EventMechanism.queue) + assert ( + st.enabled[constants.EventType.service_request] + is EventMechanismFlag.HANDLER + ) + assert st.is_queue_enabled(constants.EventType.service_request) is False + assert st.is_handler_enabled(constants.EventType.service_request) is True + st.disable( + constants.EventType.service_request, constants.EventMechanism.handler + ) + assert constants.EventType.service_request not in st.enabled + assert st.any_enabled() is False + + def test_any_enabled(self): + st = EventState() + assert st.any_enabled() is False + st.enable(constants.EventType.io_completion, constants.EventMechanism.queue) + assert st.enabled[constants.EventType.io_completion] is EventMechanismFlag.QUEUE + assert st.any_enabled() is True + + def test_disable_removes_empty_event_type(self): + st = EventState() + st.enable(constants.EventType.service_request, constants.EventMechanism.queue) + assert ( + st.enabled[constants.EventType.service_request] is EventMechanismFlag.QUEUE + ) + st.disable(constants.EventType.service_request, constants.EventMechanism.queue) + # Internal dict should be clean + assert constants.EventType.service_request not in st.enabled + + def test_enable_combined_bitmask(self): + st = EventState() + combined = constants.EventMechanism.queue | constants.EventMechanism.handler + st.enable(constants.EventType.service_request, combined) + assert st.enabled[constants.EventType.service_request] is ( + EventMechanismFlag.QUEUE | EventMechanismFlag.HANDLER + ) + assert st.is_queue_enabled(constants.EventType.service_request) is True + assert st.is_handler_enabled(constants.EventType.service_request) is True + + def test_disable_combined_bitmask(self): + st = EventState() + combined = constants.EventMechanism.queue | constants.EventMechanism.handler + st.enable(constants.EventType.service_request, combined) + st.disable(constants.EventType.service_request, combined) + assert constants.EventType.service_request not in st.enabled + assert st.is_queue_enabled(constants.EventType.service_request) is False + assert st.is_handler_enabled(constants.EventType.service_request) is False + assert st.any_enabled() is False + + def test_disable_all_clears_everything(self): + st = EventState() + st.enable(constants.EventType.service_request, constants.EventMechanism.queue) + st.enable(constants.EventType.service_request, constants.EventMechanism.handler) + assert st.enabled[constants.EventType.service_request] is ( + EventMechanismFlag.QUEUE | EventMechanismFlag.HANDLER + ) + st.disable(constants.EventType.service_request, constants.EventMechanism.all) + assert st.is_queue_enabled(constants.EventType.service_request) is False + assert st.is_handler_enabled(constants.EventType.service_request) is False + assert constants.EventType.service_request not in st.enabled + + +# --------------------------------------------------------------------------- +# highlevel.py (mocked session) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def lib_and_session(): + lib = PyVisaLibrary() + sess = MagicMock() + sess._event_state = EventState() + sess._supported_event_types = {constants.EventType.service_request} + sess._start_event_monitor.return_value = StatusCode.success + session_id = lib._register(sess) + sess._session_handle = session_id + return lib, sess, session_id + + +class TestHighlevelEventMethods: + def test_enable_event_delegates_and_starts_monitor(self, lib_and_session): + lib, sess, sid = lib_and_session + result = lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert result == StatusCode.success + assert sess._event_state.is_queue_enabled(constants.EventType.service_request) + sess._start_event_monitor.assert_called_once() + + def test_disable_event_delegates_and_stops_monitor(self, lib_and_session): + lib, sess, sid = lib_and_session + # First enable + lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + # Then disable + result = lib.disable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert result == StatusCode.success + assert not sess._event_state.is_queue_enabled( + constants.EventType.service_request + ) + sess._stop_event_monitor.assert_called_once() + + def test_disable_event_does_not_stop_when_other_enabled(self, lib_and_session): + lib, sess, sid = lib_and_session + lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.handler, + ) + sess._start_event_monitor.reset_mock() + sess._stop_event_monitor.reset_mock() + lib.disable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + # Handler still enabled -> monitor should NOT be stopped + sess._stop_event_monitor.assert_not_called() + + def test_discard_events_queue(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._event_state.queue.put( + EventContext(event_type=constants.EventType.service_request) + ) + result = lib.discard_events( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert result == StatusCode.success + assert sess._event_state.queue.get(timeout_ms=0) is None + + def test_discard_events_all_mechanism(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._event_state.queue.put( + EventContext(event_type=constants.EventType.service_request) + ) + result = lib.discard_events( + sid, + constants.EventType.service_request, + constants.EventMechanism.all, + ) + assert result == StatusCode.success + assert sess._event_state.queue.get(timeout_ms=0) is None + + def test_install_handler(self, lib_and_session): + lib, sess, sid = lib_and_session + + def my_handler(*_): + pass + + result = lib.install_handler( + sid, + constants.EventType.service_request, + my_handler, + "uh", + ) + assert result == (my_handler, "uh", my_handler, StatusCode.success) + handlers = sess._event_state.registry._handlers[ + constants.EventType.service_request + ] + assert handlers == [(my_handler, "uh")] + + def test_uninstall_handler_success(self, lib_and_session): + lib, sess, sid = lib_and_session + + def my_handler(*_): + pass + + sess._event_state.registry.install( + constants.EventType.service_request, my_handler, "uh" + ) + result = lib.uninstall_handler( + sid, + constants.EventType.service_request, + my_handler, + "uh", + ) + assert result == StatusCode.success + + def test_uninstall_handler_not_installed_raises(self, lib_and_session): + lib, _sess, sid = lib_and_session + + def my_handler(*_): + pass + + with pytest.raises(errors.VisaIOError) as exc_info: + lib.uninstall_handler( + sid, + constants.EventType.service_request, + my_handler, + ) + assert exc_info.value.error_code == StatusCode.error_handler_not_installed + + def test_wait_on_event_success(self, lib_and_session): + lib, sess, sid = lib_and_session + ctx = EventContext( + event_type=constants.EventType.service_request, context_id=123 + ) + sess._event_state.queue.put(ctx) + etype, ectx, status = lib.wait_on_event( + sid, constants.EventType.service_request, 1000 + ) + assert etype == constants.EventType.service_request + assert ectx == 123 + assert status == StatusCode.success + + def test_wait_on_event_timeout_raises(self, lib_and_session): + lib, _sess, sid = lib_and_session + with pytest.raises(errors.VisaIOError) as exc_info: + lib.wait_on_event(sid, constants.EventType.service_request, 50) + assert exc_info.value.error_code == StatusCode.error_timeout + + def test_wait_on_event_zero_timeout_raises(self, lib_and_session): + lib, _sess, sid = lib_and_session + with pytest.raises(errors.VisaIOError) as exc_info: + lib.wait_on_event(sid, constants.EventType.service_request, 0) + assert exc_info.value.error_code == StatusCode.error_timeout + + def test_wait_on_event_invalid_session(self, lib_and_session): + lib, _, _ = lib_and_session + with pytest.raises(errors.VisaIOError) as exc_info: + lib.wait_on_event(999999, constants.EventType.service_request, 0) + assert exc_info.value.error_code == StatusCode.error_invalid_object + + def test_enable_event_invalid_session(self, lib_and_session): + lib, _, _ = lib_and_session + with pytest.raises(errors.VisaIOError) as exc_info: + lib.enable_event( + 999999, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert exc_info.value.error_code == StatusCode.error_invalid_object + + def test_enable_event_unsupported_returns_error(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._supported_event_types = set() + with pytest.raises(errors.VisaIOError) as exc_info: + lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert exc_info.value.error_code == StatusCode.error_invalid_event + + def test_enable_event_supported_returns_success(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._supported_event_types = {constants.EventType.service_request} + result = lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert result == StatusCode.success + + def test_enable_event_rollback_on_monitor_failure(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._supported_event_types = {constants.EventType.service_request} + sess._start_event_monitor.return_value = StatusCode.error_io + with pytest.raises(errors.VisaIOError) as exc_info: + lib.enable_event( + sid, + constants.EventType.service_request, + constants.EventMechanism.queue, + ) + assert exc_info.value.error_code == StatusCode.error_io + assert not sess._event_state.is_queue_enabled( + constants.EventType.service_request + ) + + def test_discard_events_queue_and_handler_discards_queue(self, lib_and_session): + lib, sess, sid = lib_and_session + sess._event_state.queue.put( + EventContext(event_type=constants.EventType.service_request) + ) + combined = constants.EventMechanism.queue | constants.EventMechanism.handler + result = lib.discard_events( + sid, + constants.EventType.service_request, + combined, + ) + assert result == StatusCode.success + assert sess._event_state.queue.get(timeout_ms=0) is None + + def test_discard_events_handler_alone_does_not_discard_queue(self, lib_and_session): + lib, sess, sid = lib_and_session + ctx = EventContext(event_type=constants.EventType.service_request) + sess._event_state.queue.put(ctx) + result = lib.discard_events( + sid, + constants.EventType.service_request, + constants.EventMechanism.handler, + ) + assert result == StatusCode.success + assert sess._event_state.queue.get(timeout_ms=0) is ctx + + +# --------------------------------------------------------------------------- +# VXI-11 SRQ flow (mocked transport) +# --------------------------------------------------------------------------- + + +class TestVxi11SrqFlow: + @pytest.fixture + def mock_vxi11_session(self): + """Return a partially-initialised TCPIPInstrVxi11 with a mocked iface.""" + from pyvisa_py.tcpip import TCPIPInstrVxi11 + + sess = MagicMock(spec=TCPIPInstrVxi11) + sess._event_state = EventState() + sess.link = 1 + sess.interface = MagicMock() + sess.interface.create_intr_chan.return_value = 0 + sess.interface.device_enable_srq.return_value = 0 + sess.interface.destroy_intr_chan.return_value = 0 + sess._srq_server = None + sess._srq_lifecycle_lock = threading.Lock() + return sess + + def test_start_event_monitor_calls_enable(self, mock_vxi11_session): + from pyvisa_py.tcpip import TCPIPInstrVxi11 + + sess = mock_vxi11_session + sess._event_state.enable( + constants.EventType.service_request, constants.EventMechanism.queue + ) + sess.interface.sock.getsockname.return_value = ("192.168.1.2", 12345) + + # Patch SrqInterruptTCPServer so we don't bind a real TCP socket + with patch("pyvisa_py.tcpip.vxi11.SrqInterruptTCPServer") as MockServer: + mock_sock = MagicMock() + mock_sock.getsockname.return_value = ("127.0.0.1", 65432) + MockServer.return_value.sock = mock_sock + + result = TCPIPInstrVxi11._start_event_monitor(sess) + + assert result == StatusCode.success + sess.interface.create_intr_chan.assert_called_once_with( + int(ipaddress.IPv4Address("192.168.1.2")), + 65432, + vxi11.DEVICE_INTR_PROG, + vxi11.DEVICE_INTR_VERS, + vxi11.DEVICE_TCP, + ) + sess.interface.device_enable_srq.assert_called_once_with( + sess.link, True, b"srq" + ) + assert sess._event_state.monitor_thread is not None + # Clean up + sess._event_state.stop_flag.set() + if sess._event_state.monitor_thread is not None: + sess._event_state.monitor_thread.join(timeout=0.5) + + def test_start_event_monitor_create_intr_chan_error(self, mock_vxi11_session): + from pyvisa_py.tcpip import TCPIPInstrVxi11 + + sess = mock_vxi11_session + sess._event_state.enable( + constants.EventType.service_request, constants.EventMechanism.queue + ) + sess.interface.sock.getsockname.return_value = ("192.168.1.2", 12345) + sess.interface.create_intr_chan.return_value = 8 + + with patch("pyvisa_py.tcpip.vxi11.SrqInterruptTCPServer") as MockServer: + mock_sock = MagicMock() + mock_sock.getsockname.return_value = ("127.0.0.1", 65432) + MockServer.return_value.sock = mock_sock + + result = TCPIPInstrVxi11._start_event_monitor(sess) + + assert result == StatusCode.error_nonsupported_operation + assert sess._event_state.monitor_thread is None + + def test_stop_event_monitor_calls_disable(self, mock_vxi11_session): + from pyvisa_py.tcpip import TCPIPInstrVxi11 + + sess = mock_vxi11_session + sess._event_state.monitor_thread = None + TCPIPInstrVxi11._stop_event_monitor(sess) + sess.interface.device_enable_srq.assert_called_once_with(sess.link, False, b"") + sess.interface.destroy_intr_chan.assert_called_once() + + def test_fire_event_then_wait_on_event(self, lib_and_session): + """Simulate an SRQ by calling _fire_event on a mocked session.""" + lib, sess, sid = lib_and_session + sess._event_state.enable( + constants.EventType.service_request, constants.EventMechanism.queue + ) + ctx = EventContext( + event_type=constants.EventType.service_request, + status_byte=0x50, + context_id=9876, + ) + # Use the real Session._fire_event logic via a partial call + from pyvisa_py.sessions import Session + + Session._fire_event(sess, constants.EventType.service_request, ctx) + etype, ectx, status = lib.wait_on_event( + sid, constants.EventType.service_request, 1000 + ) + assert etype == constants.EventType.service_request + assert ectx == 9876 + assert status == StatusCode.success + + def test_vxi11_fire_event_handler_mechanism(self, lib_and_session): + _lib, sess, sid = lib_and_session + calls = [] + + def my_handler(session, event_type, context_id, user_handle): + calls.append((session, event_type, context_id, user_handle)) + + sess._event_state.registry.install( + constants.EventType.service_request, my_handler, "uh" + ) + sess._event_state.enable( + constants.EventType.service_request, constants.EventMechanism.handler + ) + ctx = EventContext( + event_type=constants.EventType.service_request, context_id=5555 + ) + from pyvisa_py.sessions import Session + + Session._fire_event(sess, constants.EventType.service_request, ctx) + assert calls == [(sid, constants.EventType.service_request, 5555, "uh")]