nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34use tokio::time::sleep;
35
36use self::{
37 clock::{Clock, FakeRelativeClock, MonotonicClock},
38 gcra::{Gcra, NotUntil},
39 nanos::Nanos,
40 quota::Quota,
41};
42
43#[derive(Debug, Default)]
52pub struct InMemoryState(AtomicU64);
53
54impl InMemoryState {
55 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
61 where
62 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
63 {
64 let mut prev = self.0.load(Ordering::Acquire);
65 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
66 while let Ok((result, new_data)) = decision {
67 match self.0.compare_exchange_weak(
68 prev,
69 new_data.into(),
70 Ordering::Release,
71 Ordering::Relaxed,
72 ) {
73 Ok(_) => return Ok(result),
74 Err(next_prev) => prev = next_prev,
75 }
76 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
77 }
78 decision.map(|(result, _)| result)
81 }
82}
83
84pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
86
87pub trait StateStore {
98 type Key;
100
101 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118 where
119 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
120}
121
122impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
123 type Key = K;
124
125 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
126 where
127 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
128 {
129 if let Some(v) = self.get(key) {
130 return v.measure_and_replace_one(f);
132 }
133 let entry = self.entry(key.clone()).or_default();
135 (*entry).measure_and_replace_one(f)
136 }
137}
138
139pub struct RateLimiter<K, C>
144where
145 C: Clock,
146{
147 default_gcra: Option<Gcra>,
148 state: DashMapStateStore<K>,
149 gcra: DashMap<K, Gcra>,
150 clock: C,
151 start: C::Instant,
152}
153
154impl<K, C> Debug for RateLimiter<K, C>
155where
156 K: Debug,
157 C: Clock,
158{
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct(stringify!(RateLimiter)).finish()
161 }
162}
163
164impl<K> RateLimiter<K, MonotonicClock>
165where
166 K: Eq + Hash,
167{
168 #[must_use]
173 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
174 let clock = MonotonicClock {};
175 let start = MonotonicClock::now(&clock);
176 let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
177 Self {
178 default_gcra: base_quota.map(Gcra::new),
179 state: DashMapStateStore::new(),
180 gcra,
181 clock,
182 start,
183 }
184 }
185}
186
187impl<K> RateLimiter<K, FakeRelativeClock>
188where
189 K: Hash + Eq + Clone,
190{
191 pub fn advance_clock(&self, by: Duration) {
195 self.clock.advance(by);
196 }
197}
198
199impl<K, C> RateLimiter<K, C>
200where
201 K: Hash + Eq + Clone,
202 C: Clock,
203{
204 pub fn add_quota_for_key(&self, key: K, value: Quota) {
206 self.gcra.insert(key, Gcra::new(value));
207 }
208
209 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
215 match self.gcra.get(key) {
216 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
217 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
218 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
219 }),
220 }
221 }
222
223 pub async fn until_key_ready(&self, key: &K) {
225 loop {
226 match self.check_key(key) {
227 Ok(()) => {
228 break;
229 }
230 Err(neg) => {
231 sleep(neg.wait_time_from(self.clock.now())).await;
232 }
233 }
234 }
235 }
236
237 pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
241 let keys = keys.unwrap_or_default();
242 let tasks = keys.iter().map(|key| self.until_key_ready(key));
243
244 futures::stream::iter(tasks)
245 .for_each_concurrent(None, |key_future| async move {
246 key_future.await;
247 })
248 .await;
249 }
250}
251
252#[cfg(test)]
256mod tests {
257 use std::{num::NonZeroU32, time::Duration};
258
259 use dashmap::DashMap;
260 use rstest::rstest;
261
262 use super::{
263 DashMapStateStore, RateLimiter,
264 clock::{Clock, FakeRelativeClock},
265 gcra::Gcra,
266 quota::Quota,
267 };
268
269 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
270 let clock = FakeRelativeClock::default();
271 let start = clock.now();
272 let gcra = DashMap::new();
273 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
274 RateLimiter {
275 default_gcra: Some(Gcra::new(base_quota)),
276 state: DashMapStateStore::new(),
277 gcra,
278 clock,
279 start,
280 }
281 }
282
283 #[rstest]
284 fn test_default_quota() {
285 let mock_limiter = initialize_mock_rate_limiter();
286
287 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
289 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
290
291 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
293
294 mock_limiter.advance_clock(Duration::from_secs(1));
296 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
297 }
298
299 #[rstest]
300 fn test_custom_key_quota() {
301 let mock_limiter = initialize_mock_rate_limiter();
302
303 mock_limiter.add_quota_for_key(
305 "custom".to_string(),
306 Quota::per_second(NonZeroU32::new(1).unwrap()),
307 );
308
309 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
311 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
312
313 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
315 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
316 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
317 }
318
319 #[rstest]
320 fn test_multiple_keys() {
321 let mock_limiter = initialize_mock_rate_limiter();
322
323 mock_limiter.add_quota_for_key(
324 "key1".to_string(),
325 Quota::per_second(NonZeroU32::new(1).unwrap()),
326 );
327 mock_limiter.add_quota_for_key(
328 "key2".to_string(),
329 Quota::per_second(NonZeroU32::new(3).unwrap()),
330 );
331
332 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
334 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
335
336 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
338 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
339 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
341 }
342
343 #[rstest]
344 fn test_quota_reset() {
345 let mock_limiter = initialize_mock_rate_limiter();
346
347 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
349 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
350 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
351
352 mock_limiter.advance_clock(Duration::from_millis(499));
354 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
355
356 mock_limiter.advance_clock(Duration::from_millis(501));
358 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
359 }
360
361 #[rstest]
362 fn test_different_quotas() {
363 let mock_limiter = initialize_mock_rate_limiter();
364
365 mock_limiter.add_quota_for_key(
366 "per_second".to_string(),
367 Quota::per_second(NonZeroU32::new(2).unwrap()),
368 );
369 mock_limiter.add_quota_for_key(
370 "per_minute".to_string(),
371 Quota::per_minute(NonZeroU32::new(3).unwrap()),
372 );
373
374 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
376 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
377 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
378
379 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
381 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
382 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
384
385 mock_limiter.advance_clock(Duration::from_secs(1));
387 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
388 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
389 }
390
391 #[tokio::test]
392 async fn test_await_keys_ready() {
393 let mock_limiter = initialize_mock_rate_limiter();
394
395 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
397 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
398
399 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
401
402 mock_limiter.advance_clock(Duration::from_secs(1));
404 mock_limiter
405 .await_keys_ready(Some(vec!["default".to_string()]))
406 .await;
407 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
408 }
409}