nautilus_network/ratelimiter/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A rate limiter implementation heavily inspired by [governor](https://github.com/antifuchs/governor).
17//!
18//! The governor does not support different quota for different key. It is an open [issue](https://github.com/antifuchs/governor/issues/193).
19pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25    fmt::Debug,
26    hash::Hash,
27    num::NonZeroU64,
28    sync::atomic::{AtomicU64, Ordering},
29    time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34use tokio::time::sleep;
35
36use self::{
37    clock::{Clock, FakeRelativeClock, MonotonicClock},
38    gcra::{Gcra, NotUntil},
39    nanos::Nanos,
40    quota::Quota,
41};
42
43/// An in-memory representation of a GCRA's rate-limiting state.
44///
45/// Implemented using [`AtomicU64`] operations, this state representation can be used to
46/// construct rate limiting states for other in-memory states: e.g., this crate uses
47/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
48///
49/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
50/// nanoseconds since the rate limiter was created.
51#[derive(Debug, Default)]
52pub struct InMemoryState(AtomicU64);
53
54impl InMemoryState {
55    /// Measures and updates the GCRA's state atomically, retrying on concurrent modifications.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the provided closure returns an error.
60    pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
61    where
62        F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
63    {
64        let mut prev = self.0.load(Ordering::Acquire);
65        let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
66        while let Ok((result, new_data)) = decision {
67            match self.0.compare_exchange_weak(
68                prev,
69                new_data.into(),
70                Ordering::Release,
71                Ordering::Relaxed,
72            ) {
73                Ok(_) => return Ok(result),
74                Err(next_prev) => prev = next_prev,
75            }
76            decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
77        }
78        // This map shouldn't be needed, as we only get here in the error case, but the compiler
79        // can't see it.
80        decision.map(|(result, _)| result)
81    }
82}
83
84/// A concurrent, thread-safe and fairly performant hashmap based on [`DashMap`].
85pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
86
87/// A way for rate limiters to keep state.
88///
89/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
90/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
91/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
92/// call budget per client API key).
93///
94/// A direct state store is expressed as [`StateStore::Key`] = `NotKeyed`.
95/// Keyed state stores have a
96/// type parameter for the key and set their key to that.
97pub trait StateStore {
98    /// The type of key that the state store can represent.
99    type Key;
100
101    /// Updates a state store's rate limiting state for a given key, using the given closure.
102    ///
103    /// The closure parameter takes the old value (`None` if this is the first measurement) of the
104    /// state store at the key's location, checks if the request an be accommodated and:
105    ///
106    /// - If the request is rate-limited, returns `Err(E)`.
107    /// - If the request can make it through, returns `Ok(T)` (an arbitrary positive return
108    ///   value) and the updated state.
109    ///
110    /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
111    /// only update the value if the value hasn't changed. The implementations in this
112    /// crate use `AtomicU64` operations for this.
113    ///
114    /// # Errors
115    ///
116    /// Returns `Err(E)` if the closure returns an error or the request is rate-limited.
117    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118    where
119        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
120}
121
122impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
123    type Key = K;
124
125    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
126    where
127        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
128    {
129        if let Some(v) = self.get(key) {
130            // fast path: measure existing entry
131            return v.measure_and_replace_one(f);
132        }
133        // make an entry and measure that:
134        let entry = self.entry(key.clone()).or_default();
135        (*entry).measure_and_replace_one(f)
136    }
137}
138
139/// A rate limiter that enforces different quotas per key using the GCRA algorithm.
140///
141/// This implementation allows setting different rate limits for different keys,
142/// with an optional default quota for keys that don't have specific quotas.
143pub struct RateLimiter<K, C>
144where
145    C: Clock,
146{
147    default_gcra: Option<Gcra>,
148    state: DashMapStateStore<K>,
149    gcra: DashMap<K, Gcra>,
150    clock: C,
151    start: C::Instant,
152}
153
154impl<K, C> Debug for RateLimiter<K, C>
155where
156    K: Debug,
157    C: Clock,
158{
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct(stringify!(RateLimiter)).finish()
161    }
162}
163
164impl<K> RateLimiter<K, MonotonicClock>
165where
166    K: Eq + Hash,
167{
168    /// Creates a new rate limiter with a base quota and keyed quotas.
169    ///
170    /// The base quota applies to all keys that don't have specific quotas.
171    /// Keyed quotas override the base quota for specific keys.
172    #[must_use]
173    pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
174        let clock = MonotonicClock {};
175        let start = MonotonicClock::now(&clock);
176        let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
177        Self {
178            default_gcra: base_quota.map(Gcra::new),
179            state: DashMapStateStore::new(),
180            gcra,
181            clock,
182            start,
183        }
184    }
185}
186
187impl<K> RateLimiter<K, FakeRelativeClock>
188where
189    K: Hash + Eq + Clone,
190{
191    /// Advances the fake clock by the specified duration.
192    ///
193    /// This is only available for testing with `FakeRelativeClock`.
194    pub fn advance_clock(&self, by: Duration) {
195        self.clock.advance(by);
196    }
197}
198
199impl<K, C> RateLimiter<K, C>
200where
201    K: Hash + Eq + Clone,
202    C: Clock,
203{
204    /// Adds or updates a quota for a specific key.
205    pub fn add_quota_for_key(&self, key: K, value: Quota) {
206        self.gcra.insert(key, Gcra::new(value));
207    }
208
209    /// Checks if the given key is allowed under the rate limit.
210    ///
211    /// # Errors
212    ///
213    /// Returns `Err(NotUntil)` if the key is rate-limited, indicating when it will be allowed.
214    pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
215        match self.gcra.get(key) {
216            Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
217            None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
218                gcra.test_and_update(self.start, key, &self.state, self.clock.now())
219            }),
220        }
221    }
222
223    /// Waits until the specified key is ready (not rate-limited).
224    pub async fn until_key_ready(&self, key: &K) {
225        loop {
226            match self.check_key(key) {
227                Ok(()) => {
228                    break;
229                }
230                Err(neg) => {
231                    sleep(neg.wait_time_from(self.clock.now())).await;
232                }
233            }
234        }
235    }
236
237    /// Waits until all specified keys are ready (not rate-limited).
238    ///
239    /// If no keys are provided, this function returns immediately.
240    pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
241        let keys = keys.unwrap_or_default();
242        let tasks = keys.iter().map(|key| self.until_key_ready(key));
243
244        futures::stream::iter(tasks)
245            .for_each_concurrent(None, |key_future| async move {
246                key_future.await;
247            })
248            .await;
249    }
250}
251
252////////////////////////////////////////////////////////////////////////////////
253// Tests
254////////////////////////////////////////////////////////////////////////////////
255#[cfg(test)]
256mod tests {
257    use std::{num::NonZeroU32, time::Duration};
258
259    use dashmap::DashMap;
260    use rstest::rstest;
261
262    use super::{
263        DashMapStateStore, RateLimiter,
264        clock::{Clock, FakeRelativeClock},
265        gcra::Gcra,
266        quota::Quota,
267    };
268
269    fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
270        let clock = FakeRelativeClock::default();
271        let start = clock.now();
272        let gcra = DashMap::new();
273        let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
274        RateLimiter {
275            default_gcra: Some(Gcra::new(base_quota)),
276            state: DashMapStateStore::new(),
277            gcra,
278            clock,
279            start,
280        }
281    }
282
283    #[rstest]
284    fn test_default_quota() {
285        let mock_limiter = initialize_mock_rate_limiter();
286
287        // Check base quota is not exceeded
288        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
289        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
290
291        // Check base quota is exceeded
292        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
293
294        // Increment clock and check base quota is reset
295        mock_limiter.advance_clock(Duration::from_secs(1));
296        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
297    }
298
299    #[rstest]
300    fn test_custom_key_quota() {
301        let mock_limiter = initialize_mock_rate_limiter();
302
303        // Add new key quota pair
304        mock_limiter.add_quota_for_key(
305            "custom".to_string(),
306            Quota::per_second(NonZeroU32::new(1).unwrap()),
307        );
308
309        // Check custom quota
310        assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
311        assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
312
313        // Check that default quota still applies to other keys
314        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
315        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
316        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
317    }
318
319    #[rstest]
320    fn test_multiple_keys() {
321        let mock_limiter = initialize_mock_rate_limiter();
322
323        mock_limiter.add_quota_for_key(
324            "key1".to_string(),
325            Quota::per_second(NonZeroU32::new(1).unwrap()),
326        );
327        mock_limiter.add_quota_for_key(
328            "key2".to_string(),
329            Quota::per_second(NonZeroU32::new(3).unwrap()),
330        );
331
332        // Test key1
333        assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
334        assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
335
336        // Test key2
337        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
338        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
339        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340        assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
341    }
342
343    #[rstest]
344    fn test_quota_reset() {
345        let mock_limiter = initialize_mock_rate_limiter();
346
347        // Exhaust quota
348        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
349        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
350        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
351
352        // Advance clock by less than a second
353        mock_limiter.advance_clock(Duration::from_millis(499));
354        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
355
356        // Advance clock to reset
357        mock_limiter.advance_clock(Duration::from_millis(501));
358        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
359    }
360
361    #[rstest]
362    fn test_different_quotas() {
363        let mock_limiter = initialize_mock_rate_limiter();
364
365        mock_limiter.add_quota_for_key(
366            "per_second".to_string(),
367            Quota::per_second(NonZeroU32::new(2).unwrap()),
368        );
369        mock_limiter.add_quota_for_key(
370            "per_minute".to_string(),
371            Quota::per_minute(NonZeroU32::new(3).unwrap()),
372        );
373
374        // Test per_second quota
375        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
376        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
377        assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
378
379        // Test per_minute quota
380        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
381        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
382        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
384
385        // Advance clock and check reset
386        mock_limiter.advance_clock(Duration::from_secs(1));
387        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
388        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
389    }
390
391    #[tokio::test]
392    async fn test_await_keys_ready() {
393        let mock_limiter = initialize_mock_rate_limiter();
394
395        // Check base quota is not exceeded
396        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
397        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
398
399        // Check base quota is exceeded
400        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
401
402        // Wait keys to be ready and check base quota is reset
403        mock_limiter.advance_clock(Duration::from_secs(1));
404        mock_limiter
405            .await_keys_ready(Some(vec!["default".to_string()]))
406            .await;
407        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
408    }
409}