1use std::{
17 any::Any,
18 cell::{RefCell, UnsafeCell},
19 collections::VecDeque,
20 fmt::Debug,
21 marker::PhantomData,
22 rc::Rc,
23};
24
25use nautilus_core::{UnixNanos, correctness::FAILED};
26use ustr::Ustr;
27
28use crate::{
29 actor::{
30 Actor,
31 registry::{get_actor_unchecked, register_actor},
32 },
33 clock::Clock,
34 msgbus::{
35 self,
36 handler::{MessageHandler, ShareableMessageHandler},
37 },
38 timer::{TimeEvent, TimeEventCallback},
39};
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct RateLimit {
44 pub limit: usize,
45 pub interval_ns: u64,
46}
47
48impl RateLimit {
49 #[must_use]
51 pub const fn new(limit: usize, interval_ns: u64) -> Self {
52 Self { limit, interval_ns }
53 }
54}
55
56pub struct Throttler<T, F> {
61 pub recv_count: usize,
63 pub sent_count: usize,
65 pub is_limiting: bool,
67 pub limit: usize,
69 pub buffer: VecDeque<T>,
71 pub timestamps: VecDeque<UnixNanos>,
73 pub clock: Rc<RefCell<dyn Clock>>,
75 pub actor_id: Ustr,
77 interval: u64,
79 timer_name: String,
81 output_send: F,
83 output_drop: Option<F>,
85}
86
87impl<T, F> Actor for Throttler<T, F>
88where
89 T: 'static,
90 F: Fn(T) + 'static,
91{
92 fn id(&self) -> Ustr {
93 self.actor_id
94 }
95
96 fn handle(&mut self, _msg: &dyn Any) {}
97
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101}
102
103impl<T, F> Debug for Throttler<T, F>
104where
105 T: Debug,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct(stringify!(InnerThrottler))
109 .field("recv_count", &self.recv_count)
110 .field("sent_count", &self.sent_count)
111 .field("is_limiting", &self.is_limiting)
112 .field("limit", &self.limit)
113 .field("buffer", &self.buffer)
114 .field("timestamps", &self.timestamps)
115 .field("interval", &self.interval)
116 .field("timer_name", &self.timer_name)
117 .finish()
118 }
119}
120
121impl<T, F> Throttler<T, F> {
122 #[inline]
123 pub fn new(
124 limit: usize,
125 interval: u64,
126 clock: Rc<RefCell<dyn Clock>>,
127 timer_name: String,
128 output_send: F,
129 output_drop: Option<F>,
130 actor_id: Ustr,
131 ) -> Self {
132 Self {
133 recv_count: 0,
134 sent_count: 0,
135 is_limiting: false,
136 limit,
137 buffer: VecDeque::new(),
138 timestamps: VecDeque::with_capacity(limit),
139 clock,
140 interval,
141 timer_name,
142 output_send,
143 output_drop,
144 actor_id,
145 }
146 }
147
148 #[inline]
158 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
159 let delta = self.delta_next();
160 let mut clock = self.clock.borrow_mut();
161 if clock.timer_names().contains(&self.timer_name.as_str()) {
162 clock.cancel_timer(&self.timer_name);
163 }
164 let alert_ts = clock.timestamp_ns() + delta;
165
166 clock
167 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
168 .expect(FAILED);
169 }
170
171 #[inline]
173 pub fn delta_next(&mut self) -> u64 {
174 match self.timestamps.get(self.limit - 1) {
175 Some(ts) => {
176 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
177 self.interval.saturating_sub(diff)
178 }
179 None => 0,
180 }
181 }
182
183 #[inline]
185 pub fn reset(&mut self) {
186 self.buffer.clear();
187 self.recv_count = 0;
188 self.sent_count = 0;
189 self.is_limiting = false;
190 self.timestamps.clear();
191 }
192
193 #[inline]
195 pub fn used(&self) -> f64 {
196 if self.timestamps.is_empty() {
197 return 0.0;
198 }
199
200 let now = self.clock.borrow().timestamp_ns().as_i64();
201 let interval_start = now - self.interval as i64;
202
203 let messages_in_current_interval = self
204 .timestamps
205 .iter()
206 .take_while(|&&ts| ts.as_i64() > interval_start)
207 .count();
208
209 (messages_in_current_interval as f64) / (self.limit as f64)
210 }
211
212 #[inline]
214 pub fn qsize(&self) -> usize {
215 self.buffer.len()
216 }
217}
218
219impl<T, F> Throttler<T, F>
220where
221 T: 'static,
222 F: Fn(T) + 'static,
223{
224 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
225 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
227 msgbus::register(
228 process_handler.id().as_str().into(),
229 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn MessageHandler>),
230 );
231
232 let actor = Rc::new(UnsafeCell::new(self));
234 register_actor(actor.clone());
235
236 actor
237 }
238
239 #[inline]
240 pub fn send_msg(&mut self, msg: T) {
241 let now = self.clock.borrow().timestamp_ns();
242
243 if self.timestamps.len() >= self.limit {
244 self.timestamps.pop_back();
245 }
246 self.timestamps.push_front(now);
247
248 self.sent_count += 1;
249 (self.output_send)(msg);
250 }
251
252 #[inline]
253 pub fn limit_msg(&mut self, msg: T) {
254 let callback = if self.output_drop.is_none() {
255 self.buffer.push_front(msg);
256 log::debug!("Buffering {}", self.buffer.len());
257 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
258 } else {
259 log::debug!("Dropping");
260 if let Some(drop) = &self.output_drop {
261 drop(msg);
262 }
263 Some(throttler_resume::<T, F>(self.actor_id))
264 };
265 if !self.is_limiting {
266 log::debug!("Limiting");
267 self.set_timer(callback);
268 self.is_limiting = true;
269 }
270 }
271
272 #[inline]
273 pub fn send(&mut self, msg: T)
274 where
275 T: 'static,
276 F: Fn(T) + 'static,
277 {
278 self.recv_count += 1;
279
280 if self.is_limiting || self.delta_next() > 0 {
281 self.limit_msg(msg);
282 } else {
283 self.send_msg(msg);
284 }
285 }
286}
287
288struct ThrottlerProcess<T, F> {
293 actor_id: Ustr,
294 endpoint: Ustr,
295 phantom_t: PhantomData<T>,
296 phantom_f: PhantomData<F>,
297}
298
299impl<T, F> ThrottlerProcess<T, F> {
300 pub fn new(actor_id: Ustr) -> Self {
301 let endpoint = Ustr::from(&format!("{}_process", actor_id));
302 Self {
303 actor_id,
304 endpoint,
305 phantom_t: PhantomData,
306 phantom_f: PhantomData,
307 }
308 }
309
310 pub fn get_timer_callback(&self) -> TimeEventCallback {
311 let endpoint = self.endpoint.into(); let process_callback = Rc::new(move |_event: TimeEvent| {
313 msgbus::send(endpoint, &());
314 });
315 TimeEventCallback::Rust(process_callback)
316 }
317}
318
319impl<T, F> MessageHandler for ThrottlerProcess<T, F>
320where
321 T: 'static,
322 F: Fn(T) + 'static,
323{
324 fn id(&self) -> Ustr {
325 self.endpoint
326 }
327
328 fn handle(&self, _message: &dyn Any) {
329 let throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
330 while let Some(msg) = throttler.buffer.pop_back() {
331 throttler.send_msg(msg);
332
333 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
337 throttler.is_limiting = true;
338
339 let endpoint = self.endpoint.into(); let process_callback = Rc::new(move |_event: TimeEvent| {
343 msgbus::send(endpoint, &());
344 });
345 throttler.set_timer(Some(TimeEventCallback::Rust(process_callback)));
346 return;
347 }
348 }
349
350 throttler.is_limiting = false;
351 }
352
353 fn as_any(&self) -> &dyn Any {
354 self
355 }
356}
357
358pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
360where
361 T: 'static,
362 F: Fn(T) + 'static,
363{
364 let callback = Rc::new(move |_event: TimeEvent| {
365 let throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
366 throttler.is_limiting = false;
367 });
368
369 TimeEventCallback::Rust(callback)
370}
371
372#[cfg(test)]
376mod tests {
377 use std::{
378 cell::{RefCell, UnsafeCell},
379 rc::Rc,
380 };
381
382 use nautilus_core::UUID4;
383 use rstest::{fixture, rstest};
384 use ustr::Ustr;
385
386 use super::{RateLimit, Throttler};
387 use crate::clock::TestClock;
388 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
389
390 #[derive(Clone)]
395 struct TestThrottler {
396 throttler: SharedThrottler,
397 clock: Rc<RefCell<TestClock>>,
398 interval: u64,
399 }
400
401 #[allow(unsafe_code)]
402 impl TestThrottler {
403 #[allow(clippy::mut_from_ref)]
404 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
405 unsafe { &mut *self.throttler.get() }
406 }
407 }
408
409 #[fixture]
410 pub fn test_throttler_buffered() -> TestThrottler {
411 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
412 log::debug!("Sent: {msg}");
413 });
414 let clock = Rc::new(RefCell::new(TestClock::new()));
415 let inner_clock = Rc::clone(&clock);
416 let rate_limit = RateLimit::new(5, 10);
417 let interval = rate_limit.interval_ns;
418 let actor_id = Ustr::from(&UUID4::new().to_string());
419
420 TestThrottler {
421 throttler: Throttler::new(
422 rate_limit.limit,
423 rate_limit.interval_ns,
424 clock,
425 "buffer_timer".to_string(),
426 output_send,
427 None,
428 actor_id,
429 )
430 .to_actor(),
431 clock: inner_clock,
432 interval,
433 }
434 }
435
436 #[fixture]
437 pub fn test_throttler_unbuffered() -> TestThrottler {
438 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
439 log::debug!("Sent: {msg}");
440 });
441 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
442 log::debug!("Dropped: {msg}");
443 });
444 let clock = Rc::new(RefCell::new(TestClock::new()));
445 let inner_clock = Rc::clone(&clock);
446 let rate_limit = RateLimit::new(5, 10);
447 let interval = rate_limit.interval_ns;
448 let actor_id = Ustr::from(&UUID4::new().to_string());
449
450 TestThrottler {
451 throttler: Throttler::new(
452 rate_limit.limit,
453 rate_limit.interval_ns,
454 clock,
455 "dropper_timer".to_string(),
456 output_send,
457 Some(output_drop),
458 actor_id,
459 )
460 .to_actor(),
461 clock: inner_clock,
462 interval,
463 }
464 }
465
466 #[rstest]
467 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
468 let throttler = test_throttler_buffered.get_throttler();
469 for _ in 0..6 {
470 throttler.send(42);
471 }
472 assert_eq!(throttler.qsize(), 1);
473
474 assert!(throttler.is_limiting);
475 assert_eq!(throttler.recv_count, 6);
476 assert_eq!(throttler.sent_count, 5);
477 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
478 }
479
480 #[rstest]
481 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
482 let throttler = test_throttler_buffered.get_throttler();
483
484 for _ in 0..5 {
485 throttler.send(42);
486 }
487
488 assert_eq!(throttler.used(), 1.0);
489 assert_eq!(throttler.recv_count, 5);
490 assert_eq!(throttler.sent_count, 5);
491 }
492
493 #[rstest]
494 fn test_buffering_used_when_half_interval_from_limit_returns_one(
495 test_throttler_buffered: TestThrottler,
496 ) {
497 let throttler = test_throttler_buffered.get_throttler();
498
499 for _ in 0..5 {
500 throttler.send(42);
501 }
502
503 let half_interval = test_throttler_buffered.interval / 2;
504 {
506 let mut clock = test_throttler_buffered.clock.borrow_mut();
507 clock.advance_time(half_interval.into(), true);
508 }
509
510 assert_eq!(throttler.used(), 1.0);
511 assert_eq!(throttler.recv_count, 5);
512 assert_eq!(throttler.sent_count, 5);
513 }
514
515 #[rstest]
516 fn test_buffering_used_before_limit_when_halfway_returns_half(
517 test_throttler_buffered: TestThrottler,
518 ) {
519 let throttler = test_throttler_buffered.get_throttler();
520
521 for _ in 0..3 {
522 throttler.send(42);
523 }
524
525 assert_eq!(throttler.used(), 0.6);
526 assert_eq!(throttler.recv_count, 3);
527 assert_eq!(throttler.sent_count, 3);
528 }
529
530 #[rstest]
531 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
532 test_throttler_buffered: TestThrottler,
533 ) {
534 let throttler = test_throttler_buffered.get_throttler();
535
536 for _ in 0..6 {
537 throttler.send(42);
538 }
539
540 {
542 let mut clock = test_throttler_buffered.clock.borrow_mut();
543 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
544 for each_event in clock.match_handlers(time_events) {
545 drop(clock); each_event.callback.call(each_event.event);
548
549 clock = test_throttler_buffered.clock.borrow_mut();
551 }
552 }
553
554 assert_eq!(throttler.used(), 0.2);
556 assert_eq!(throttler.recv_count, 6);
557 assert_eq!(throttler.sent_count, 6);
558 assert_eq!(throttler.qsize(), 0);
559 }
560
561 #[rstest]
562 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
563 let throttler = test_throttler_buffered.get_throttler();
564
565 for _ in 0..6 {
566 throttler.send(43);
567 }
568
569 {
571 let mut clock = test_throttler_buffered.clock.borrow_mut();
572 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
573 for each_event in clock.match_handlers(time_events) {
574 drop(clock); each_event.callback.call(each_event.event);
577
578 clock = test_throttler_buffered.clock.borrow_mut();
580 }
581 }
582
583 for _ in 0..6 {
584 throttler.send(42);
585 }
586
587 assert_eq!(throttler.used(), 1.0);
589 assert_eq!(throttler.recv_count, 12);
590 assert_eq!(throttler.sent_count, 10);
591 assert_eq!(throttler.qsize(), 2);
592 }
593
594 #[rstest]
595 fn test_buffering_send_message_after_halfway_after_buffering_message(
596 test_throttler_buffered: TestThrottler,
597 ) {
598 let throttler = test_throttler_buffered.get_throttler();
599
600 for _ in 0..6 {
601 throttler.send(42);
602 }
603
604 {
606 let mut clock = test_throttler_buffered.clock.borrow_mut();
607 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
608 for each_event in clock.match_handlers(time_events) {
609 drop(clock); each_event.callback.call(each_event.event);
612
613 clock = test_throttler_buffered.clock.borrow_mut();
615 }
616 }
617
618 for _ in 0..3 {
619 throttler.send(42);
620 }
621
622 assert_eq!(throttler.used(), 0.8);
624 assert_eq!(throttler.recv_count, 9);
625 assert_eq!(throttler.sent_count, 9);
626 assert_eq!(throttler.qsize(), 0);
627 }
628
629 #[rstest]
630 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
631 let throttler = test_throttler_unbuffered.get_throttler();
632 throttler.send(42);
633
634 assert!(!throttler.is_limiting);
635 assert_eq!(throttler.recv_count, 1);
636 assert_eq!(throttler.sent_count, 1);
637 }
638
639 #[rstest]
640 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
641 let throttler = test_throttler_unbuffered.get_throttler();
642 for _ in 0..6 {
643 throttler.send(42);
644 }
645 assert_eq!(throttler.qsize(), 0);
646
647 assert!(throttler.is_limiting);
648 assert_eq!(throttler.used(), 1.0);
649 assert_eq!(throttler.clock.borrow().timer_count(), 1);
650 assert_eq!(
651 throttler.clock.borrow().timer_names(),
652 vec!["dropper_timer"]
653 );
654 assert_eq!(throttler.recv_count, 6);
655 assert_eq!(throttler.sent_count, 5);
656 }
657
658 #[rstest]
659 fn test_dropping_advance_time_when_at_limit_dropped_message(
660 test_throttler_unbuffered: TestThrottler,
661 ) {
662 let throttler = test_throttler_unbuffered.get_throttler();
663 for _ in 0..6 {
664 throttler.send(42);
665 }
666
667 {
669 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
670 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
671 for each_event in clock.match_handlers(time_events) {
672 drop(clock); each_event.callback.call(each_event.event);
675
676 clock = test_throttler_unbuffered.clock.borrow_mut();
678 }
679 }
680
681 assert_eq!(throttler.clock.borrow().timer_count(), 0);
682 assert!(!throttler.is_limiting);
683 assert_eq!(throttler.used(), 0.0);
684 assert_eq!(throttler.recv_count, 6);
685 assert_eq!(throttler.sent_count, 5);
686 }
687
688 #[rstest]
689 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
690 let throttler = test_throttler_unbuffered.get_throttler();
691 for _ in 0..6 {
692 throttler.send(42);
693 }
694
695 {
697 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
698 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
699 for each_event in clock.match_handlers(time_events) {
700 drop(clock); each_event.callback.call(each_event.event);
703
704 clock = test_throttler_unbuffered.clock.borrow_mut();
706 }
707 }
708
709 throttler.send(42);
710
711 assert_eq!(throttler.used(), 0.2);
712 assert_eq!(throttler.clock.borrow().timer_count(), 0);
713 assert!(!throttler.is_limiting);
714 assert_eq!(throttler.recv_count, 7);
715 assert_eq!(throttler.sent_count, 6);
716 }
717
718 use proptest::prelude::*;
719
720 #[derive(Clone, Debug)]
721 enum ThrottlerInput {
722 SendMessage(u64),
723 AdvanceClock(u8),
724 }
725
726 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
728 prop_oneof![
729 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
730 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
731 ]
732 }
733
734 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
736 prop::collection::vec(throttler_input_strategy(), 10..=150)
737 }
738
739 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
740 let test_clock = test_throttler.clock.clone();
741 let interval = test_throttler.interval;
742 let throttler = test_throttler.get_throttler();
743 let mut sent_count = 0;
744
745 for input in inputs {
746 match input {
747 ThrottlerInput::SendMessage(msg) => {
748 throttler.send(msg);
749 sent_count += 1;
750 }
751 ThrottlerInput::AdvanceClock(duration) => {
752 let mut clock_ref = test_clock.borrow_mut();
753 let current_time = clock_ref.get_time_ns();
754 let time_events =
755 clock_ref.advance_time(current_time + u64::from(duration), true);
756 for each_event in clock_ref.match_handlers(time_events) {
757 drop(clock_ref);
758 each_event.callback.call(each_event.event);
759 clock_ref = test_clock.borrow_mut();
760 }
761 }
762 }
763
764 let buffered_messages = throttler.qsize() > 0;
769 let now = throttler.clock.borrow().timestamp_ns().as_u64();
770 let limit_filled_within_interval = throttler
771 .timestamps
772 .get(throttler.limit - 1)
773 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
774 let expected_limiting = buffered_messages && limit_filled_within_interval;
775 assert_eq!(throttler.is_limiting, expected_limiting);
776
777 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
779 }
780
781 let time_events = test_clock
783 .borrow_mut()
784 .advance_time((interval * 100).into(), true);
785 let mut clock_ref = test_clock.borrow_mut();
786 for each_event in clock_ref.match_handlers(time_events) {
787 drop(clock_ref);
788 each_event.callback.call(each_event.event);
789 clock_ref = test_clock.borrow_mut();
790 }
791 assert_eq!(throttler.qsize(), 0);
792 }
793
794 #[ignore = "Used for manually testing failing cases"]
795 #[rstest]
796 fn test_case() {
797 let inputs = [
798 ThrottlerInput::SendMessage(42),
799 ThrottlerInput::AdvanceClock(5),
800 ThrottlerInput::SendMessage(42),
801 ThrottlerInput::SendMessage(42),
802 ThrottlerInput::SendMessage(42),
803 ThrottlerInput::SendMessage(42),
804 ThrottlerInput::SendMessage(42),
805 ThrottlerInput::AdvanceClock(5),
806 ThrottlerInput::SendMessage(42),
807 ThrottlerInput::SendMessage(42),
808 ]
809 .to_vec();
810
811 let test_throttler = test_throttler_buffered();
812 test_throttler_with_inputs(inputs, test_throttler);
813 }
814
815 #[rstest]
816 #[allow(unsafe_code)]
817 fn prop_test() {
818 let test_throttler = test_throttler_buffered();
819
820 proptest!(move |(inputs in throttler_test_strategy())| {
821 test_throttler_with_inputs(inputs, test_throttler.clone());
822 let throttler = unsafe { &mut *(test_throttler.throttler.get() as *mut _ as *mut Throttler<u64, Box<dyn Fn(u64)>>) };
824 throttler.reset();
825 throttler.clock.borrow_mut().reset();
826 });
827 }
828}