diff --git a/python/natsrpy/__init__.py b/python/natsrpy/__init__.py index 624ff88..1a19879 100644 --- a/python/natsrpy/__init__.py +++ b/python/natsrpy/__init__.py @@ -1,7 +1,8 @@ -from ._natsrpy_rs import Message, Nats, Subscription +from ._natsrpy_rs import CallbackSubscription, IteratorSubscription, Message, Nats __all__ = [ + "CallbackSubscription", + "IteratorSubscription", "Message", "Nats", - "Subscription", ] diff --git a/python/natsrpy/_natsrpy_rs/__init__.pyi b/python/natsrpy/_natsrpy_rs/__init__.pyi index ff8402d..ec386be 100644 --- a/python/natsrpy/_natsrpy_rs/__init__.pyi +++ b/python/natsrpy/_natsrpy_rs/__init__.pyi @@ -1,12 +1,19 @@ +from collections.abc import Awaitable, Callable from datetime import timedelta -from typing import Any +from typing import Any, overload from natsrpy._natsrpy_rs.js import JetStream from natsrpy._natsrpy_rs.message import Message -class Subscription: - def __aiter__(self) -> Subscription: ... +class IteratorSubscription: + def __aiter__(self) -> IteratorSubscription: ... async def __anext__(self) -> Message: ... + async def unsubscribe(self, limit: int | None = None) -> None: ... + async def drain(self) -> None: ... + +class CallbackSubscription: + async def unsubscribe(self, limit: int | None = None) -> None: ... + async def drain(self) -> None: ... class Nats: def __init__( @@ -37,7 +44,18 @@ class Nats: async def request(self, subject: str, payload: bytes) -> None: ... async def drain(self) -> None: ... async def flush(self) -> None: ... - async def subscribe(self, subject: str) -> Subscription: ... + @overload + async def subscribe( + self, + subject: str, + callback: Callable[[Message], Awaitable[None]], + ) -> CallbackSubscription: ... + @overload + async def subscribe( + self, + subject: str, + callback: None = None, + ) -> IteratorSubscription: ... async def jetstream(self) -> JetStream: ... -__all__ = ["Message", "Nats", "Subscription"] +__all__ = ["CallbackSubscription", "IteratorSubscription", "Message", "Nats"] diff --git a/src/exceptions/rust_err.rs b/src/exceptions/rust_err.rs index e005a29..236ff94 100644 --- a/src/exceptions/rust_err.rs +++ b/src/exceptions/rust_err.rs @@ -41,6 +41,8 @@ pub enum NatsrpyError { #[error(transparent)] SubscribeError(#[from] async_nats::SubscribeError), #[error(transparent)] + UnsubscribeError(#[from] async_nats::UnsubscribeError), + #[error(transparent)] KeyValueError(#[from] async_nats::jetstream::context::KeyValueError), #[error(transparent)] CreateKeyValueError(#[from] async_nats::jetstream::context::CreateKeyValueError), diff --git a/src/lib.rs b/src/lib.rs index e6abd67..5747164 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ pub mod exceptions; pub mod js; pub mod message; pub mod nats_cls; -pub mod subscription; +pub mod subscriptions; pub mod utils; #[pyo3::pymodule] @@ -38,7 +38,7 @@ pub mod _natsrpy_rs { #[pymodule_export] use super::nats_cls::NatsCls; #[pymodule_export] - use super::subscription::Subscription; + use super::subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription}; #[pymodule_export] use super::js::pymod as js; diff --git a/src/message.rs b/src/message.rs index bedbcf9..082e0fd 100644 --- a/src/message.rs +++ b/src/message.rs @@ -17,12 +17,12 @@ pub struct Message { pub length: usize, } -impl TryFrom for Message { +impl TryFrom<&async_nats::Message> for Message { type Error = NatsrpyError; - fn try_from(value: async_nats::Message) -> Result { + fn try_from(value: &async_nats::Message) -> Result { Python::attach(move |gil| { - let headers = match value.headers { + let headers = match &value.headers { Some(headermap) => headermap.to_pydict(gil)?.unbind(), None => PyDict::new(gil).unbind(), }; @@ -32,13 +32,21 @@ impl TryFrom for Message { payload: PyBytes::new(gil, &value.payload).unbind(), headers, status: value.status.map(Into::::into), - description: value.description, + description: value.description.clone(), length: value.length, }) }) } } +impl TryFrom for Message { + type Error = NatsrpyError; + + fn try_from(value: async_nats::Message) -> Result { + Self::try_from(&value) + } +} + #[pyo3::pymethods] impl Message { #[must_use] diff --git a/src/nats_cls.rs b/src/nats_cls.rs index a42ea7b..a2b0d52 100644 --- a/src/nats_cls.rs +++ b/src/nats_cls.rs @@ -1,14 +1,14 @@ use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage}; use pyo3::{ - Bound, PyAny, PyResult, Python, + Bound, IntoPyObjectExt, Py, PyAny, Python, types::{PyBytes, PyBytesMethods, PyDict}, }; use std::{sync::Arc, time::Duration}; use tokio::sync::RwLock; use crate::{ - exceptions::rust_err::NatsrpyError, - subscription::Subscription, + exceptions::rust_err::{NatsrpyError, NatsrpyResult}, + subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription}, utils::{ futures::natsrpy_future_with_timeout, headers::NatsrpyHeadermapExt, @@ -75,7 +75,7 @@ impl NatsCls { } } - pub fn startup<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn startup<'py>(&self, py: Python<'py>) -> NatsrpyResult> { let mut conn_opts = async_nats::ConnectOptions::new(); if let Some((username, passwd)) = &self.user_and_pass { conn_opts = conn_opts.user_and_password(username.clone(), passwd.clone()); @@ -100,23 +100,19 @@ impl NatsCls { let session = self.nats_session.clone(); let address = self.addr.clone(); let timeout = self.connection_timeout; - return Ok(natsrpy_future_with_timeout( - py, - Some(timeout), - async move { - if session.read().await.is_some() { - return Err(NatsrpyError::SessionError( - "NATS session already exists".to_string(), - )); - } - // Scoping for early-dropping of a guard. - { - let mut sesion_guard = session.write().await; - *sesion_guard = Some(conn_opts.connect(address).await?); - } - Ok(()) - }, - )?); + natsrpy_future_with_timeout(py, Some(timeout), async move { + if session.read().await.is_some() { + return Err(NatsrpyError::SessionError( + "NATS session already exists".to_string(), + )); + } + // Scoping for early-dropping of a guard. + { + let mut sesion_guard = session.write().await; + *sesion_guard = Some(conn_opts.connect(address).await?); + } + Ok(()) + }) } #[pyo3(signature = (subject, payload, *, headers=None, reply=None, err_on_disconnect = false))] @@ -128,14 +124,14 @@ impl NatsCls { headers: Option>, reply: Option, err_on_disconnect: bool, - ) -> PyResult> { + ) -> NatsrpyResult> { let session = self.nats_session.clone(); log::info!("Payload: {payload:?}"); let data = payload.into(); let headermap = headers .map(async_nats::HeaderMap::from_pydict) .transpose()?; - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { if let Some(session) = session.read().await.as_ref() { if err_on_disconnect && session.connection_state() == async_nats::connection::State::Disconnected @@ -154,7 +150,7 @@ impl NatsCls { } else { Err(NatsrpyError::NotInitialized) } - })?) + }) } #[pyo3(signature = (subject, payload, *, headers=None, inbox = None, timeout=None))] @@ -166,13 +162,13 @@ impl NatsCls { headers: Option>, inbox: Option, timeout: Option, - ) -> PyResult> { + ) -> NatsrpyResult> { let session = self.nats_session.clone(); let data = payload.map(|inner| bytes::Bytes::from(inner.as_bytes().to_vec())); let headermap = headers .map(async_nats::HeaderMap::from_pydict) .transpose()?; - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { if let Some(session) = session.read().await.as_ref() { let request = async_nats::Request { payload: data, @@ -185,32 +181,44 @@ impl NatsCls { } else { Err(NatsrpyError::NotInitialized) } - })?) + }) } - pub fn drain<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Draining NATS session"); let session = self.nats_session.clone(); - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { if let Some(session) = session.write().await.as_ref() { session.drain().await?; Ok(()) } else { Err(NatsrpyError::NotInitialized) } - })?) + }) } - pub fn subscribe<'py>(&self, py: Python<'py>, subject: String) -> PyResult> { + #[pyo3(signature=(subject, callback=None))] + pub fn subscribe<'py>( + &self, + py: Python<'py>, + subject: String, + callback: Option>, + ) -> NatsrpyResult> { log::debug!("Subscribing to '{subject}'"); let session = self.nats_session.clone(); - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { if let Some(session) = session.read().await.as_ref() { - Ok(Subscription::new(session.subscribe(subject).await?)) + if let Some(cb) = callback { + let sub = CallbackSubscription::new(session.subscribe(subject).await?, cb)?; + Ok(Python::attach(|gil| sub.into_py_any(gil))?) + } else { + let sub = IteratorSubscription::new(session.subscribe(subject).await?); + Ok(Python::attach(|gil| sub.into_py_any(gil))?) + } } else { Err(NatsrpyError::NotInitialized) } - })?) + }) } #[pyo3(signature = ( @@ -233,10 +241,10 @@ impl NatsCls { concurrency_limit: Option, max_ack_inflight: Option, backpressure_on_inflight: Option, - ) -> PyResult> { + ) -> NatsrpyResult> { log::debug!("Creating JetStream context"); let session = self.nats_session.clone(); - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { let mut builder = async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit); if let Some(timeout) = ack_timeout { @@ -269,13 +277,13 @@ impl NatsCls { Ok(crate::js::jetstream::JetStream::new(js)) }, ) - })?) + }) } - pub fn shutdown<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Closing nats session"); let session = self.nats_session.clone(); - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { let mut write_guard = session.write().await; let Some(session) = write_guard.as_ref() else { return Err(NatsrpyError::NotInitialized); @@ -284,20 +292,20 @@ impl NatsCls { *write_guard = None; drop(write_guard); Ok(()) - })?) + }) } - pub fn flush<'py>(&self, py: Python<'py>) -> PyResult> { + pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Flushing streams"); let session = self.nats_session.clone(); - Ok(natsrpy_future(py, async move { + natsrpy_future(py, async move { if let Some(session) = session.write().await.as_ref() { session.flush().await?; Ok(()) } else { Err(NatsrpyError::NotInitialized) } - })?) + }) } } diff --git a/src/subscription.rs b/src/subscription.rs index 47e8aae..13cec8d 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -1,25 +1,67 @@ use futures_util::StreamExt; use std::sync::Arc; -use pyo3::{Bound, PyAny, PyRef, Python}; +use pyo3::{Bound, Py, PyAny, PyRef, Python}; use tokio::sync::Mutex; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, - utils::{futures::natsrpy_future_with_timeout, py_types::TimeValue}, + utils::{futures::natsrpy_future_with_timeout, natsrpy_future, py_types::TimeValue}, }; #[pyo3::pyclass] pub struct Subscription { inner: Option>>, + reading_task: Option, +} + +async fn process_message(message: async_nats::message::Message, py_callback: Py) { + let task = async || -> NatsrpyResult<()> { + let message = crate::message::Message::try_from(&message)?; + let awaitable = Python::attach(|gil| -> NatsrpyResult<_> { + let res = py_callback.call1(gil, (message,))?; + let rust_task = pyo3_async_runtimes::tokio::into_future(res.into_bound(gil))?; + Ok(rust_task) + })?; + awaitable.await?; + Ok(()) + }; + if let Err(err) = task().await { + log::error!("Cannot process message {message:?}. Error: {err}"); + } +} + +async fn start_py_sub( + sub: Arc>, + py_callback: Py, + locals: pyo3_async_runtimes::TaskLocals, +) { + while let Some(message) = sub.lock().await.next().await { + let py_cb = Python::attach(|py| py_callback.clone_ref(py)); + tokio::spawn(pyo3_async_runtimes::tokio::scope( + locals.clone(), + process_message(message, py_cb), + )); + } } impl Subscription { - #[must_use] - pub fn new(sub: async_nats::Subscriber) -> Self { - Self { - inner: Some(Arc::new(Mutex::new(sub))), - } + pub fn new(sub: async_nats::Subscriber, callback: Option>) -> NatsrpyResult { + let sub = Arc::new(Mutex::new(sub)); + let cb_sub = sub.clone(); + let task_locals = Python::attach(pyo3_async_runtimes::tokio::get_current_locals)?; + let task_handle = callback.map(move |cb| { + tokio::task::spawn(pyo3_async_runtimes::tokio::scope( + task_locals.clone(), + start_py_sub(cb_sub, cb, task_locals), + )) + .abort_handle() + }); + + Ok(Self { + inner: Some(sub), + reading_task: task_handle, + }) } } @@ -36,13 +78,17 @@ impl Subscription { timeout: Option, ) -> NatsrpyResult> { let Some(inner) = self.inner.clone() else { - return Err(NatsrpyError::NotInitialized); + unreachable!("Subscription used after del") }; + if self.reading_task.is_some() { + log::warn!( + "Callback is set. Getting messages from this subscription might produce unpredictable results." + ); + } natsrpy_future_with_timeout(py, timeout, async move { let Some(message) = inner.lock().await.next().await else { return Err(NatsrpyError::AsyncStopIteration); }; - crate::message::Message::try_from(message) }) } @@ -50,6 +96,35 @@ impl Subscription { pub fn __anext__<'py>(&self, py: Python<'py>) -> NatsrpyResult> { self.next(py, None) } + + #[pyo3(signature=(limit=None))] + pub fn unsubscribe<'py>( + &self, + py: Python<'py>, + limit: Option, + ) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + if let Some(limit) = limit { + inner.lock().await.unsubscribe_after(limit).await?; + } else { + inner.lock().await.unsubscribe().await?; + } + Ok(()) + }) + } + + pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + inner.lock().await.drain().await?; + Ok(()) + }) + } } /// This is required only because @@ -67,6 +142,9 @@ impl Drop for Subscription { fn drop(&mut self) { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.inner = None; + if let Some(reading) = self.reading_task.take() { + reading.abort(); + } }); } } diff --git a/src/subscriptions/callback.rs b/src/subscriptions/callback.rs new file mode 100644 index 0000000..0bc5cbe --- /dev/null +++ b/src/subscriptions/callback.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use futures_util::StreamExt; +use pyo3::{Bound, Py, PyAny, Python}; +use tokio::sync::Mutex; + +use crate::{exceptions::rust_err::NatsrpyResult, utils::natsrpy_future}; + +#[pyo3::pyclass] +pub struct CallbackSubscription { + inner: Option>>, + reading_task: tokio::task::AbortHandle, +} + +async fn process_message(message: async_nats::message::Message, py_callback: Py) { + let task = async || -> NatsrpyResult<()> { + let message = crate::message::Message::try_from(&message)?; + let awaitable = Python::attach(|gil| -> NatsrpyResult<_> { + let res = py_callback.call1(gil, (message,))?; + let rust_task = pyo3_async_runtimes::tokio::into_future(res.into_bound(gil))?; + Ok(rust_task) + })?; + awaitable.await?; + Ok(()) + }; + if let Err(err) = task().await { + log::error!("Cannot process message {message:?}. Error: {err}"); + } +} + +async fn start_py_sub( + sub: Arc>, + py_callback: Py, + locals: pyo3_async_runtimes::TaskLocals, +) { + while let Some(message) = sub.lock().await.next().await { + let py_cb = Python::attach(|py| py_callback.clone_ref(py)); + tokio::spawn(pyo3_async_runtimes::tokio::scope( + locals.clone(), + process_message(message, py_cb), + )); + } +} + +impl CallbackSubscription { + pub fn new(sub: async_nats::Subscriber, callback: Py) -> NatsrpyResult { + let sub = Arc::new(Mutex::new(sub)); + let cb_sub = sub.clone(); + let task_locals = Python::attach(pyo3_async_runtimes::tokio::get_current_locals)?; + let task_handle = tokio::task::spawn(pyo3_async_runtimes::tokio::scope( + task_locals.clone(), + start_py_sub(cb_sub, callback, task_locals), + )) + .abort_handle(); + Ok(Self { + inner: Some(sub), + reading_task: task_handle, + }) + } +} + +#[pyo3::pymethods] +impl CallbackSubscription { + #[pyo3(signature=(limit=None))] + pub fn unsubscribe<'py>( + &self, + py: Python<'py>, + limit: Option, + ) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + if let Some(limit) = limit { + inner.lock().await.unsubscribe_after(limit).await?; + } else { + inner.lock().await.unsubscribe().await?; + } + Ok(()) + }) + } + + pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + inner.lock().await.drain().await?; + Ok(()) + }) + } +} + +/// This is required only because +/// in nats library they run async operation on Drop. +/// +/// Because of that we need to execute drop in async +/// runtime's context. +/// +/// And because we want to perform a drop, +/// we need somehow drop the inner variable, +/// but leave self intouch. That is exactly why we have +/// Option>. So we can just assign it to None +/// and it will perform a drop. +impl Drop for CallbackSubscription { + fn drop(&mut self) { + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { + self.inner = None; + self.reading_task.abort(); + }); + } +} diff --git a/src/subscriptions/iterator.rs b/src/subscriptions/iterator.rs new file mode 100644 index 0000000..e94f919 --- /dev/null +++ b/src/subscriptions/iterator.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use futures_util::StreamExt; +use pyo3::{Bound, PyAny, PyRef, Python}; +use tokio::sync::Mutex; + +use crate::exceptions::rust_err::{NatsrpyError, NatsrpyResult}; +use crate::utils::futures::natsrpy_future_with_timeout; +use crate::utils::natsrpy_future; +use crate::utils::py_types::TimeValue; + +#[pyo3::pyclass] +pub struct IteratorSubscription { + inner: Option>>, +} + +impl IteratorSubscription { + #[must_use] + pub fn new(sub: async_nats::Subscriber) -> Self { + Self { + inner: Some(Arc::new(Mutex::new(sub))), + } + } +} + +#[pyo3::pymethods] +impl IteratorSubscription { + #[must_use] + pub const fn __aiter__(slf: PyRef) -> PyRef { + slf + } + + pub fn next<'py>( + &self, + py: Python<'py>, + timeout: Option, + ) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future_with_timeout(py, timeout, async move { + let Some(message) = inner.lock().await.next().await else { + return Err(NatsrpyError::AsyncStopIteration); + }; + crate::message::Message::try_from(message) + }) + } + + pub fn __anext__<'py>(&self, py: Python<'py>) -> NatsrpyResult> { + self.next(py, None) + } + + #[pyo3(signature=(limit=None))] + pub fn unsubscribe<'py>( + &self, + py: Python<'py>, + limit: Option, + ) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + if let Some(limit) = limit { + inner.lock().await.unsubscribe_after(limit).await?; + } else { + inner.lock().await.unsubscribe().await?; + } + Ok(()) + }) + } + + pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { + let Some(inner) = self.inner.clone() else { + unreachable!("Subscription used after del") + }; + natsrpy_future(py, async move { + inner.lock().await.drain().await?; + Ok(()) + }) + } +} + +/// This is required only because +/// in nats library they run async operation on Drop. +/// +/// Because of that we need to execute drop in async +/// runtime's context. +/// +/// And because we want to perform a drop, +/// we need somehow drop the inner variable, +/// but leave self intouch. That is exactly why we have +/// Option>. So we can just assign it to None +/// and it will perform a drop. +impl Drop for IteratorSubscription { + fn drop(&mut self) { + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { + self.inner = None; + }); + } +} diff --git a/src/subscriptions/mod.rs b/src/subscriptions/mod.rs new file mode 100644 index 0000000..5c11901 --- /dev/null +++ b/src/subscriptions/mod.rs @@ -0,0 +1,2 @@ +pub mod callback; +pub mod iterator;