Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 23 additions & 42 deletions common/src/typed_socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,12 @@
}

#[derive(Clone)]
pub struct TypedSocketSender<T: ChannelMessage> {
sender: Sender<SocketAction<T>>,
pub struct TypedSocketSender<A> {
inner_send:
Arc<dyn Fn(SocketAction<A>) -> Result<(), TypedSocketError> + 'static + Send + Sync>,
}

#[derive(Clone)]
pub struct WrappedTypedSocketSender<K> {
send: Arc<dyn Fn(K) -> Result<(), TypedSocketError> + 'static + Send + Sync>,
}

impl<K> WrappedTypedSocketSender<K> {
pub fn new<T: ChannelMessage, F>(sender: Sender<SocketAction<T>>, transform: F) -> Self
where
F: (Fn(K) -> T) + 'static + Send + Sync,
{
Self {
send: Arc::new(move |message| {
sender
.try_send(SocketAction::Send(transform(message)))
.map_err(TypedSocketError::from)
}),
}
}

pub fn send(&self, message: K) -> Result<(), TypedSocketError> {
(self.send)(message)
}
}

impl<T: ChannelMessage> Debug for TypedSocketSender<T> {
impl<T> Debug for TypedSocketSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("typed socket sender")
}
Expand All @@ -77,28 +54,20 @@
}
}

impl<T: ChannelMessage> TypedSocketSender<T> {
pub fn send(&self, message: T) -> Result<(), TypedSocketError> {
self.sender.try_send(SocketAction::Send(message))?;
impl<A: Debug> TypedSocketSender<A> {
pub fn send(&self, message: A) -> Result<(), TypedSocketError> {
(self.inner_send)(SocketAction::Send(message))?;
Ok(())
}

pub fn close(&mut self) -> Result<(), TypedSocketError> {
self.sender.try_send(SocketAction::Close)?;
(self.inner_send)(SocketAction::Close)?;
Ok(())
}

/// Wrap the sender with a transform function.
pub fn wrap<K, F>(&self, transform: F) -> WrappedTypedSocketSender<K>
where
F: (Fn(K) -> T) + 'static + Send + Sync,
{
WrappedTypedSocketSender::new(self.sender.clone(), transform)
}
}

impl<T: ChannelMessage> TypedSocket<T> {
pub fn send(&mut self, message: T) -> Result<(), PlaneClientError> {

Check warning on line 70 in common/src/typed_socket/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

the `Err`-variant returned from this function is very large

warning: the `Err`-variant returned from this function is very large --> common/src/typed_socket/mod.rs:70:43 | 70 | pub fn send(&mut self, message: T) -> Result<(), PlaneClientError> { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ::: common/src/lib.rs:52:5 | 52 | Tungstenite(#[from] tokio_tungstenite::tungstenite::Error), | ---------------------------------------------------------- the largest variant contains at least 136 bytes | = help: try reducing the size of `PlaneClientError`, for example by boxing large elements or replacing it with `Box<PlaneClientError>` = help: for further information visit https://rust-lang.github.io/rust-clippy/rust-1.92.0/index.html#result_large_err
self.send
.try_send(SocketAction::Send(message))
.map_err(|_| PlaneClientError::SendFailed)?;
Expand All @@ -109,10 +78,22 @@
self.recv.recv().await
}

pub fn sender(&self) -> TypedSocketSender<T> {
pub fn sender<A, F>(&self, transform: F) -> TypedSocketSender<A>
where
F: (Fn(A) -> T) + 'static + Send + Sync,
{
let sender = self.send.clone();

TypedSocketSender { sender }
let inner_send = move |message: SocketAction<A>| {
let message = match message {
SocketAction::Close => SocketAction::Close,
SocketAction::Send(message) => SocketAction::Send(transform(message)),
};
sender.try_send(message).map_err(|e| e.into())
};

TypedSocketSender {
inner_send: Arc::new(inner_send),
}
}

pub async fn close(&mut self) {
Expand Down
2 changes: 0 additions & 2 deletions plane/plane-tests/tests/cert_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::sync::Arc;
mod common;

#[plane_test]
#[ignore = "Doesn't work"]
async fn cert_manager_does_refresh(env: TestEnvironment) {
let controller = env.controller().await;

Expand Down Expand Up @@ -57,7 +56,6 @@ async fn cert_manager_does_refresh(env: TestEnvironment) {
}

#[plane_test(500)]
#[ignore = "Doesn't work"]
async fn cert_manager_does_refresh_eab(env: TestEnvironment) {
let certs_dir = env.scratch_dir.join("certs");

Expand Down
2 changes: 1 addition & 1 deletion plane/plane-tests/tests/proxy_cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,6 @@ async fn proxy_valid_request_has_cors_headers(env: TestEnvironment) {
.unwrap()
.to_str()
.unwrap(),
"*, Authorization"
"authorization, accept, content-type"
);
}
49 changes: 18 additions & 31 deletions plane/src/controller/drone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use plane_common::{
ApiErrorKind, BackendAction, BackendActionMessage, Heartbeat, KeyDeadlines,
MessageFromDrone, MessageToDrone, RenewKeyResponse,
},
typed_socket::{server::new_server, TypedSocketSender},
typed_socket::{server::new_server, TypedSocket},
types::{
backend_state::TerminationReason, ClusterName, DronePoolName, NodeId, TerminationKind,
},
Expand All @@ -22,15 +22,12 @@ use std::{
};
use valuable::Valuable;

use crate::{
database::{
backend_key::{
KEY_LEASE_HARD_TERMINATE_AFTER, KEY_LEASE_RENEW_AFTER, KEY_LEASE_SOFT_TERMINATE_AFTER,
},
subscribe::Subscription,
PlaneDatabase,
use crate::database::{
backend_key::{
KEY_LEASE_HARD_TERMINATE_AFTER, KEY_LEASE_RENEW_AFTER, KEY_LEASE_SOFT_TERMINATE_AFTER,
},
util::GuardHandle,
subscribe::Subscription,
PlaneDatabase,
};

use super::{core::Controller, error::IntoApiError};
Expand All @@ -44,7 +41,7 @@ pub async fn handle_message_from_drone(
msg: MessageFromDrone,
drone_id: NodeId,
controller: &Controller,
sender: TypedSocketSender<MessageToDrone>,
sender: &mut TypedSocket<MessageToDrone>,
) -> anyhow::Result<()> {
match msg {
MessageFromDrone::BackendMetrics(metrics_msg) => {
Expand Down Expand Up @@ -149,7 +146,7 @@ pub async fn sweep_loop(db: PlaneDatabase, drone_id: NodeId) {

pub async fn process_pending_actions(
db: &PlaneDatabase,
socket: &mut TypedSocketSender<MessageToDrone>,
socket: &mut TypedSocket<MessageToDrone>,
drone_id: &NodeId,
) -> Result<(), anyhow::Error> {
let mut count = 0;
Expand Down Expand Up @@ -203,25 +200,20 @@ pub async fn drone_socket_inner(
let mut backend_actions: Subscription<BackendActionMessage> =
controller.db.subscribe_with_key(&drone_id.to_string());

process_pending_actions(&controller.db, &mut socket.sender(), &drone_id).await?;
process_pending_actions(&controller.db, &mut socket, &drone_id).await?;

let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

let mut log_interval = tokio::time::interval(Duration::from_secs(60));
log_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut message_counts: HashMap<&'static str, u64> = HashMap::new();

let mut sender = socket.sender();
let db = controller.db.clone();
let _pending_actions_handle = GuardHandle::new(async move {
loop {
if let Err(err) = process_pending_actions(&db, &mut sender, &drone_id).await {
tracing::error!(?err, "Error processing pending actions");
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
});

loop {
tokio::select! {
_ = interval.tick() => {
process_pending_actions(&controller.db, &mut socket, &drone_id).await?;
}
_ = log_interval.tick() => {
let (outgoing, incoming) = socket.channel_depths();
tracing::info!(
Expand Down Expand Up @@ -271,14 +263,9 @@ pub async fn drone_socket_inner(
*message_counts.entry("backend_metrics").or_insert(0) += 1;
}
}

let sender = socket.sender();
let controller = controller.clone();
tokio::spawn(async move {
if let Err(err) = handle_message_from_drone(message_from_drone, drone_id, &controller, sender).await {
tracing::error!(?err, "Error handling message from drone");
}
});
if let Err(err) = handle_message_from_drone(message_from_drone, drone_id, &controller, &mut socket).await {
tracing::error!(?err, "Error handling message from drone");
}
}
None => {
tracing::info!("Drone socket closed");
Expand Down
115 changes: 57 additions & 58 deletions plane/src/controller/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use plane_common::{
ApiErrorKind, CertManagerRequest, CertManagerResponse, MessageFromProxy, MessageToProxy,
RouteInfoRequest, RouteInfoResponse,
},
typed_socket::{server::new_server, TypedSocketSender},
typed_socket::{server::new_server, TypedSocket},
types::{BackendState, BearerToken, ClusterName, NodeId},
};
use std::{
Expand All @@ -28,7 +28,7 @@ use valuable::Valuable;
pub async fn handle_route_info_request(
token: BearerToken,
controller: &Controller,
socket: TypedSocketSender<MessageToProxy>,
socket: &mut TypedSocket<MessageToProxy>,
) -> anyhow::Result<()> {
match controller.db.backend().route_info_for_token(&token).await {
// When a proxy requests a route, either:
Expand Down Expand Up @@ -79,58 +79,65 @@ pub async fn handle_route_info_request(
}
}

let socket = socket.wrap(MessageToProxy::RouteInfoResponse);

loop {
// Note: this timeout is arbitrary to avoid a memory leak. Under normal system operation, the critical
// timeout will be that of the backend failing to start. We use a large timeout to avoid it becoming
// the critical timeout when the system is functioning.
let result = match tokio::time::timeout(
std::time::Duration::from_secs(30 * 60 /* 30 minutes */),
sub.next(),
)
.await
{
Ok(Some(result)) => result,
Ok(None) => {
tracing::error!("Event subscription closed!");
break;
}
Err(_) => {
tracing::error!("Timeout waiting for backend state");
break;
}
};
let socket = socket.sender(MessageToProxy::RouteInfoResponse);
tokio::spawn(async move {
loop {
// Note: this timeout is arbitrary to avoid a memory leak. Under normal system operation, the critical
// timeout will be that of the backend failing to start. We use a large timeout to avoid it becoming
// the critical timeout when the system is functioning.
let result = match tokio::time::timeout(
std::time::Duration::from_secs(30 * 60 /* 30 minutes */),
sub.next(),
)
.await
{
Ok(Some(result)) => result,
Ok(None) => {
tracing::error!("Event subscription closed!");
break;
}
Err(_) => {
tracing::error!("Timeout waiting for backend state");
break;
}
};

let Notification { payload, .. } = result;
let Notification { payload, .. } = result;

match payload {
BackendState::Ready { address } => {
let route_info = partial_route_info.set_address(address);
let response = RouteInfoResponse {
token,
route_info: Some(route_info),
};
if let Err(err) = socket.send(response) {
tracing::error!(?err, "Error sending route info response to proxy.");
match payload {
BackendState::Ready { address } => {
let route_info = partial_route_info.set_address(address);
let response = RouteInfoResponse {
token,
route_info: Some(route_info),
};
if let Err(err) = socket.send(response) {
tracing::error!(
?err,
"Error sending route info response to proxy."
);
}
break;
}
break;
}
BackendState::Terminated { .. }
| BackendState::Terminating { .. }
| BackendState::HardTerminating { .. } => {
let response = RouteInfoResponse {
token,
route_info: None,
};
if let Err(err) = socket.send(response) {
tracing::error!(?err, "Error sending route info response to proxy.");
BackendState::Terminated { .. }
| BackendState::Terminating { .. }
| BackendState::HardTerminating { .. } => {
let response = RouteInfoResponse {
token,
route_info: None,
};
if let Err(err) = socket.send(response) {
tracing::error!(
?err,
"Error sending route info response to proxy."
);
}
break;
}
break;
_ => {}
}
_ => {}
}
}
});
}
Ok(RouteInfoResult::NotFound) => {
let response = RouteInfoResponse {
Expand All @@ -152,7 +159,7 @@ pub async fn handle_route_info_request(
pub async fn handle_message_from_proxy(
message: MessageFromProxy,
controller: &Controller,
socket: TypedSocketSender<MessageToProxy>,
socket: &mut TypedSocket<MessageToProxy>,
cluster: &ClusterName,
node_id: NodeId,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -294,15 +301,7 @@ pub async fn proxy_socket_inner(
*message_counts.entry("cert_manager_request").or_insert(0) += 1;
}
}

let sender = socket.sender();
let controller = controller.clone();
let cluster = cluster.clone();
tokio::spawn(async move {
if let Err(err) = handle_message_from_proxy(message, &controller, sender, &cluster, node_guard.id).await {
tracing::error!(?err, "Error handling message from proxy");
}
});
handle_message_from_proxy(message, &controller, &mut socket, &cluster, node_guard.id).await?
}
None => {
tracing::info!("Proxy socket closed");
Expand Down
6 changes: 2 additions & 4 deletions plane/src/drone/heartbeat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::heartbeat_consts::HEARTBEAT_INTERVAL;
use chrono::Utc;
use plane_common::{
log_types::LoggableTime, protocol::Heartbeat, typed_socket::WrappedTypedSocketSender,
};
use plane_common::{log_types::LoggableTime, protocol::Heartbeat, typed_socket::TypedSocketSender};
use tokio::task::JoinHandle;

/// A background task that sends heartbeats to the server.
Expand All @@ -11,7 +9,7 @@ pub struct HeartbeatLoop {
}

impl HeartbeatLoop {
pub fn start(sender: WrappedTypedSocketSender<Heartbeat>) -> Self {
pub fn start(sender: TypedSocketSender<Heartbeat>) -> Self {
let handle = tokio::spawn(async move {
loop {
let local_time = LoggableTime(Utc::now());
Expand Down
Loading
Loading