diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..c73122c07 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -34,10 +34,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = context::current(); + let mut context2 = context::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 67836e069..3a907795c 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).unwrap().sample(&mut rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..55ec2730e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::Context, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..d38492bd7 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut context::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut context::Context) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: context::Context, + _: &mut context::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut context::Context) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + async fn foo(self, _: &mut context::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..e703cc676 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +134,9 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client + .hello(&mut context::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..c99825d08 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut Context) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..4e132616f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,15 +263,21 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); + for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + let mut context = context::current(); + client + .receive(&mut context, topic.clone(), message.clone()) + .await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -342,26 +348,30 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut context::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..c00c270f0 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +46,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..d81ea74a1 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + async fn ping(self, _: &mut Context) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..1bace43ce 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut context::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,9 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); + let mut ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut ctx, 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..9ef7a1acb 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +153,10 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: context::Context { + deadline: ctx.deadline, + trace_context: ctx.trace_context, + }, span, request_id, request, @@ -881,7 +884,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel.call(&mut current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..14b6edf30 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,8 +24,11 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut context::Context, + request: Self::Req, + ) -> Result; } impl Stub for Channel @@ -35,7 +38,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +49,11 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..6c0f7b0df 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..6f0540797 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::Context, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..18c84f25f 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index fec15ac14..8f8ff128d 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,7 +21,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..17a06ec57 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,8 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -279,7 +280,7 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..e08365964 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -102,10 +106,12 @@ impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +119,18 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -360,10 +368,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -399,12 +408,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -748,11 +758,12 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> @@ -855,11 +866,11 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// @@ -875,7 +886,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,7 +894,7 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -977,6 +988,7 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use tracing_subscriber::filter::FilterExt; fn test_channel() -> ( Pin, Response>>>>, @@ -1039,8 +1051,8 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1060,14 +1072,17 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) + let serve = serve(move |ctx: &mut context::Context, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1101,21 +1116,21 @@ mod tests { } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1320,7 +1335,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..eddf3794e 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -63,9 +63,10 @@ where /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..64b97453a 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,11 +43,11 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { @@ -58,7 +58,8 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn before(self, hook: Hook) -> HookThenServe @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,15 +94,15 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) + /// }.boxed()) /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -151,8 +152,9 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..e2c49b2f1 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -59,14 +59,14 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..ad04cc784 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -87,13 +87,13 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Self::Req, ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +103,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +120,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -209,8 +210,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = context::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..e06f34113 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,13 +46,13 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..e064e6813 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,13 +191,16 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) + .execute(serve(|_ctx, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() })) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); @@ -206,8 +209,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..e051b434e 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..f3adda2fb 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +38,12 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +57,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::Context) { loop { futures::pending!(); } @@ -73,7 +75,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +114,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +147,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -169,12 +171,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -195,9 +200,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -225,8 +234,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::current(); + let mut context2 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,7 +257,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + async fn count(self, _: &mut context::Context) -> u32 { self.0 += 1; self.0 } @@ -262,8 +274,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) }