From 14b4ea1b80bbfd2bebcb2a5bb77c0c97d38c200c Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sun, 8 Mar 2026 12:14:59 +0100 Subject: [PATCH 1/2] rename tlru module --- src/main.rs | 2 +- src/ratelimiter_map.rs | 4 ++-- src/{expiring_lru.rs => tlru.rs} | 28 ++++++++++++++-------------- 3 files changed, 17 insertions(+), 17 deletions(-) rename src/{expiring_lru.rs => tlru.rs} (92%) diff --git a/src/main.rs b/src/main.rs index d32af9b..de6fa74 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ mod error; -mod expiring_lru; mod ratelimiter_map; +mod tlru; use error::RequestError; use http::{HeaderMap, HeaderValue, Method as HttpMethod, Uri, header}; diff --git a/src/ratelimiter_map.rs b/src/ratelimiter_map.rs index 7c82f21..6fe4510 100644 --- a/src/ratelimiter_map.rs +++ b/src/ratelimiter_map.rs @@ -1,11 +1,11 @@ -use crate::expiring_lru::{Builder, ExpiringLru}; +use crate::tlru::{Builder, Tlru}; use tokio::time::Duration; use twilight_http_ratelimiting::RateLimiter; pub struct RatelimiterMap { default: RateLimiter, default_token: String, - inner: ExpiringLru, + inner: Tlru, } impl RatelimiterMap { diff --git a/src/expiring_lru.rs b/src/tlru.rs similarity index 92% rename from src/expiring_lru.rs rename to src/tlru.rs index c97229d..6bdad6b 100644 --- a/src/expiring_lru.rs +++ b/src/tlru.rs @@ -86,7 +86,7 @@ async fn decay_task( } }, else => { - // Channel has been closed by the other end, i.e. the ExpiringLru has + // Channel has been closed by the other end, i.e. the Tlru has // been dropped. break; } @@ -100,13 +100,13 @@ enum TimerUpdate { RemoveLru, } -pub struct ExpiringLru { +pub struct Tlru { inner: Arc>>, decay_tx: UnboundedSender>, max_size: Option, } -impl ExpiringLru +impl Tlru where K: Eq + Hash + Clone + Send + Sync + 'static, V: Send + Sync + 'static, @@ -194,8 +194,8 @@ where self } - pub fn build(self) -> ExpiringLru { - ExpiringLru::new(self.expiration, self.max_size) + pub fn build(self) -> Tlru { + Tlru::new(self.expiration, self.max_size) } } @@ -205,13 +205,13 @@ mod tests { use tokio::time::{Duration, sleep}; #[tokio::test(start_paused = true)] - async fn test_lru() { - let lru = Builder::new() + async fn tlru() { + let tlru = Builder::new() .expiration(Duration::from_secs(1)) .max_size(2) .build(); - lru.insert(1, 2); + tlru.insert(1, 2); // The actual LRU cache insert is performed in a different // task and insert will return pre-emptively after notifying @@ -223,15 +223,15 @@ mod tests { tokio::task::yield_now().await; { - let entry = lru.get(&1).unwrap(); + let entry = tlru.get(&1).unwrap(); assert_eq!(entry.value(), &2); } sleep(Duration::from_secs(2)).await; - assert!(lru.get(&1).is_none()); + assert!(tlru.get(&1).is_none()); for i in 2..5 { - lru.insert(i, 0); + tlru.insert(i, 0); // If we insert instantly after another, // upon inserting 4 it will remove either 2 or 3, @@ -241,8 +241,8 @@ mod tests { sleep(Duration::from_millis(50)).await; } - assert_eq!(lru.len(), 2); - assert!(lru.get(&2).is_none()); - assert!(lru.get(&4).is_some()); + assert_eq!(tlru.len(), 2); + assert!(tlru.get(&2).is_none()); + assert!(tlru.get(&4).is_some()); } } From 74dc6f5796ea95a21580df49b1b4618268c3d5f1 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Tue, 19 May 2026 12:17:24 +0200 Subject: [PATCH 2/2] fix tlru race condition The previous implementation would crash when accessing entries that were about to be removed. --- Cargo.lock | 64 +--------- Cargo.toml | 1 - src/main.rs | 3 +- src/ratelimiter_map.rs | 19 ++- src/tlru.rs | 263 ++++++++++++++--------------------------- 5 files changed, 101 insertions(+), 249 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 395b844..c333523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,20 +123,6 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "dunce" version = "1.0.5" @@ -273,12 +259,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - [[package]] name = "hashbrown" version = "0.16.1" @@ -401,7 +381,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown", ] [[package]] @@ -442,15 +422,6 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - [[package]] name = "log" version = "0.4.29" @@ -496,7 +467,7 @@ dependencies = [ "aho-corasick", "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.16.1", + "hashbrown", "indexmap", "metrics", "ordered-float", @@ -560,19 +531,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-link", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -696,15 +654,6 @@ dependencies = [ "bitflags", ] -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags", -] - [[package]] name = "ring" version = "0.17.14" @@ -760,12 +709,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "sharded-slab" version = "0.1.7" @@ -989,7 +932,6 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" name = "twilight-http-proxy" version = "0.1.0" dependencies = [ - "dashmap", "http", "http-body-util", "hyper", @@ -1011,7 +953,7 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0515b0c30814068a7540fcb5f58b634259ca453fa335d42c3b2c8f2b06ac6a59" dependencies = [ - "hashbrown 0.16.1", + "hashbrown", "tokio", "tokio-util", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 2117e7c..53dc748 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ name = "twilight-http-proxy" version = "0.1.0" [dependencies] -dashmap = "6" http = "1" http-body-util = "0.1" hyper = { version = "1", default-features = false } diff --git a/src/main.rs b/src/main.rs index de6fa74..1006373 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,7 @@ use std::{ env, error::Error, net::{Ipv4Addr, SocketAddrV4}, + num::NonZero, pin::pin, str::FromStr, sync::Arc, @@ -87,7 +88,7 @@ async fn main() -> Result<(), Box> { let ratelimiter_map = Arc::new(RatelimiterMap::new( env::var("DISCORD_TOKEN")?, Duration::from_secs(parse_env("CLIENT_DECAY_TIMEOUT")?.unwrap_or(3600)), - parse_env("CLIENT_CACHE_MAX_SIZE")?, + parse_env("CLIENT_CACHE_MAX_SIZE")?.unwrap_or(NonZero::::MAX), )); let host = parse_env("HOST")?.unwrap_or(Ipv4Addr::UNSPECIFIED); diff --git a/src/ratelimiter_map.rs b/src/ratelimiter_map.rs index 6fe4510..a2f7fa1 100644 --- a/src/ratelimiter_map.rs +++ b/src/ratelimiter_map.rs @@ -1,4 +1,5 @@ -use crate::tlru::{Builder, Tlru}; +use crate::tlru::Tlru; +use std::num::NonZero; use tokio::time::Duration; use twilight_http_ratelimiting::RateLimiter; @@ -9,7 +10,7 @@ pub struct RatelimiterMap { } impl RatelimiterMap { - pub fn new(mut default_token: String, timeout: Duration, max_size: Option) -> Self { + pub fn new(mut default_token: String, timeout: Duration, cap: NonZero) -> Self { let is_bot = default_token.starts_with("Bot "); let is_bearer = default_token.starts_with("Bearer "); @@ -19,13 +20,7 @@ impl RatelimiterMap { default_token.insert_str(0, "Bot "); } - let mut builder = Builder::new().expiration(timeout); - - if let Some(max_size) = max_size { - builder = builder.max_size(max_size); - } - - let inner = builder.build(); + let inner = Tlru::new(cap, timeout); let default = RateLimiter::default(); @@ -41,13 +36,13 @@ impl RatelimiterMap { if token == self.default_token { (self.default.clone(), self.default_token.clone()) } else if let Some(entry) = self.inner.get(token) { - (entry.value().clone(), token.to_string()) + (entry, token.to_owned()) } else { let ratelimiter = RateLimiter::default(); - self.inner.insert(token.to_string(), ratelimiter.clone()); + self.inner.insert(token.to_owned(), ratelimiter.clone()); - (ratelimiter, token.to_string()) + (ratelimiter, token.to_owned()) } } else { (self.default.clone(), self.default_token.clone()) diff --git a/src/tlru.rs b/src/tlru.rs index 6bdad6b..b4763a0 100644 --- a/src/tlru.rs +++ b/src/tlru.rs @@ -1,230 +1,145 @@ -use dashmap::{DashMap, mapref::one::Ref}; use std::{ - borrow::Borrow, future::poll_fn, hash::Hash, marker::PhantomData, ops::Deref, sync::Arc, + borrow::Borrow, + collections::HashMap, + future::poll_fn, + hash::Hash, + num::NonZero, + sync::{Arc, Mutex}, + task::{Context, Poll, ready}, time::Duration, }; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tokio_util::time::{DelayQueue, delay_queue::Key}; -use tracing::debug; +use tokio::sync::Notify; +use tokio_util::time::{DelayQueue, delay_queue}; -pub struct Entry { - inner: V, - decay_key: Key, -} - -pub struct EntryRef<'a, K, V>(Ref<'a, K, Entry>); - -impl EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - pub fn value(&self) -> &V { - &self.0.value().inner - } -} - -impl AsRef for EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - fn as_ref(&self) -> &V { - &self.0.value().inner - } -} - -impl Deref for EntryRef<'_, K, V> -where - K: Eq + Hash, -{ - type Target = V; - - fn deref(&self) -> &Self::Target { - &self.0.value().inner - } -} - -async fn decay_task( - map: Arc>>, - expiration: Duration, - mut rx: UnboundedReceiver>, -) where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, -{ - let mut queue = DelayQueue::new(); - - loop { - tokio::select! { - Some(key) = poll_fn(|cx| queue.poll_expired(cx)), if !queue.is_empty() => { - // An item expired in the queue, remove it from the map - debug!("Removing expired entry from ratelimiter decay queue"); - map.remove(key.get_ref()); - } - Some(msg) = rx.recv() => { - match msg { - TimerUpdate::Add { key, value } => { - debug!("Adding entry to ratelimiter decay queue"); - let decay_key = queue.insert(key.clone(), expiration); - let entry = Entry { - inner: value, - decay_key, - }; - map.insert(key, entry); - }, - TimerUpdate::Refresh { key } => { - debug!("Refreshing entry in ratelimiter decay queue"); - // This will panic if the key is not present, therefore - // we check that in the calling end - queue.reset(&key, expiration); - }, - TimerUpdate::RemoveLru => { - debug!("Removing least recently used item from ratelimiter decay queue"); - if let Some(expired) = queue.peek().and_then(|key| queue.try_remove(&key)) { - map.remove(expired.get_ref()); - } - } - } - }, - else => { - // Channel has been closed by the other end, i.e. the Tlru has - // been dropped. - break; - } - }; - } +/// A time-aware least recently used cache. +/// +/// Entries are removed once their time-to-use (TTU) expires. +pub struct Tlru { + cap: NonZero, + inner: Arc>>, + notify: Arc, + ttu: Duration, } -enum TimerUpdate { - Add { key: K, value: V }, - Refresh { key: Key }, - RemoveLru, +// INVARIANT: both collections contain the same keys. +struct TlruInner { + entries: HashMap>, + expirations: DelayQueue, } -pub struct Tlru { - inner: Arc>>, - decay_tx: UnboundedSender>, - max_size: Option, +#[derive(Clone)] +struct Entry { + expiration: delay_queue::Key, + value: V, } impl Tlru where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, + K: Clone + Eq + Hash + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, { - fn new(expiration: Duration, max_size: Option) -> Self { - let inner = Arc::new(DashMap::new()); - let (decay_tx, decay_rx) = unbounded_channel(); - - let this = Self { - inner: inner.clone(), - decay_tx, - max_size, + /// Creates an empty `Tlru`. + pub fn new(cap: NonZero, ttu: Duration) -> Self { + let inner = TlruInner { + entries: HashMap::new(), + expirations: DelayQueue::new(), }; + let inner = Arc::new(Mutex::new(inner)); + let notify = Arc::new(Notify::new()); - tokio::spawn(decay_task(inner, expiration, decay_rx)); + tokio::spawn(reaper(Arc::clone(¬ify), Arc::clone(&inner))); - this + Self { + cap, + inner, + notify, + ttu, + } } + /// Inserts a key-value pair into the cache. + /// + /// If the cache is full, the least recently used entry is replaced. pub fn insert(&self, key: K, value: V) { - match self.max_size { - Some(0) => return, - Some(max_size) if self.len() >= max_size => { - self.remove_lru(); - } - _ => {} + let mut guard = self.inner.lock().unwrap(); + // Notify `reaper` that the cache is non-empty. + self.notify.notify_waiters(); + + if self.cap.get() == guard.len() { + guard.remove_lru(); } - _ = self.decay_tx.send(TimerUpdate::Add { key, value }); + let expiration = guard.expirations.insert(key.clone(), self.ttu); + guard.entries.insert(key, Entry { expiration, value }); } - pub fn get(&self, key: &Q) -> Option> + /// Returns the value corresponding to the key. + pub fn get(&self, key: &Q) -> Option where K: Borrow, - Q: Hash + Eq + ?Sized, + Q: Eq + Hash, { - let entry = self.inner.get(key)?; - _ = self.decay_tx.send(TimerUpdate::Refresh { - key: entry.decay_key, - }); - - Some(EntryRef(entry)) - } + let mut guard = self.inner.lock().unwrap(); - fn remove_lru(&self) { - _ = self.decay_tx.send(TimerUpdate::RemoveLru); - } + let Entry { expiration, value } = guard.entries.get(key)?.clone(); + guard.expirations.reset(&expiration, self.ttu); - pub fn len(&self) -> usize { - self.inner.len() + Some(value) } } -pub struct Builder { - expiration: Duration, - max_size: Option, - - _marker: PhantomData<(K, V)>, -} - -const DEFAULT_EXPIRATION: Duration = Duration::from_secs(3600); - -impl Builder +impl TlruInner where - K: Eq + Hash + Clone + Send + Sync + 'static, - V: Send + Sync + 'static, + K: Eq + Hash, { - pub const fn new() -> Self { - Self { - expiration: DEFAULT_EXPIRATION, - max_size: None, - _marker: PhantomData, - } + /// Returns the number of entries in the cache. + fn len(&self) -> usize { + debug_assert_eq!(self.entries.len(), self.expirations.len()); + self.entries.len() } - pub const fn expiration(mut self, expiration: Duration) -> Self { - self.expiration = expiration; + /// Attempts to remove an expired entry. + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { + let expired = ready!(self.expirations.poll_expired(cx)); + let entry = expired.map(|expired| self.entries.remove(expired.get_ref()).unwrap().value); - self + Poll::Ready(entry) } - pub const fn max_size(mut self, size: usize) -> Self { - self.max_size = Some(size); + /// Removes the least recently used entry. + fn remove_lru(&mut self) -> Option { + let lru = self.expirations.remove(&self.expirations.peek()?); + let entry = self.entries.remove(lru.get_ref()).unwrap().value; - self + Some(entry) } +} + +async fn reaper(notify: Arc, tlru: Arc>>) { + loop { + let expired = poll_fn(|cx| tlru.lock().unwrap().poll_expired(cx)).await; - pub fn build(self) -> Tlru { - Tlru::new(self.expiration, self.max_size) + if expired.is_none() { + // Wait until the cache is non-empty. + notify.notified().await; + } } } #[cfg(test)] mod tests { - use super::Builder; - use tokio::time::{Duration, sleep}; + use super::*; + use tokio::time::sleep; #[tokio::test(start_paused = true)] async fn tlru() { - let tlru = Builder::new() - .expiration(Duration::from_secs(1)) - .max_size(2) - .build(); + let tlru = Tlru::new(NonZero::new(2).unwrap(), Duration::from_secs(1)); tlru.insert(1, 2); - // The actual LRU cache insert is performed in a different - // task and insert will return pre-emptively after notifying - // the task of the insertion. In order to allow the task to run - // and receive the insertion message, we have to yield back to the - // runtime. The alternative would be making insert asynchronous and - // wait on a oneshot channel, but there is no benefit to that - // for our usecase. - tokio::task::yield_now().await; - { let entry = tlru.get(&1).unwrap(); - assert_eq!(entry.value(), &2); + assert_eq!(entry, 2); } sleep(Duration::from_secs(2)).await; @@ -241,7 +156,7 @@ mod tests { sleep(Duration::from_millis(50)).await; } - assert_eq!(tlru.len(), 2); + assert_eq!(tlru.inner.lock().unwrap().len(), 2); assert!(tlru.get(&2).is_none()); assert!(tlru.get(&4).is_some()); }