1use std::{
23 any::Any,
24 cell::{RefCell, UnsafeCell},
25 collections::VecDeque,
26 fmt::Debug,
27 marker::PhantomData,
28 rc::Rc,
29};
30
31use nautilus_core::{UnixNanos, correctness::FAILED};
32use ustr::Ustr;
33
34use crate::{
35 actor::{
36 Actor,
37 registry::{get_actor_unchecked, register_actor},
38 },
39 clock::Clock,
40 msgbus::{
41 self,
42 handler::{MessageHandler, ShareableMessageHandler},
43 },
44 timer::{TimeEvent, TimeEventCallback},
45};
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct RateLimit {
50 pub limit: usize,
51 pub interval_ns: u64,
52}
53
54impl RateLimit {
55 #[must_use]
57 pub const fn new(limit: usize, interval_ns: u64) -> Self {
58 Self { limit, interval_ns }
59 }
60}
61
62pub struct Throttler<T, F> {
67 pub recv_count: usize,
69 pub sent_count: usize,
71 pub is_limiting: bool,
73 pub limit: usize,
75 pub buffer: VecDeque<T>,
77 pub timestamps: VecDeque<UnixNanos>,
79 pub clock: Rc<RefCell<dyn Clock>>,
81 pub actor_id: Ustr,
83 interval: u64,
85 timer_name: String,
87 output_send: F,
89 output_drop: Option<F>,
91}
92
93impl<T, F> Actor for Throttler<T, F>
94where
95 T: 'static + Debug,
96 F: Fn(T) + 'static,
97{
98 fn id(&self) -> Ustr {
99 self.actor_id
100 }
101
102 fn handle(&mut self, _msg: &dyn Any) {}
103
104 fn as_any(&self) -> &dyn Any {
105 self
106 }
107}
108
109impl<T, F> Debug for Throttler<T, F>
110where
111 T: Debug,
112{
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct(stringify!(InnerThrottler))
115 .field("recv_count", &self.recv_count)
116 .field("sent_count", &self.sent_count)
117 .field("is_limiting", &self.is_limiting)
118 .field("limit", &self.limit)
119 .field("buffer", &self.buffer)
120 .field("timestamps", &self.timestamps)
121 .field("interval", &self.interval)
122 .field("timer_name", &self.timer_name)
123 .finish()
124 }
125}
126
127impl<T, F> Throttler<T, F>
128where
129 T: Debug,
130{
131 #[inline]
132 pub fn new(
133 limit: usize,
134 interval: u64,
135 clock: Rc<RefCell<dyn Clock>>,
136 timer_name: String,
137 output_send: F,
138 output_drop: Option<F>,
139 actor_id: Ustr,
140 ) -> Self {
141 Self {
142 recv_count: 0,
143 sent_count: 0,
144 is_limiting: false,
145 limit,
146 buffer: VecDeque::new(),
147 timestamps: VecDeque::with_capacity(limit),
148 clock,
149 interval,
150 timer_name,
151 output_send,
152 output_drop,
153 actor_id,
154 }
155 }
156
157 #[inline]
167 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
168 let delta = self.delta_next();
169 let mut clock = self.clock.borrow_mut();
170 if clock.timer_names().contains(&self.timer_name.as_str()) {
171 clock.cancel_timer(&self.timer_name);
172 }
173 let alert_ts = clock.timestamp_ns() + delta;
174
175 clock
176 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
177 .expect(FAILED);
178 }
179
180 #[inline]
182 pub fn delta_next(&mut self) -> u64 {
183 match self.timestamps.get(self.limit - 1) {
184 Some(ts) => {
185 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
186 self.interval.saturating_sub(diff)
187 }
188 None => 0,
189 }
190 }
191
192 #[inline]
194 pub fn reset(&mut self) {
195 self.buffer.clear();
196 self.recv_count = 0;
197 self.sent_count = 0;
198 self.is_limiting = false;
199 self.timestamps.clear();
200 }
201
202 #[inline]
204 pub fn used(&self) -> f64 {
205 if self.timestamps.is_empty() {
206 return 0.0;
207 }
208
209 let now = self.clock.borrow().timestamp_ns().as_i64();
210 let interval_start = now - self.interval as i64;
211
212 let messages_in_current_interval = self
213 .timestamps
214 .iter()
215 .take_while(|&&ts| ts.as_i64() > interval_start)
216 .count();
217
218 (messages_in_current_interval as f64) / (self.limit as f64)
219 }
220
221 #[inline]
223 pub fn qsize(&self) -> usize {
224 self.buffer.len()
225 }
226}
227
228impl<T, F> Throttler<T, F>
229where
230 T: 'static + Debug,
231 F: Fn(T) + 'static,
232{
233 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
234 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
236 msgbus::register(
237 process_handler.id().as_str().into(),
238 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn MessageHandler>),
239 );
240
241 register_actor(self)
243 }
244
245 #[inline]
246 pub fn send_msg(&mut self, msg: T) {
247 let now = self.clock.borrow().timestamp_ns();
248
249 if self.timestamps.len() >= self.limit {
250 self.timestamps.pop_back();
251 }
252 self.timestamps.push_front(now);
253
254 self.sent_count += 1;
255 (self.output_send)(msg);
256 }
257
258 #[inline]
259 pub fn limit_msg(&mut self, msg: T) {
260 let callback = if self.output_drop.is_none() {
261 self.buffer.push_front(msg);
262 log::debug!("Buffering {}", self.buffer.len());
263 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
264 } else {
265 log::debug!("Dropping");
266 if let Some(drop) = &self.output_drop {
267 drop(msg);
268 }
269 Some(throttler_resume::<T, F>(self.actor_id))
270 };
271 if !self.is_limiting {
272 log::debug!("Limiting");
273 self.set_timer(callback);
274 self.is_limiting = true;
275 }
276 }
277
278 #[inline]
279 pub fn send(&mut self, msg: T)
280 where
281 T: 'static,
282 F: Fn(T) + 'static,
283 {
284 self.recv_count += 1;
285
286 if self.is_limiting || self.delta_next() > 0 {
287 self.limit_msg(msg);
288 } else {
289 self.send_msg(msg);
290 }
291 }
292}
293
294struct ThrottlerProcess<T, F> {
299 actor_id: Ustr,
300 endpoint: Ustr,
301 phantom_t: PhantomData<T>,
302 phantom_f: PhantomData<F>,
303}
304
305impl<T, F> ThrottlerProcess<T, F>
306where
307 T: Debug,
308{
309 pub fn new(actor_id: Ustr) -> Self {
310 let endpoint = Ustr::from(&format!("{}_process", actor_id));
311 Self {
312 actor_id,
313 endpoint,
314 phantom_t: PhantomData,
315 phantom_f: PhantomData,
316 }
317 }
318
319 pub fn get_timer_callback(&self) -> TimeEventCallback {
320 let endpoint = self.endpoint.into(); let process_callback = Rc::new(move |_event: TimeEvent| {
322 msgbus::send_any(endpoint, &());
323 });
324 TimeEventCallback::Rust(process_callback)
325 }
326}
327
328impl<T, F> MessageHandler for ThrottlerProcess<T, F>
329where
330 T: 'static + Debug,
331 F: Fn(T) + 'static,
332{
333 fn id(&self) -> Ustr {
334 self.endpoint
335 }
336
337 fn handle(&self, _message: &dyn Any) {
338 let throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
339 while let Some(msg) = throttler.buffer.pop_back() {
340 throttler.send_msg(msg);
341
342 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
346 throttler.is_limiting = true;
347
348 let endpoint = self.endpoint.into(); let process_callback = Rc::new(move |_event: TimeEvent| {
352 msgbus::send_any(endpoint, &());
353 });
354 throttler.set_timer(Some(TimeEventCallback::Rust(process_callback)));
355 return;
356 }
357 }
358
359 throttler.is_limiting = false;
360 }
361
362 fn as_any(&self) -> &dyn Any {
363 self
364 }
365}
366
367pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
369where
370 T: 'static + Debug,
371 F: Fn(T) + 'static,
372{
373 let callback = Rc::new(move |_event: TimeEvent| {
374 let throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
375 throttler.is_limiting = false;
376 });
377
378 TimeEventCallback::Rust(callback)
379}
380
381#[cfg(test)]
385mod tests {
386 use std::{
387 cell::{RefCell, UnsafeCell},
388 rc::Rc,
389 };
390
391 use nautilus_core::UUID4;
392 use rstest::{fixture, rstest};
393 use ustr::Ustr;
394
395 use super::{RateLimit, Throttler};
396 use crate::clock::TestClock;
397 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
398
399 #[derive(Clone)]
404 struct TestThrottler {
405 throttler: SharedThrottler,
406 clock: Rc<RefCell<TestClock>>,
407 interval: u64,
408 }
409
410 #[allow(unsafe_code)]
411 impl TestThrottler {
412 #[allow(clippy::mut_from_ref)]
413 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
414 unsafe { &mut *self.throttler.get() }
415 }
416 }
417
418 #[fixture]
419 pub fn test_throttler_buffered() -> TestThrottler {
420 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
421 log::debug!("Sent: {msg}");
422 });
423 let clock = Rc::new(RefCell::new(TestClock::new()));
424 let inner_clock = Rc::clone(&clock);
425 let rate_limit = RateLimit::new(5, 10);
426 let interval = rate_limit.interval_ns;
427 let actor_id = Ustr::from(&UUID4::new().to_string());
428
429 TestThrottler {
430 throttler: Throttler::new(
431 rate_limit.limit,
432 rate_limit.interval_ns,
433 clock,
434 "buffer_timer".to_string(),
435 output_send,
436 None,
437 actor_id,
438 )
439 .to_actor(),
440 clock: inner_clock,
441 interval,
442 }
443 }
444
445 #[fixture]
446 pub fn test_throttler_unbuffered() -> TestThrottler {
447 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
448 log::debug!("Sent: {msg}");
449 });
450 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
451 log::debug!("Dropped: {msg}");
452 });
453 let clock = Rc::new(RefCell::new(TestClock::new()));
454 let inner_clock = Rc::clone(&clock);
455 let rate_limit = RateLimit::new(5, 10);
456 let interval = rate_limit.interval_ns;
457 let actor_id = Ustr::from(&UUID4::new().to_string());
458
459 TestThrottler {
460 throttler: Throttler::new(
461 rate_limit.limit,
462 rate_limit.interval_ns,
463 clock,
464 "dropper_timer".to_string(),
465 output_send,
466 Some(output_drop),
467 actor_id,
468 )
469 .to_actor(),
470 clock: inner_clock,
471 interval,
472 }
473 }
474
475 #[rstest]
476 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
477 let throttler = test_throttler_buffered.get_throttler();
478 for _ in 0..6 {
479 throttler.send(42);
480 }
481 assert_eq!(throttler.qsize(), 1);
482
483 assert!(throttler.is_limiting);
484 assert_eq!(throttler.recv_count, 6);
485 assert_eq!(throttler.sent_count, 5);
486 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
487 }
488
489 #[rstest]
490 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
491 let throttler = test_throttler_buffered.get_throttler();
492
493 for _ in 0..5 {
494 throttler.send(42);
495 }
496
497 assert_eq!(throttler.used(), 1.0);
498 assert_eq!(throttler.recv_count, 5);
499 assert_eq!(throttler.sent_count, 5);
500 }
501
502 #[rstest]
503 fn test_buffering_used_when_half_interval_from_limit_returns_one(
504 test_throttler_buffered: TestThrottler,
505 ) {
506 let throttler = test_throttler_buffered.get_throttler();
507
508 for _ in 0..5 {
509 throttler.send(42);
510 }
511
512 let half_interval = test_throttler_buffered.interval / 2;
513 {
515 let mut clock = test_throttler_buffered.clock.borrow_mut();
516 clock.advance_time(half_interval.into(), true);
517 }
518
519 assert_eq!(throttler.used(), 1.0);
520 assert_eq!(throttler.recv_count, 5);
521 assert_eq!(throttler.sent_count, 5);
522 }
523
524 #[rstest]
525 fn test_buffering_used_before_limit_when_halfway_returns_half(
526 test_throttler_buffered: TestThrottler,
527 ) {
528 let throttler = test_throttler_buffered.get_throttler();
529
530 for _ in 0..3 {
531 throttler.send(42);
532 }
533
534 assert_eq!(throttler.used(), 0.6);
535 assert_eq!(throttler.recv_count, 3);
536 assert_eq!(throttler.sent_count, 3);
537 }
538
539 #[rstest]
540 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
541 test_throttler_buffered: TestThrottler,
542 ) {
543 let throttler = test_throttler_buffered.get_throttler();
544
545 for _ in 0..6 {
546 throttler.send(42);
547 }
548
549 {
551 let mut clock = test_throttler_buffered.clock.borrow_mut();
552 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
553 for each_event in clock.match_handlers(time_events) {
554 drop(clock); each_event.callback.call(each_event.event);
557
558 clock = test_throttler_buffered.clock.borrow_mut();
560 }
561 }
562
563 assert_eq!(throttler.used(), 0.2);
565 assert_eq!(throttler.recv_count, 6);
566 assert_eq!(throttler.sent_count, 6);
567 assert_eq!(throttler.qsize(), 0);
568 }
569
570 #[rstest]
571 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
572 let throttler = test_throttler_buffered.get_throttler();
573
574 for _ in 0..6 {
575 throttler.send(43);
576 }
577
578 {
580 let mut clock = test_throttler_buffered.clock.borrow_mut();
581 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
582 for each_event in clock.match_handlers(time_events) {
583 drop(clock); each_event.callback.call(each_event.event);
586
587 clock = test_throttler_buffered.clock.borrow_mut();
589 }
590 }
591
592 for _ in 0..6 {
593 throttler.send(42);
594 }
595
596 assert_eq!(throttler.used(), 1.0);
598 assert_eq!(throttler.recv_count, 12);
599 assert_eq!(throttler.sent_count, 10);
600 assert_eq!(throttler.qsize(), 2);
601 }
602
603 #[rstest]
604 fn test_buffering_send_message_after_halfway_after_buffering_message(
605 test_throttler_buffered: TestThrottler,
606 ) {
607 let throttler = test_throttler_buffered.get_throttler();
608
609 for _ in 0..6 {
610 throttler.send(42);
611 }
612
613 {
615 let mut clock = test_throttler_buffered.clock.borrow_mut();
616 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
617 for each_event in clock.match_handlers(time_events) {
618 drop(clock); each_event.callback.call(each_event.event);
621
622 clock = test_throttler_buffered.clock.borrow_mut();
624 }
625 }
626
627 for _ in 0..3 {
628 throttler.send(42);
629 }
630
631 assert_eq!(throttler.used(), 0.8);
633 assert_eq!(throttler.recv_count, 9);
634 assert_eq!(throttler.sent_count, 9);
635 assert_eq!(throttler.qsize(), 0);
636 }
637
638 #[rstest]
639 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
640 let throttler = test_throttler_unbuffered.get_throttler();
641 throttler.send(42);
642
643 assert!(!throttler.is_limiting);
644 assert_eq!(throttler.recv_count, 1);
645 assert_eq!(throttler.sent_count, 1);
646 }
647
648 #[rstest]
649 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
650 let throttler = test_throttler_unbuffered.get_throttler();
651 for _ in 0..6 {
652 throttler.send(42);
653 }
654 assert_eq!(throttler.qsize(), 0);
655
656 assert!(throttler.is_limiting);
657 assert_eq!(throttler.used(), 1.0);
658 assert_eq!(throttler.clock.borrow().timer_count(), 1);
659 assert_eq!(
660 throttler.clock.borrow().timer_names(),
661 vec!["dropper_timer"]
662 );
663 assert_eq!(throttler.recv_count, 6);
664 assert_eq!(throttler.sent_count, 5);
665 }
666
667 #[rstest]
668 fn test_dropping_advance_time_when_at_limit_dropped_message(
669 test_throttler_unbuffered: TestThrottler,
670 ) {
671 let throttler = test_throttler_unbuffered.get_throttler();
672 for _ in 0..6 {
673 throttler.send(42);
674 }
675
676 {
678 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
679 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
680 for each_event in clock.match_handlers(time_events) {
681 drop(clock); each_event.callback.call(each_event.event);
684
685 clock = test_throttler_unbuffered.clock.borrow_mut();
687 }
688 }
689
690 assert_eq!(throttler.clock.borrow().timer_count(), 0);
691 assert!(!throttler.is_limiting);
692 assert_eq!(throttler.used(), 0.0);
693 assert_eq!(throttler.recv_count, 6);
694 assert_eq!(throttler.sent_count, 5);
695 }
696
697 #[rstest]
698 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
699 let throttler = test_throttler_unbuffered.get_throttler();
700 for _ in 0..6 {
701 throttler.send(42);
702 }
703
704 {
706 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
707 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
708 for each_event in clock.match_handlers(time_events) {
709 drop(clock); each_event.callback.call(each_event.event);
712
713 clock = test_throttler_unbuffered.clock.borrow_mut();
715 }
716 }
717
718 throttler.send(42);
719
720 assert_eq!(throttler.used(), 0.2);
721 assert_eq!(throttler.clock.borrow().timer_count(), 0);
722 assert!(!throttler.is_limiting);
723 assert_eq!(throttler.recv_count, 7);
724 assert_eq!(throttler.sent_count, 6);
725 }
726
727 use proptest::prelude::*;
732
733 #[derive(Clone, Debug)]
734 enum ThrottlerInput {
735 SendMessage(u64),
736 AdvanceClock(u8),
737 }
738
739 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
741 prop_oneof![
742 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
743 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
744 ]
745 }
746
747 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
749 prop::collection::vec(throttler_input_strategy(), 10..=150)
750 }
751
752 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
753 let test_clock = test_throttler.clock.clone();
754 let interval = test_throttler.interval;
755 let throttler = test_throttler.get_throttler();
756 let mut sent_count = 0;
757
758 for input in inputs {
759 match input {
760 ThrottlerInput::SendMessage(msg) => {
761 throttler.send(msg);
762 sent_count += 1;
763 }
764 ThrottlerInput::AdvanceClock(duration) => {
765 let mut clock_ref = test_clock.borrow_mut();
766 let current_time = clock_ref.get_time_ns();
767 let time_events =
768 clock_ref.advance_time(current_time + u64::from(duration), true);
769 for each_event in clock_ref.match_handlers(time_events) {
770 drop(clock_ref);
771 each_event.callback.call(each_event.event);
772 clock_ref = test_clock.borrow_mut();
773 }
774 }
775 }
776
777 let buffered_messages = throttler.qsize() > 0;
782 let now = throttler.clock.borrow().timestamp_ns().as_u64();
783 let limit_filled_within_interval = throttler
784 .timestamps
785 .get(throttler.limit - 1)
786 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
787 let expected_limiting = buffered_messages && limit_filled_within_interval;
788 assert_eq!(throttler.is_limiting, expected_limiting);
789
790 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
792 }
793
794 let time_events = test_clock
796 .borrow_mut()
797 .advance_time((interval * 100).into(), true);
798 let mut clock_ref = test_clock.borrow_mut();
799 for each_event in clock_ref.match_handlers(time_events) {
800 drop(clock_ref);
801 each_event.callback.call(each_event.event);
802 clock_ref = test_clock.borrow_mut();
803 }
804 assert_eq!(throttler.qsize(), 0);
805 }
806
807 #[ignore = "Used for manually testing failing cases"]
808 #[rstest]
809 fn test_case() {
810 let inputs = [
811 ThrottlerInput::SendMessage(42),
812 ThrottlerInput::AdvanceClock(5),
813 ThrottlerInput::SendMessage(42),
814 ThrottlerInput::SendMessage(42),
815 ThrottlerInput::SendMessage(42),
816 ThrottlerInput::SendMessage(42),
817 ThrottlerInput::SendMessage(42),
818 ThrottlerInput::AdvanceClock(5),
819 ThrottlerInput::SendMessage(42),
820 ThrottlerInput::SendMessage(42),
821 ]
822 .to_vec();
823
824 let test_throttler = test_throttler_buffered();
825 test_throttler_with_inputs(inputs, test_throttler);
826 }
827
828 #[rstest]
829 #[allow(unsafe_code)]
830 fn prop_test() {
831 let test_throttler = test_throttler_buffered();
832
833 proptest!(move |(inputs in throttler_test_strategy())| {
834 test_throttler_with_inputs(inputs, test_throttler.clone());
835 let throttler = unsafe { &mut *(test_throttler.throttler.get() as *mut _ as *mut Throttler<u64, Box<dyn Fn(u64)>>) };
837 throttler.reset();
838 throttler.clock.borrow_mut().reset();
839 });
840 }
841}