diff --git a/common/src/typed_socket/mod.rs b/common/src/typed_socket/mod.rs index d2dbe79d..3b747e2b 100644 --- a/common/src/typed_socket/mod.rs +++ b/common/src/typed_socket/mod.rs @@ -26,35 +26,12 @@ pub struct TypedSocket { } #[derive(Clone)] -pub struct TypedSocketSender { - sender: Sender>, +pub struct TypedSocketSender { + inner_send: + Arc) -> Result<(), TypedSocketError> + 'static + Send + Sync>, } -#[derive(Clone)] -pub struct WrappedTypedSocketSender { - send: Arc Result<(), TypedSocketError> + 'static + Send + Sync>, -} - -impl WrappedTypedSocketSender { - pub fn new(sender: Sender>, transform: F) -> Self - where - F: (Fn(K) -> T) + 'static + Send + Sync, - { - Self { - send: Arc::new(move |message| { - sender - .try_send(SocketAction::Send(transform(message))) - .map_err(TypedSocketError::from) - }), - } - } - - pub fn send(&self, message: K) -> Result<(), TypedSocketError> { - (self.send)(message) - } -} - -impl Debug for TypedSocketSender { +impl Debug for TypedSocketSender { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("typed socket sender") } @@ -77,24 +54,16 @@ impl From> for TypedSocketError { } } -impl TypedSocketSender { - pub fn send(&self, message: T) -> Result<(), TypedSocketError> { - self.sender.try_send(SocketAction::Send(message))?; +impl TypedSocketSender { + pub fn send(&self, message: A) -> Result<(), TypedSocketError> { + (self.inner_send)(SocketAction::Send(message))?; Ok(()) } pub fn close(&mut self) -> Result<(), TypedSocketError> { - self.sender.try_send(SocketAction::Close)?; + (self.inner_send)(SocketAction::Close)?; Ok(()) } - - /// Wrap the sender with a transform function. - pub fn wrap(&self, transform: F) -> WrappedTypedSocketSender - where - F: (Fn(K) -> T) + 'static + Send + Sync, - { - WrappedTypedSocketSender::new(self.sender.clone(), transform) - } } impl TypedSocket { @@ -109,10 +78,22 @@ impl TypedSocket { self.recv.recv().await } - pub fn sender(&self) -> TypedSocketSender { + pub fn sender(&self, transform: F) -> TypedSocketSender + where + F: (Fn(A) -> T) + 'static + Send + Sync, + { let sender = self.send.clone(); - - TypedSocketSender { sender } + let inner_send = move |message: SocketAction| { + let message = match message { + SocketAction::Close => SocketAction::Close, + SocketAction::Send(message) => SocketAction::Send(transform(message)), + }; + sender.try_send(message).map_err(|e| e.into()) + }; + + TypedSocketSender { + inner_send: Arc::new(inner_send), + } } pub async fn close(&mut self) { diff --git a/plane/plane-tests/tests/cert_manager.rs b/plane/plane-tests/tests/cert_manager.rs index 7c8b4346..d8d79b34 100644 --- a/plane/plane-tests/tests/cert_manager.rs +++ b/plane/plane-tests/tests/cert_manager.rs @@ -11,7 +11,6 @@ use std::sync::Arc; mod common; #[plane_test] -#[ignore = "Doesn't work"] async fn cert_manager_does_refresh(env: TestEnvironment) { let controller = env.controller().await; @@ -57,7 +56,6 @@ async fn cert_manager_does_refresh(env: TestEnvironment) { } #[plane_test(500)] -#[ignore = "Doesn't work"] async fn cert_manager_does_refresh_eab(env: TestEnvironment) { let certs_dir = env.scratch_dir.join("certs"); diff --git a/plane/plane-tests/tests/proxy_cors.rs b/plane/plane-tests/tests/proxy_cors.rs index 4dcdb0ce..836bc3f2 100644 --- a/plane/plane-tests/tests/proxy_cors.rs +++ b/plane/plane-tests/tests/proxy_cors.rs @@ -126,6 +126,6 @@ async fn proxy_valid_request_has_cors_headers(env: TestEnvironment) { .unwrap() .to_str() .unwrap(), - "*, Authorization" + "authorization, accept, content-type" ); } diff --git a/plane/src/controller/drone.rs b/plane/src/controller/drone.rs index bd8ffc6d..93d1676b 100644 --- a/plane/src/controller/drone.rs +++ b/plane/src/controller/drone.rs @@ -9,7 +9,7 @@ use plane_common::{ ApiErrorKind, BackendAction, BackendActionMessage, Heartbeat, KeyDeadlines, MessageFromDrone, MessageToDrone, RenewKeyResponse, }, - typed_socket::{server::new_server, TypedSocketSender}, + typed_socket::{server::new_server, TypedSocket}, types::{ backend_state::TerminationReason, ClusterName, DronePoolName, NodeId, TerminationKind, }, @@ -22,15 +22,12 @@ use std::{ }; use valuable::Valuable; -use crate::{ - database::{ - backend_key::{ - KEY_LEASE_HARD_TERMINATE_AFTER, KEY_LEASE_RENEW_AFTER, KEY_LEASE_SOFT_TERMINATE_AFTER, - }, - subscribe::Subscription, - PlaneDatabase, +use crate::database::{ + backend_key::{ + KEY_LEASE_HARD_TERMINATE_AFTER, KEY_LEASE_RENEW_AFTER, KEY_LEASE_SOFT_TERMINATE_AFTER, }, - util::GuardHandle, + subscribe::Subscription, + PlaneDatabase, }; use super::{core::Controller, error::IntoApiError}; @@ -44,7 +41,7 @@ pub async fn handle_message_from_drone( msg: MessageFromDrone, drone_id: NodeId, controller: &Controller, - sender: TypedSocketSender, + sender: &mut TypedSocket, ) -> anyhow::Result<()> { match msg { MessageFromDrone::BackendMetrics(metrics_msg) => { @@ -149,7 +146,7 @@ pub async fn sweep_loop(db: PlaneDatabase, drone_id: NodeId) { pub async fn process_pending_actions( db: &PlaneDatabase, - socket: &mut TypedSocketSender, + socket: &mut TypedSocket, drone_id: &NodeId, ) -> Result<(), anyhow::Error> { let mut count = 0; @@ -203,25 +200,20 @@ pub async fn drone_socket_inner( let mut backend_actions: Subscription = controller.db.subscribe_with_key(&drone_id.to_string()); - process_pending_actions(&controller.db, &mut socket.sender(), &drone_id).await?; + process_pending_actions(&controller.db, &mut socket, &drone_id).await?; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut log_interval = tokio::time::interval(Duration::from_secs(60)); log_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut message_counts: HashMap<&'static str, u64> = HashMap::new(); - let mut sender = socket.sender(); - let db = controller.db.clone(); - let _pending_actions_handle = GuardHandle::new(async move { - loop { - if let Err(err) = process_pending_actions(&db, &mut sender, &drone_id).await { - tracing::error!(?err, "Error processing pending actions"); - } - tokio::time::sleep(Duration::from_secs(5)).await; - } - }); - loop { tokio::select! { + _ = interval.tick() => { + process_pending_actions(&controller.db, &mut socket, &drone_id).await?; + } _ = log_interval.tick() => { let (outgoing, incoming) = socket.channel_depths(); tracing::info!( @@ -271,14 +263,9 @@ pub async fn drone_socket_inner( *message_counts.entry("backend_metrics").or_insert(0) += 1; } } - - let sender = socket.sender(); - let controller = controller.clone(); - tokio::spawn(async move { - if let Err(err) = handle_message_from_drone(message_from_drone, drone_id, &controller, sender).await { - tracing::error!(?err, "Error handling message from drone"); - } - }); + if let Err(err) = handle_message_from_drone(message_from_drone, drone_id, &controller, &mut socket).await { + tracing::error!(?err, "Error handling message from drone"); + } } None => { tracing::info!("Drone socket closed"); diff --git a/plane/src/controller/proxy.rs b/plane/src/controller/proxy.rs index 485e27ab..bf4e7a3f 100644 --- a/plane/src/controller/proxy.rs +++ b/plane/src/controller/proxy.rs @@ -14,7 +14,7 @@ use plane_common::{ ApiErrorKind, CertManagerRequest, CertManagerResponse, MessageFromProxy, MessageToProxy, RouteInfoRequest, RouteInfoResponse, }, - typed_socket::{server::new_server, TypedSocketSender}, + typed_socket::{server::new_server, TypedSocket}, types::{BackendState, BearerToken, ClusterName, NodeId}, }; use std::{ @@ -28,7 +28,7 @@ use valuable::Valuable; pub async fn handle_route_info_request( token: BearerToken, controller: &Controller, - socket: TypedSocketSender, + socket: &mut TypedSocket, ) -> anyhow::Result<()> { match controller.db.backend().route_info_for_token(&token).await { // When a proxy requests a route, either: @@ -79,58 +79,65 @@ pub async fn handle_route_info_request( } } - let socket = socket.wrap(MessageToProxy::RouteInfoResponse); - - loop { - // Note: this timeout is arbitrary to avoid a memory leak. Under normal system operation, the critical - // timeout will be that of the backend failing to start. We use a large timeout to avoid it becoming - // the critical timeout when the system is functioning. - let result = match tokio::time::timeout( - std::time::Duration::from_secs(30 * 60 /* 30 minutes */), - sub.next(), - ) - .await - { - Ok(Some(result)) => result, - Ok(None) => { - tracing::error!("Event subscription closed!"); - break; - } - Err(_) => { - tracing::error!("Timeout waiting for backend state"); - break; - } - }; + let socket = socket.sender(MessageToProxy::RouteInfoResponse); + tokio::spawn(async move { + loop { + // Note: this timeout is arbitrary to avoid a memory leak. Under normal system operation, the critical + // timeout will be that of the backend failing to start. We use a large timeout to avoid it becoming + // the critical timeout when the system is functioning. + let result = match tokio::time::timeout( + std::time::Duration::from_secs(30 * 60 /* 30 minutes */), + sub.next(), + ) + .await + { + Ok(Some(result)) => result, + Ok(None) => { + tracing::error!("Event subscription closed!"); + break; + } + Err(_) => { + tracing::error!("Timeout waiting for backend state"); + break; + } + }; - let Notification { payload, .. } = result; + let Notification { payload, .. } = result; - match payload { - BackendState::Ready { address } => { - let route_info = partial_route_info.set_address(address); - let response = RouteInfoResponse { - token, - route_info: Some(route_info), - }; - if let Err(err) = socket.send(response) { - tracing::error!(?err, "Error sending route info response to proxy."); + match payload { + BackendState::Ready { address } => { + let route_info = partial_route_info.set_address(address); + let response = RouteInfoResponse { + token, + route_info: Some(route_info), + }; + if let Err(err) = socket.send(response) { + tracing::error!( + ?err, + "Error sending route info response to proxy." + ); + } + break; } - break; - } - BackendState::Terminated { .. } - | BackendState::Terminating { .. } - | BackendState::HardTerminating { .. } => { - let response = RouteInfoResponse { - token, - route_info: None, - }; - if let Err(err) = socket.send(response) { - tracing::error!(?err, "Error sending route info response to proxy."); + BackendState::Terminated { .. } + | BackendState::Terminating { .. } + | BackendState::HardTerminating { .. } => { + let response = RouteInfoResponse { + token, + route_info: None, + }; + if let Err(err) = socket.send(response) { + tracing::error!( + ?err, + "Error sending route info response to proxy." + ); + } + break; } - break; + _ => {} } - _ => {} } - } + }); } Ok(RouteInfoResult::NotFound) => { let response = RouteInfoResponse { @@ -152,7 +159,7 @@ pub async fn handle_route_info_request( pub async fn handle_message_from_proxy( message: MessageFromProxy, controller: &Controller, - socket: TypedSocketSender, + socket: &mut TypedSocket, cluster: &ClusterName, node_id: NodeId, ) -> anyhow::Result<()> { @@ -294,15 +301,7 @@ pub async fn proxy_socket_inner( *message_counts.entry("cert_manager_request").or_insert(0) += 1; } } - - let sender = socket.sender(); - let controller = controller.clone(); - let cluster = cluster.clone(); - tokio::spawn(async move { - if let Err(err) = handle_message_from_proxy(message, &controller, sender, &cluster, node_guard.id).await { - tracing::error!(?err, "Error handling message from proxy"); - } - }); + handle_message_from_proxy(message, &controller, &mut socket, &cluster, node_guard.id).await? } None => { tracing::info!("Proxy socket closed"); diff --git a/plane/src/drone/heartbeat.rs b/plane/src/drone/heartbeat.rs index 4bae39a8..54d52c24 100644 --- a/plane/src/drone/heartbeat.rs +++ b/plane/src/drone/heartbeat.rs @@ -1,8 +1,6 @@ use crate::heartbeat_consts::HEARTBEAT_INTERVAL; use chrono::Utc; -use plane_common::{ - log_types::LoggableTime, protocol::Heartbeat, typed_socket::WrappedTypedSocketSender, -}; +use plane_common::{log_types::LoggableTime, protocol::Heartbeat, typed_socket::TypedSocketSender}; use tokio::task::JoinHandle; /// A background task that sends heartbeats to the server. @@ -11,7 +9,7 @@ pub struct HeartbeatLoop { } impl HeartbeatLoop { - pub fn start(sender: WrappedTypedSocketSender) -> Self { + pub fn start(sender: TypedSocketSender) -> Self { let handle = tokio::spawn(async move { loop { let local_time = LoggableTime(Utc::now()); diff --git a/plane/src/drone/key_manager.rs b/plane/src/drone/key_manager.rs index 1c9995b6..82344cd8 100644 --- a/plane/src/drone/key_manager.rs +++ b/plane/src/drone/key_manager.rs @@ -5,7 +5,7 @@ use plane_common::{ log_types::LoggableTime, names::BackendName, protocol::{AcquiredKey, BackendAction, KeyDeadlines, RenewKeyRequest}, - typed_socket::WrappedTypedSocketSender, + typed_socket::TypedSocketSender, types::{backend_state::TerminationReason, TerminationKind}, }; use std::{collections::HashMap, sync::Arc, time::Duration}; @@ -20,13 +20,13 @@ pub struct KeyManager { /// and terminating the backend if the key cannot be renewed. handles: HashMap, - sender: Option>, + sender: Option>, } async fn renew_key_loop( key: AcquiredKey, backend: BackendName, - sender: Option>, + sender: Option>, executor: Arc, ) { loop { @@ -120,7 +120,7 @@ impl KeyManager { } } - pub fn set_sender(&mut self, sender: WrappedTypedSocketSender) { + pub fn set_sender(&mut self, sender: TypedSocketSender) { self.sender.replace(sender); for (backend, (acquired_key, handle)) in self.handles.iter_mut() { diff --git a/plane/src/drone/mod.rs b/plane/src/drone/mod.rs index 22eb55c0..f768d111 100644 --- a/plane/src/drone/mod.rs +++ b/plane/src/drone/mod.rs @@ -55,11 +55,10 @@ pub async fn drone_loop( loop { let mut socket = connection.connect_with_retry(&name).await; - let _heartbeat_guard = - HeartbeatLoop::start(socket.sender().wrap(MessageFromDrone::Heartbeat)); + let _heartbeat_guard = HeartbeatLoop::start(socket.sender(MessageFromDrone::Heartbeat)); { - let socket = socket.sender().wrap(MessageFromDrone::BackendMetrics); + let socket = socket.sender(MessageFromDrone::BackendMetrics); executor .runtime .metrics_callback(Box::new(move |metrics_message| { @@ -72,12 +71,12 @@ pub async fn drone_loop( key_manager .lock() .expect("Key manager lock poisoned") - .set_sender(socket.sender().wrap(MessageFromDrone::RenewKey)); + .set_sender(socket.sender(MessageFromDrone::RenewKey)); { // Forward state changes to the socket. // This will start by sending any existing unacked events. - let sender = socket.sender().wrap(MessageFromDrone::BackendEvent); + let sender = socket.sender(MessageFromDrone::BackendEvent); let key_manager = key_manager.clone(); if let Err(err) = executor.register_listener(move |message| { if matches!(message.state, BackendState::Terminated { .. }) { @@ -142,7 +141,7 @@ pub async fn drone_loop( tokio::spawn(handle_message( message, key_manager, - socket.sender(), + socket.sender(|x| x), executor.clone(), )); } diff --git a/plane/src/proxy/proxy_connection.rs b/plane/src/proxy/proxy_connection.rs index a84e1fba..cdda9cc7 100644 --- a/plane/src/proxy/proxy_connection.rs +++ b/plane/src/proxy/proxy_connection.rs @@ -36,14 +36,14 @@ impl ProxyConnection { let mut conn = proxy_connection.connect_with_retry(&name).await; state.set_ready(true); - let sender = conn.sender().wrap(MessageFromProxy::CertManagerRequest); + let sender = conn.sender(MessageFromProxy::CertManagerRequest); cert_manager.set_request_sender(move |m| { if let Err(e) = sender.send(m) { tracing::error!(?e, "Error sending cert manager request."); } }); - let sender = conn.sender().wrap(MessageFromProxy::RouteInfoRequest); + let sender = conn.sender(MessageFromProxy::RouteInfoRequest); state .inner .route_map @@ -52,7 +52,7 @@ impl ProxyConnection { tracing::error!(?e, "Error sending route info request."); } }); - let sender = conn.sender().wrap(MessageFromProxy::KeepAlive); + let sender = conn.sender(MessageFromProxy::KeepAlive); state.inner.monitor.set_listener(move |backend| { if let Err(err) = sender.send(backend.clone()) { tracing::error!(?err, "Error sending keepalive.");