Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 5 additions & 27 deletions tarpc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use pin_project::pin_project;
use std::{
any::Any,
convert::TryFrom,
fmt,
pin::Pin,
Expand Down Expand Up @@ -269,7 +268,6 @@ where
transport: transport.fuse(),
in_flight_requests: InFlightRequests::default(),
pending_requests,
terminal_error: None,
},
}
}
Expand All @@ -291,11 +289,6 @@ pub struct RequestDispatch<Req, Resp, C> {
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// Produces errors that can be sent in response to any unprocessed requests at the time
/// RequestDispatch is dropped. Correctness note: this field should only be populated by
/// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type
/// determined within the poll function.
terminal_error: Option<ChannelError<dyn Any + Send + Sync + 'static>>,
}

impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
Expand Down Expand Up @@ -353,12 +346,6 @@ where
self.as_mut().project().pending_requests
}

fn terminal_error_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut Option<ChannelError<dyn Any + Send + Sync + 'static>> {
self.as_mut().project().terminal_error
}

fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -659,20 +646,13 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
loop {
if let Some(e) = self.terminal_error_mut() {
let result = ready!(self.run(cx));
match result {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => {
tracing::debug!("RpcError::Channel");
let e: ChannelError<C::Error> = e
.clone()
.downcast()
.expect("Invariant: ChannelError must store a C::Error");
ready!(self.shut_down_with_terminal_error(cx, e.clone().upcast_error()));
return Poll::Ready(Err(e));
}
let result = ready!(self.run(cx));
match result {
Ok(()) => return Poll::Ready(Ok(())),
Err(e) => *self.terminal_error_mut() = Some(e.upcast_any()),
Poll::Ready(Err(e))
}
}
}
Expand Down Expand Up @@ -986,7 +966,6 @@ mod tests {
canceled_requests,
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
terminal_error: None,
});
let channel = Channel {
to_dispatch,
Expand Down Expand Up @@ -1082,7 +1061,6 @@ mod tests {
canceled_requests,
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
terminal_error: None,
};

let channel = Channel {
Expand Down
40 changes: 1 addition & 39 deletions tarpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ pub(crate) mod util;

pub use crate::transport::sealed::Transport;

use std::{any::Any, error::Error, io, sync::Arc, time::Instant};
use std::{error::Error, io, sync::Arc, time::Instant};

/// A message from a client to a server.
#[derive(Debug)]
Expand Down Expand Up @@ -445,44 +445,6 @@ where
}
}

impl<E> ChannelError<E>
where
E: Send + Sync + 'static,
{
/// Converts the ChannelError's source error type to a dyn Any. This is useful in type-erased
/// contexts, for example, storing a ChannelError in a non-generic type like
/// [`client::RpcError`].
fn upcast_any(self) -> ChannelError<dyn Any + Send + Sync + 'static> {
use ChannelError::*;
match self {
Read(e) => Read(e),
Ready(e) => Ready(e),
Write(e) => Write(e),
Flush(e) => Flush(e),
Close(e) => Close(e),
}
}
}

impl ChannelError<dyn Any + Send + Sync + 'static> {
/// Converts the ChannelError's source error type to a concrete type. This is useful in
/// type-erased contexts, for example, storing a ChannelError in a non-generic type like
/// [`Client::RpcError`].
fn downcast<E>(self) -> Result<ChannelError<E>, Self>
where
E: Any + Send + Sync,
{
use ChannelError::*;
match self {
Read(e) => e.downcast::<E>().map(Read).map_err(Read),
Ready(e) => e.downcast::<E>().map(Ready).map_err(Ready),
Write(e) => e.downcast::<E>().map(Write).map_err(Write),
Flush(e) => e.downcast::<E>().map(Flush).map_err(Flush),
Close(e) => e.downcast::<E>().map(Close).map_err(Close),
}
}
}

impl ServerError {
/// Returns a new server error with `kind` and `detail`.
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
Expand Down