diff --git a/Cargo.lock b/Cargo.lock index 395b844..c333523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,20 +123,6 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "dunce" version = "1.0.5" @@ -273,12 +259,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - [[package]] name = "hashbrown" version = "0.16.1" @@ -401,7 +381,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown", ] [[package]] @@ -442,15 +422,6 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - [[package]] name = "log" version = "0.4.29" @@ -496,7 +467,7 @@ dependencies = [ "aho-corasick", "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.16.1", + "hashbrown", "indexmap", "metrics", "ordered-float", @@ -560,19 +531,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-link", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -696,15 +654,6 @@ dependencies = [ "bitflags", ] -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags", -] - [[package]] name = "ring" version = "0.17.14" @@ -760,12 +709,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "sharded-slab" version = "0.1.7" @@ -989,7 +932,6 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" name = "twilight-http-proxy" version = "0.1.0" dependencies = [ - "dashmap", "http", "http-body-util", "hyper", @@ -1011,7 +953,7 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0515b0c30814068a7540fcb5f58b634259ca453fa335d42c3b2c8f2b06ac6a59" dependencies = [ - "hashbrown 0.16.1", + "hashbrown", "tokio", "tokio-util", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 2117e7c..53dc748 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ name = "twilight-http-proxy" version = "0.1.0" [dependencies] -dashmap = "6" http = "1" http-body-util = "0.1" hyper = { version = "1", default-features = false } diff --git a/src/expiring_lru.rs b/src/expiring_lru.rs deleted file mode 100644 index c97229d..0000000 --- a/src/expiring_lru.rs +++ /dev/null @@ -1,248 +0,0 @@ -use dashmap::{DashMap, mapref::one::Ref}; -use std::{ - borrow::Borrow, future::poll_fn, hash::Hash, marker::PhantomData, ops::Deref, sync::Arc, - time::Duration, -}; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tokio_util::time::{DelayQueue, delay_queue::Key}; -use tracing::debug; - -pub struct Entry { - inner: V, - decay_key: Key, -} - -pub struct EntryRef<'a, K, V>(Ref<'a, K, Entry>); - -impl EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - pub fn value(&self) -> &V { - &self.0.value().inner - } -} - -impl AsRef for EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - fn as_ref(&self) -> &V { - &self.0.value().inner - } -} - -impl Deref for EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - type Target = V; - - fn deref(&self) -> &Self::Target { - &self.0.value().inner - } -} - -async fn decay_task( - map: Arc>>, - expiration: Duration, - mut rx: UnboundedReceiver>, -) where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, -{ - let mut queue = DelayQueue::new(); - - loop { - tokio::select! { - Some(key) = poll_fn(|cx| queue.poll_expired(cx)), if !queue.is_empty() => { - // An item expired in the queue, remove it from the map - debug!("Removing expired entry from ratelimiter decay queue"); - map.remove(key.get_ref()); - } - Some(msg) = rx.recv() => { - match msg { - TimerUpdate::Add { key, value } => { - debug!("Adding entry to ratelimiter decay queue"); - let decay_key = queue.insert(key.clone(), expiration); - let entry = Entry { - inner: value, - decay_key, - }; - map.insert(key, entry); - }, - TimerUpdate::Refresh { key } => { - debug!("Refreshing entry in ratelimiter decay queue"); - // This will panic if the key is not present, therefore - // we check that in the calling end - queue.reset(&key, expiration); - }, - TimerUpdate::RemoveLru => { - debug!("Removing least recently used item from ratelimiter decay queue"); - if let Some(expired) = queue.peek().and_then(|key| queue.try_remove(&key)) { - map.remove(expired.get_ref()); - } - } - } - }, - else => { - // Channel has been closed by the other end, i.e. the ExpiringLru has - // been dropped. - break; - } - }; - } -} - -enum TimerUpdate { - Add { key: K, value: V }, - Refresh { key: Key }, - RemoveLru, -} - -pub struct ExpiringLru { - inner: Arc>>, - decay_tx: UnboundedSender>, - max_size: Option, -} - -impl ExpiringLru -where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, -{ - fn new(expiration: Duration, max_size: Option) -> Self { - let inner = Arc::new(DashMap::new()); - let (decay_tx, decay_rx) = unbounded_channel(); - - let this = Self { - inner: inner.clone(), - decay_tx, - max_size, - }; - - tokio::spawn(decay_task(inner, expiration, decay_rx)); - - this - } - - pub fn insert(&self, key: K, value: V) { - match self.max_size { - Some(0) => return, - Some(max_size) if self.len() >= max_size => { - self.remove_lru(); - } - _ => {} - } - - _ = self.decay_tx.send(TimerUpdate::Add { key, value }); - } - - pub fn get(&self, key: &Q) -> Option> - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let entry = self.inner.get(key)?; - _ = self.decay_tx.send(TimerUpdate::Refresh { - key: entry.decay_key, - }); - - Some(EntryRef(entry)) - } - - fn remove_lru(&self) { - _ = self.decay_tx.send(TimerUpdate::RemoveLru); - } - - pub fn len(&self) -> usize { - self.inner.len() - } -} - -pub struct Builder { - expiration: Duration, - max_size: Option, - - _marker: PhantomData<(K, V)>, -} - -const DEFAULT_EXPIRATION: Duration = Duration::from_secs(3600); - -impl Builder -where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, -{ - pub const fn new() -> Self { - Self { - expiration: DEFAULT_EXPIRATION, - max_size: None, - _marker: PhantomData, - } - } - - pub const fn expiration(mut self, expiration: Duration) -> Self { - self.expiration = expiration; - - self - } - - pub const fn max_size(mut self, size: usize) -> Self { - self.max_size = Some(size); - - self - } - - pub fn build(self) -> ExpiringLru { - ExpiringLru::new(self.expiration, self.max_size) - } -} - -#[cfg(test)] -mod tests { - use super::Builder; - use tokio::time::{Duration, sleep}; - - #[tokio::test(start_paused = true)] - async fn test_lru() { - let lru = Builder::new() - .expiration(Duration::from_secs(1)) - .max_size(2) - .build(); - - lru.insert(1, 2); - - // The actual LRU cache insert is performed in a different - // task and insert will return pre-emptively after notifying - // the task of the insertion. In order to allow the task to run - // and receive the insertion message, we have to yield back to the - // runtime. The alternative would be making insert asynchronous and - // wait on a oneshot channel, but there is no benefit to that - // for our usecase. - tokio::task::yield_now().await; - - { - let entry = lru.get(&1).unwrap(); - assert_eq!(entry.value(), &2); - } - - sleep(Duration::from_secs(2)).await; - assert!(lru.get(&1).is_none()); - - for i in 2..5 { - lru.insert(i, 0); - - // If we insert instantly after another, - // upon inserting 4 it will remove either 2 or 3, - // because they were inserted at the same time. - // - // For reproducibility, add a delay. - sleep(Duration::from_millis(50)).await; - } - - assert_eq!(lru.len(), 2); - assert!(lru.get(&2).is_none()); - assert!(lru.get(&4).is_some()); - } -} diff --git a/src/main.rs b/src/main.rs index d32af9b..1006373 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ mod error; -mod expiring_lru; mod ratelimiter_map; +mod tlru; use error::RequestError; use http::{HeaderMap, HeaderValue, Method as HttpMethod, Uri, header}; @@ -22,6 +22,7 @@ use std::{ env, error::Error, net::{Ipv4Addr, SocketAddrV4}, + num::NonZero, pin::pin, str::FromStr, sync::Arc, @@ -87,7 +88,7 @@ async fn main() -> Result<(), Box> { let ratelimiter_map = Arc::new(RatelimiterMap::new( env::var("DISCORD_TOKEN")?, Duration::from_secs(parse_env("CLIENT_DECAY_TIMEOUT")?.unwrap_or(3600)), - parse_env("CLIENT_CACHE_MAX_SIZE")?, + parse_env("CLIENT_CACHE_MAX_SIZE")?.unwrap_or(NonZero::::MAX), )); let host = parse_env("HOST")?.unwrap_or(Ipv4Addr::UNSPECIFIED); diff --git a/src/ratelimiter_map.rs b/src/ratelimiter_map.rs index 7c82f21..a2f7fa1 100644 --- a/src/ratelimiter_map.rs +++ b/src/ratelimiter_map.rs @@ -1,15 +1,16 @@ -use crate::expiring_lru::{Builder, ExpiringLru}; +use crate::tlru::Tlru; +use std::num::NonZero; use tokio::time::Duration; use twilight_http_ratelimiting::RateLimiter; pub struct RatelimiterMap { default: RateLimiter, default_token: String, - inner: ExpiringLru, + inner: Tlru, } impl RatelimiterMap { - pub fn new(mut default_token: String, timeout: Duration, max_size: Option) -> Self { + pub fn new(mut default_token: String, timeout: Duration, cap: NonZero) -> Self { let is_bot = default_token.starts_with("Bot "); let is_bearer = default_token.starts_with("Bearer "); @@ -19,13 +20,7 @@ impl RatelimiterMap { default_token.insert_str(0, "Bot "); } - let mut builder = Builder::new().expiration(timeout); - - if let Some(max_size) = max_size { - builder = builder.max_size(max_size); - } - - let inner = builder.build(); + let inner = Tlru::new(cap, timeout); let default = RateLimiter::default(); @@ -41,13 +36,13 @@ impl RatelimiterMap { if token == self.default_token { (self.default.clone(), self.default_token.clone()) } else if let Some(entry) = self.inner.get(token) { - (entry.value().clone(), token.to_string()) + (entry, token.to_owned()) } else { let ratelimiter = RateLimiter::default(); - self.inner.insert(token.to_string(), ratelimiter.clone()); + self.inner.insert(token.to_owned(), ratelimiter.clone()); - (ratelimiter, token.to_string()) + (ratelimiter, token.to_owned()) } } else { (self.default.clone(), self.default_token.clone()) diff --git a/src/tlru.rs b/src/tlru.rs new file mode 100644 index 0000000..b4763a0 --- /dev/null +++ b/src/tlru.rs @@ -0,0 +1,163 @@ +use std::{ + borrow::Borrow, + collections::HashMap, + future::poll_fn, + hash::Hash, + num::NonZero, + sync::{Arc, Mutex}, + task::{Context, Poll, ready}, + time::Duration, +}; +use tokio::sync::Notify; +use tokio_util::time::{DelayQueue, delay_queue}; + +/// A time-aware least recently used cache. +/// +/// Entries are removed once their time-to-use (TTU) expires. +pub struct Tlru { + cap: NonZero, + inner: Arc>>, + notify: Arc, + ttu: Duration, +} + +// INVARIANT: both collections contain the same keys. +struct TlruInner { + entries: HashMap>, + expirations: DelayQueue, +} + +#[derive(Clone)] +struct Entry { + expiration: delay_queue::Key, + value: V, +} + +impl Tlru +where + K: Clone + Eq + Hash + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + /// Creates an empty `Tlru`. + pub fn new(cap: NonZero, ttu: Duration) -> Self { + let inner = TlruInner { + entries: HashMap::new(), + expirations: DelayQueue::new(), + }; + let inner = Arc::new(Mutex::new(inner)); + let notify = Arc::new(Notify::new()); + + tokio::spawn(reaper(Arc::clone(¬ify), Arc::clone(&inner))); + + Self { + cap, + inner, + notify, + ttu, + } + } + + /// Inserts a key-value pair into the cache. + /// + /// If the cache is full, the least recently used entry is replaced. + pub fn insert(&self, key: K, value: V) { + let mut guard = self.inner.lock().unwrap(); + // Notify `reaper` that the cache is non-empty. + self.notify.notify_waiters(); + + if self.cap.get() == guard.len() { + guard.remove_lru(); + } + + let expiration = guard.expirations.insert(key.clone(), self.ttu); + guard.entries.insert(key, Entry { expiration, value }); + } + + /// Returns the value corresponding to the key. + pub fn get(&self, key: &Q) -> Option + where + K: Borrow, + Q: Eq + Hash, + { + let mut guard = self.inner.lock().unwrap(); + + let Entry { expiration, value } = guard.entries.get(key)?.clone(); + guard.expirations.reset(&expiration, self.ttu); + + Some(value) + } +} + +impl TlruInner +where + K: Eq + Hash, +{ + /// Returns the number of entries in the cache. + fn len(&self) -> usize { + debug_assert_eq!(self.entries.len(), self.expirations.len()); + self.entries.len() + } + + /// Attempts to remove an expired entry. + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { + let expired = ready!(self.expirations.poll_expired(cx)); + let entry = expired.map(|expired| self.entries.remove(expired.get_ref()).unwrap().value); + + Poll::Ready(entry) + } + + /// Removes the least recently used entry. + fn remove_lru(&mut self) -> Option { + let lru = self.expirations.remove(&self.expirations.peek()?); + let entry = self.entries.remove(lru.get_ref()).unwrap().value; + + Some(entry) + } +} + +async fn reaper(notify: Arc, tlru: Arc>>) { + loop { + let expired = poll_fn(|cx| tlru.lock().unwrap().poll_expired(cx)).await; + + if expired.is_none() { + // Wait until the cache is non-empty. + notify.notified().await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::sleep; + + #[tokio::test(start_paused = true)] + async fn tlru() { + let tlru = Tlru::new(NonZero::new(2).unwrap(), Duration::from_secs(1)); + + tlru.insert(1, 2); + + { + let entry = tlru.get(&1).unwrap(); + assert_eq!(entry, 2); + } + + sleep(Duration::from_secs(2)).await; + assert!(tlru.get(&1).is_none()); + + for i in 2..5 { + tlru.insert(i, 0); + + // If we insert instantly after another, + // upon inserting 4 it will remove either 2 or 3, + // because they were inserted at the same time. + // + // For reproducibility, add a delay. + sleep(Duration::from_millis(50)).await; + } + + assert_eq!(tlru.inner.lock().unwrap().len(), 2); + assert!(tlru.get(&2).is_none()); + assert!(tlru.get(&4).is_some()); + } +}