Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/natsrpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._natsrpy_rs import Message, Nats, Subscription
from ._natsrpy_rs import CallbackSubscription, IteratorSubscription, Message, Nats

__all__ = [
"CallbackSubscription",
"IteratorSubscription",
"Message",
"Nats",
"Subscription",
]
28 changes: 23 additions & 5 deletions python/natsrpy/_natsrpy_rs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import Any
from typing import Any, overload

from natsrpy._natsrpy_rs.js import JetStream
from natsrpy._natsrpy_rs.message import Message

class Subscription:
def __aiter__(self) -> Subscription: ...
class IteratorSubscription:
def __aiter__(self) -> IteratorSubscription: ...
async def __anext__(self) -> Message: ...
async def unsubscribe(self, limit: int | None = None) -> None: ...
async def drain(self) -> None: ...

class CallbackSubscription:
async def unsubscribe(self, limit: int | None = None) -> None: ...
async def drain(self) -> None: ...

class Nats:
def __init__(
Expand Down Expand Up @@ -37,7 +44,18 @@ class Nats:
async def request(self, subject: str, payload: bytes) -> None: ...
async def drain(self) -> None: ...
async def flush(self) -> None: ...
async def subscribe(self, subject: str) -> Subscription: ...
@overload
async def subscribe(
self,
subject: str,
callback: Callable[[Message], Awaitable[None]],
) -> CallbackSubscription: ...
@overload
async def subscribe(
self,
subject: str,
callback: None = None,
) -> IteratorSubscription: ...
async def jetstream(self) -> JetStream: ...

__all__ = ["Message", "Nats", "Subscription"]
__all__ = ["CallbackSubscription", "IteratorSubscription", "Message", "Nats"]
2 changes: 2 additions & 0 deletions src/exceptions/rust_err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub enum NatsrpyError {
#[error(transparent)]
SubscribeError(#[from] async_nats::SubscribeError),
#[error(transparent)]
UnsubscribeError(#[from] async_nats::UnsubscribeError),
#[error(transparent)]
KeyValueError(#[from] async_nats::jetstream::context::KeyValueError),
#[error(transparent)]
CreateKeyValueError(#[from] async_nats::jetstream::context::CreateKeyValueError),
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub mod exceptions;
pub mod js;
pub mod message;
pub mod nats_cls;
pub mod subscription;
pub mod subscriptions;
pub mod utils;

#[pyo3::pymodule]
Expand All @@ -38,7 +38,7 @@ pub mod _natsrpy_rs {
#[pymodule_export]
use super::nats_cls::NatsCls;
#[pymodule_export]
use super::subscription::Subscription;
use super::subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription};

#[pymodule_export]
use super::js::pymod as js;
Expand Down
16 changes: 12 additions & 4 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ pub struct Message {
pub length: usize,
}

impl TryFrom<async_nats::Message> for Message {
impl TryFrom<&async_nats::Message> for Message {
type Error = NatsrpyError;

fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
fn try_from(value: &async_nats::Message) -> Result<Self, Self::Error> {
Python::attach(move |gil| {
let headers = match value.headers {
let headers = match &value.headers {
Some(headermap) => headermap.to_pydict(gil)?.unbind(),
None => PyDict::new(gil).unbind(),
};
Expand All @@ -32,13 +32,21 @@ impl TryFrom<async_nats::Message> for Message {
payload: PyBytes::new(gil, &value.payload).unbind(),
headers,
status: value.status.map(Into::<u16>::into),
description: value.description,
description: value.description.clone(),
length: value.length,
})
})
}
}

impl TryFrom<async_nats::Message> for Message {
type Error = NatsrpyError;

fn try_from(value: async_nats::Message) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
}

#[pyo3::pymethods]
impl Message {
#[must_use]
Expand Down
94 changes: 51 additions & 43 deletions src/nats_cls.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage};
use pyo3::{
Bound, PyAny, PyResult, Python,
Bound, IntoPyObjectExt, Py, PyAny, Python,
types::{PyBytes, PyBytesMethods, PyDict},
};
use std::{sync::Arc, time::Duration};
use tokio::sync::RwLock;

use crate::{
exceptions::rust_err::NatsrpyError,
subscription::Subscription,
exceptions::rust_err::{NatsrpyError, NatsrpyResult},
subscriptions::{callback::CallbackSubscription, iterator::IteratorSubscription},
utils::{
futures::natsrpy_future_with_timeout,
headers::NatsrpyHeadermapExt,
Expand Down Expand Up @@ -75,7 +75,7 @@ impl NatsCls {
}
}

pub fn startup<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
pub fn startup<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
let mut conn_opts = async_nats::ConnectOptions::new();
if let Some((username, passwd)) = &self.user_and_pass {
conn_opts = conn_opts.user_and_password(username.clone(), passwd.clone());
Expand All @@ -100,23 +100,19 @@ impl NatsCls {
let session = self.nats_session.clone();
let address = self.addr.clone();
let timeout = self.connection_timeout;
return Ok(natsrpy_future_with_timeout(
py,
Some(timeout),
async move {
if session.read().await.is_some() {
return Err(NatsrpyError::SessionError(
"NATS session already exists".to_string(),
));
}
// Scoping for early-dropping of a guard.
{
let mut sesion_guard = session.write().await;
*sesion_guard = Some(conn_opts.connect(address).await?);
}
Ok(())
},
)?);
natsrpy_future_with_timeout(py, Some(timeout), async move {
if session.read().await.is_some() {
return Err(NatsrpyError::SessionError(
"NATS session already exists".to_string(),
));
}
// Scoping for early-dropping of a guard.
{
let mut sesion_guard = session.write().await;
*sesion_guard = Some(conn_opts.connect(address).await?);
}
Ok(())
})
}

#[pyo3(signature = (subject, payload, *, headers=None, reply=None, err_on_disconnect = false))]
Expand All @@ -128,14 +124,14 @@ impl NatsCls {
headers: Option<Bound<PyDict>>,
reply: Option<String>,
err_on_disconnect: bool,
) -> PyResult<Bound<'py, PyAny>> {
) -> NatsrpyResult<Bound<'py, PyAny>> {
let session = self.nats_session.clone();
log::info!("Payload: {payload:?}");
let data = payload.into();
let headermap = headers
.map(async_nats::HeaderMap::from_pydict)
.transpose()?;
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
if let Some(session) = session.read().await.as_ref() {
if err_on_disconnect
&& session.connection_state() == async_nats::connection::State::Disconnected
Expand All @@ -154,7 +150,7 @@ impl NatsCls {
} else {
Err(NatsrpyError::NotInitialized)
}
})?)
})
}

#[pyo3(signature = (subject, payload, *, headers=None, inbox = None, timeout=None))]
Expand All @@ -166,13 +162,13 @@ impl NatsCls {
headers: Option<Bound<PyDict>>,
inbox: Option<String>,
timeout: Option<Duration>,
) -> PyResult<Bound<'py, PyAny>> {
) -> NatsrpyResult<Bound<'py, PyAny>> {
let session = self.nats_session.clone();
let data = payload.map(|inner| bytes::Bytes::from(inner.as_bytes().to_vec()));
let headermap = headers
.map(async_nats::HeaderMap::from_pydict)
.transpose()?;
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
if let Some(session) = session.read().await.as_ref() {
let request = async_nats::Request {
payload: data,
Expand All @@ -185,32 +181,44 @@ impl NatsCls {
} else {
Err(NatsrpyError::NotInitialized)
}
})?)
})
}

pub fn drain<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
log::debug!("Draining NATS session");
let session = self.nats_session.clone();
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
if let Some(session) = session.write().await.as_ref() {
session.drain().await?;
Ok(())
} else {
Err(NatsrpyError::NotInitialized)
}
})?)
})
}

pub fn subscribe<'py>(&self, py: Python<'py>, subject: String) -> PyResult<Bound<'py, PyAny>> {
#[pyo3(signature=(subject, callback=None))]
pub fn subscribe<'py>(
&self,
py: Python<'py>,
subject: String,
callback: Option<Py<PyAny>>,
) -> NatsrpyResult<Bound<'py, PyAny>> {
log::debug!("Subscribing to '{subject}'");
let session = self.nats_session.clone();
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
if let Some(session) = session.read().await.as_ref() {
Ok(Subscription::new(session.subscribe(subject).await?))
if let Some(cb) = callback {
let sub = CallbackSubscription::new(session.subscribe(subject).await?, cb)?;
Ok(Python::attach(|gil| sub.into_py_any(gil))?)
} else {
let sub = IteratorSubscription::new(session.subscribe(subject).await?);
Ok(Python::attach(|gil| sub.into_py_any(gil))?)
}
} else {
Err(NatsrpyError::NotInitialized)
}
})?)
})
}

#[pyo3(signature = (
Expand All @@ -233,10 +241,10 @@ impl NatsCls {
concurrency_limit: Option<usize>,
max_ack_inflight: Option<usize>,
backpressure_on_inflight: Option<bool>,
) -> PyResult<Bound<'py, PyAny>> {
) -> NatsrpyResult<Bound<'py, PyAny>> {
log::debug!("Creating JetStream context");
let session = self.nats_session.clone();
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
let mut builder =
async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit);
if let Some(timeout) = ack_timeout {
Expand Down Expand Up @@ -269,13 +277,13 @@ impl NatsCls {
Ok(crate::js::jetstream::JetStream::new(js))
},
)
})?)
})
}

pub fn shutdown<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
log::debug!("Closing nats session");
let session = self.nats_session.clone();
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
let mut write_guard = session.write().await;
let Some(session) = write_guard.as_ref() else {
return Err(NatsrpyError::NotInitialized);
Expand All @@ -284,20 +292,20 @@ impl NatsCls {
*write_guard = None;
drop(write_guard);
Ok(())
})?)
})
}

pub fn flush<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
log::debug!("Flushing streams");
let session = self.nats_session.clone();
Ok(natsrpy_future(py, async move {
natsrpy_future(py, async move {
if let Some(session) = session.write().await.as_ref() {
session.flush().await?;
Ok(())
} else {
Err(NatsrpyError::NotInitialized)
}
})?)
})
}
}

Expand Down
Loading
Loading