Skip to content

Commit ec16362

Browse files
committed
Added more subscription methods.
1 parent fa0b9e0 commit ec16362

File tree

4 files changed

+76
-43
lines changed

4 files changed

+76
-43
lines changed

python/natsrpy/_natsrpy_rs/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from natsrpy._natsrpy_rs.message import Message
77
class Subscription:
88
def __aiter__(self) -> Subscription: ...
99
async def __anext__(self) -> Message: ...
10+
async def unsubscribe(self, limit: int | None = None) -> None: ...
11+
async def drain(self) -> None: ...
1012

1113
class Nats:
1214
def __init__(

src/exceptions/rust_err.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub enum NatsrpyError {
4141
#[error(transparent)]
4242
SubscribeError(#[from] async_nats::SubscribeError),
4343
#[error(transparent)]
44+
UnsubscribeError(#[from] async_nats::UnsubscribeError),
45+
#[error(transparent)]
4446
KeyValueError(#[from] async_nats::jetstream::context::KeyValueError),
4547
#[error(transparent)]
4648
CreateKeyValueError(#[from] async_nats::jetstream::context::CreateKeyValueError),

src/nats_cls.rs

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage};
22
use pyo3::{
3-
Bound, PyAny, PyResult, Python,
3+
Bound, PyAny, Python,
44
types::{PyBytes, PyBytesMethods, PyDict},
55
};
66
use std::{sync::Arc, time::Duration};
77
use tokio::sync::RwLock;
88

99
use crate::{
10-
exceptions::rust_err::NatsrpyError,
10+
exceptions::rust_err::{NatsrpyError, NatsrpyResult},
1111
subscription::Subscription,
1212
utils::{
1313
futures::natsrpy_future_with_timeout,
@@ -75,7 +75,7 @@ impl NatsCls {
7575
}
7676
}
7777

78-
pub fn startup<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
78+
pub fn startup<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
7979
let mut conn_opts = async_nats::ConnectOptions::new();
8080
if let Some((username, passwd)) = &self.user_and_pass {
8181
conn_opts = conn_opts.user_and_password(username.clone(), passwd.clone());
@@ -100,23 +100,19 @@ impl NatsCls {
100100
let session = self.nats_session.clone();
101101
let address = self.addr.clone();
102102
let timeout = self.connection_timeout;
103-
return Ok(natsrpy_future_with_timeout(
104-
py,
105-
Some(timeout),
106-
async move {
107-
if session.read().await.is_some() {
108-
return Err(NatsrpyError::SessionError(
109-
"NATS session already exists".to_string(),
110-
));
111-
}
112-
// Scoping for early-dropping of a guard.
113-
{
114-
let mut sesion_guard = session.write().await;
115-
*sesion_guard = Some(conn_opts.connect(address).await?);
116-
}
117-
Ok(())
118-
},
119-
)?);
103+
natsrpy_future_with_timeout(py, Some(timeout), async move {
104+
if session.read().await.is_some() {
105+
return Err(NatsrpyError::SessionError(
106+
"NATS session already exists".to_string(),
107+
));
108+
}
109+
// Scoping for early-dropping of a guard.
110+
{
111+
let mut sesion_guard = session.write().await;
112+
*sesion_guard = Some(conn_opts.connect(address).await?);
113+
}
114+
Ok(())
115+
})
120116
}
121117

122118
#[pyo3(signature = (subject, payload, *, headers=None, reply=None, err_on_disconnect = false))]
@@ -128,14 +124,14 @@ impl NatsCls {
128124
headers: Option<Bound<PyDict>>,
129125
reply: Option<String>,
130126
err_on_disconnect: bool,
131-
) -> PyResult<Bound<'py, PyAny>> {
127+
) -> NatsrpyResult<Bound<'py, PyAny>> {
132128
let session = self.nats_session.clone();
133129
log::info!("Payload: {payload:?}");
134130
let data = payload.into();
135131
let headermap = headers
136132
.map(async_nats::HeaderMap::from_pydict)
137133
.transpose()?;
138-
Ok(natsrpy_future(py, async move {
134+
natsrpy_future(py, async move {
139135
if let Some(session) = session.read().await.as_ref() {
140136
if err_on_disconnect
141137
&& session.connection_state() == async_nats::connection::State::Disconnected
@@ -154,7 +150,7 @@ impl NatsCls {
154150
} else {
155151
Err(NatsrpyError::NotInitialized)
156152
}
157-
})?)
153+
})
158154
}
159155

160156
#[pyo3(signature = (subject, payload, *, headers=None, inbox = None, timeout=None))]
@@ -166,13 +162,13 @@ impl NatsCls {
166162
headers: Option<Bound<PyDict>>,
167163
inbox: Option<String>,
168164
timeout: Option<Duration>,
169-
) -> PyResult<Bound<'py, PyAny>> {
165+
) -> NatsrpyResult<Bound<'py, PyAny>> {
170166
let session = self.nats_session.clone();
171167
let data = payload.map(|inner| bytes::Bytes::from(inner.as_bytes().to_vec()));
172168
let headermap = headers
173169
.map(async_nats::HeaderMap::from_pydict)
174170
.transpose()?;
175-
Ok(natsrpy_future(py, async move {
171+
natsrpy_future(py, async move {
176172
if let Some(session) = session.read().await.as_ref() {
177173
let request = async_nats::Request {
178174
payload: data,
@@ -185,32 +181,36 @@ impl NatsCls {
185181
} else {
186182
Err(NatsrpyError::NotInitialized)
187183
}
188-
})?)
184+
})
189185
}
190186

191-
pub fn drain<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
187+
pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
192188
log::debug!("Draining NATS session");
193189
let session = self.nats_session.clone();
194-
Ok(natsrpy_future(py, async move {
190+
natsrpy_future(py, async move {
195191
if let Some(session) = session.write().await.as_ref() {
196192
session.drain().await?;
197193
Ok(())
198194
} else {
199195
Err(NatsrpyError::NotInitialized)
200196
}
201-
})?)
197+
})
202198
}
203199

204-
pub fn subscribe<'py>(&self, py: Python<'py>, subject: String) -> PyResult<Bound<'py, PyAny>> {
200+
pub fn subscribe<'py>(
201+
&self,
202+
py: Python<'py>,
203+
subject: String,
204+
) -> NatsrpyResult<Bound<'py, PyAny>> {
205205
log::debug!("Subscribing to '{subject}'");
206206
let session = self.nats_session.clone();
207-
Ok(natsrpy_future(py, async move {
207+
natsrpy_future(py, async move {
208208
if let Some(session) = session.read().await.as_ref() {
209209
Ok(Subscription::new(session.subscribe(subject).await?))
210210
} else {
211211
Err(NatsrpyError::NotInitialized)
212212
}
213-
})?)
213+
})
214214
}
215215

216216
#[pyo3(signature = (
@@ -233,10 +233,10 @@ impl NatsCls {
233233
concurrency_limit: Option<usize>,
234234
max_ack_inflight: Option<usize>,
235235
backpressure_on_inflight: Option<bool>,
236-
) -> PyResult<Bound<'py, PyAny>> {
236+
) -> NatsrpyResult<Bound<'py, PyAny>> {
237237
log::debug!("Creating JetStream context");
238238
let session = self.nats_session.clone();
239-
Ok(natsrpy_future(py, async move {
239+
natsrpy_future(py, async move {
240240
let mut builder =
241241
async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit);
242242
if let Some(timeout) = ack_timeout {
@@ -269,13 +269,13 @@ impl NatsCls {
269269
Ok(crate::js::jetstream::JetStream::new(js))
270270
},
271271
)
272-
})?)
272+
})
273273
}
274274

275-
pub fn shutdown<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
275+
pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
276276
log::debug!("Closing nats session");
277277
let session = self.nats_session.clone();
278-
Ok(natsrpy_future(py, async move {
278+
natsrpy_future(py, async move {
279279
let mut write_guard = session.write().await;
280280
let Some(session) = write_guard.as_ref() else {
281281
return Err(NatsrpyError::NotInitialized);
@@ -284,20 +284,20 @@ impl NatsCls {
284284
*write_guard = None;
285285
drop(write_guard);
286286
Ok(())
287-
})?)
287+
})
288288
}
289289

290-
pub fn flush<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
290+
pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
291291
log::debug!("Flushing streams");
292292
let session = self.nats_session.clone();
293-
Ok(natsrpy_future(py, async move {
293+
natsrpy_future(py, async move {
294294
if let Some(session) = session.write().await.as_ref() {
295295
session.flush().await?;
296296
Ok(())
297297
} else {
298298
Err(NatsrpyError::NotInitialized)
299299
}
300-
})?)
300+
})
301301
}
302302
}
303303

src/subscription.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tokio::sync::Mutex;
66

77
use crate::{
88
exceptions::rust_err::{NatsrpyError, NatsrpyResult},
9-
utils::{futures::natsrpy_future_with_timeout, py_types::TimeValue},
9+
utils::{futures::natsrpy_future_with_timeout, natsrpy_future, py_types::TimeValue},
1010
};
1111

1212
#[pyo3::pyclass]
@@ -36,7 +36,7 @@ impl Subscription {
3636
timeout: Option<TimeValue>,
3737
) -> NatsrpyResult<Bound<'py, PyAny>> {
3838
let Some(inner) = self.inner.clone() else {
39-
return Err(NatsrpyError::NotInitialized);
39+
unreachable!("Subscription used after del")
4040
};
4141
natsrpy_future_with_timeout(py, timeout, async move {
4242
let Some(message) = inner.lock().await.next().await else {
@@ -50,6 +50,35 @@ impl Subscription {
5050
pub fn __anext__<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
5151
self.next(py, None)
5252
}
53+
54+
#[pyo3(signature=(limit=None))]
55+
pub fn unsubscribe<'py>(
56+
&self,
57+
py: Python<'py>,
58+
limit: Option<u64>,
59+
) -> NatsrpyResult<Bound<'py, PyAny>> {
60+
let Some(inner) = self.inner.clone() else {
61+
unreachable!("Subscription used after del")
62+
};
63+
natsrpy_future(py, async move {
64+
if let Some(limit) = limit {
65+
inner.lock().await.unsubscribe_after(limit).await?;
66+
} else {
67+
inner.lock().await.unsubscribe().await?;
68+
}
69+
Ok(())
70+
})
71+
}
72+
73+
pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult<Bound<'py, PyAny>> {
74+
let Some(inner) = self.inner.clone() else {
75+
unreachable!("Subscription used after del")
76+
};
77+
natsrpy_future(py, async move {
78+
inner.lock().await.drain().await?;
79+
Ok(())
80+
})
81+
}
5382
}
5483

5584
/// This is required only because

0 commit comments

Comments
 (0)