1use std::time::Duration;
25
26use anyhow;
27use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
28use rand::Rng;
29
30#[derive(Clone, Debug)]
31pub struct ExponentialBackoff {
32 delay_initial: Duration,
34 delay_max: Duration,
36 delay_current: Duration,
38 factor: f64,
40 jitter_ms: u64,
42 immediate_reconnect: bool,
44 immediate_reconnect_original: bool,
46}
47
48impl ExponentialBackoff {
56 pub fn new(
65 delay_initial: Duration,
66 delay_max: Duration,
67 factor: f64,
68 jitter_ms: u64,
69 immediate_first: bool,
70 ) -> anyhow::Result<Self> {
71 check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72 check_predicate_true(
73 delay_max >= delay_initial,
74 "delay_max must be >= delay_initial",
75 )?;
76 check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
77
78 Ok(Self {
79 delay_initial,
80 delay_max,
81 delay_current: delay_initial,
82 factor,
83 jitter_ms,
84 immediate_reconnect: immediate_first,
85 immediate_reconnect_original: immediate_first,
86 })
87 }
88
89 pub fn next_duration(&mut self) -> Duration {
95 if self.immediate_reconnect && self.delay_current == self.delay_initial {
96 self.immediate_reconnect = false;
97 return Duration::ZERO;
98 }
99
100 let jitter = rand::rng().random_range(0..=self.jitter_ms);
102 let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
103
104 let current_nanos = self.delay_current.as_nanos();
106 let max_nanos = self.delay_max.as_nanos() as u64;
107
108 let next_nanos = if current_nanos > u128::from(u64::MAX) {
110 max_nanos
112 } else {
113 let current_u64 = current_nanos as u64;
114 let next_f64 = current_u64 as f64 * self.factor;
115
116 if next_f64 > u64::MAX as f64 {
118 u64::MAX
119 } else {
120 next_f64 as u64
121 }
122 };
123
124 self.delay_current = Duration::from_nanos(std::cmp::min(next_nanos, max_nanos));
125
126 delay_with_jitter
127 }
128
129 pub const fn reset(&mut self) {
131 self.delay_current = self.delay_initial;
132 self.immediate_reconnect = self.immediate_reconnect_original;
133 }
134
135 #[must_use]
139 pub const fn current_delay(&self) -> Duration {
140 self.delay_current
141 }
142}
143
144#[cfg(test)]
148mod tests {
149 use std::time::Duration;
150
151 use rstest::rstest;
152
153 use super::*;
154
155 #[rstest]
156 fn test_no_jitter_exponential_growth() {
157 let initial = Duration::from_millis(100);
158 let max = Duration::from_millis(1600);
159 let factor = 2.0;
160 let jitter = 0;
161 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
162
163 let d1 = backoff.next_duration();
165 assert_eq!(d1, Duration::from_millis(100));
166
167 let d2 = backoff.next_duration();
169 assert_eq!(d2, Duration::from_millis(200));
170
171 let d3 = backoff.next_duration();
173 assert_eq!(d3, Duration::from_millis(400));
174
175 let d4 = backoff.next_duration();
177 assert_eq!(d4, Duration::from_millis(800));
178
179 let d5 = backoff.next_duration();
181 assert_eq!(d5, Duration::from_millis(1600));
182
183 let d6 = backoff.next_duration();
185 assert_eq!(d6, Duration::from_millis(1600));
186 }
187
188 #[rstest]
189 fn test_reset() {
190 let initial = Duration::from_millis(100);
191 let max = Duration::from_millis(1600);
192 let factor = 2.0;
193 let jitter = 0;
194 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
195
196 let _ = backoff.next_duration(); backoff.reset();
199 let d = backoff.next_duration();
200 assert_eq!(d, Duration::from_millis(100));
202 }
203
204 #[rstest]
205 fn test_jitter_within_bounds() {
206 let initial = Duration::from_millis(100);
207 let max = Duration::from_millis(1000);
208 let factor = 2.0;
209 let jitter = 50;
210 for _ in 0..10 {
212 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
213 let base = backoff.delay_current;
215 let delay = backoff.next_duration();
216 let min_expected = base;
218 let max_expected = base + Duration::from_millis(jitter);
219 assert!(
220 delay >= min_expected,
221 "Delay {delay:?} is less than expected minimum {min_expected:?}"
222 );
223 assert!(
224 delay <= max_expected,
225 "Delay {delay:?} exceeds expected maximum {max_expected:?}"
226 );
227 }
228 }
229
230 #[rstest]
231 fn test_factor_less_than_two() {
232 let initial = Duration::from_millis(100);
233 let max = Duration::from_millis(200);
234 let factor = 1.5;
235 let jitter = 0;
236 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
237
238 let d1 = backoff.next_duration();
240 assert_eq!(d1, Duration::from_millis(100));
241
242 let d2 = backoff.next_duration();
244 assert_eq!(d2, Duration::from_millis(150));
245
246 let d3 = backoff.next_duration();
248 assert_eq!(d3, Duration::from_millis(200));
249
250 let d4 = backoff.next_duration();
252 assert_eq!(d4, Duration::from_millis(200));
253 }
254
255 #[rstest]
256 fn test_max_delay_is_respected() {
257 let initial = Duration::from_millis(500);
258 let max = Duration::from_millis(1000);
259 let factor = 3.0;
260 let jitter = 0;
261 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
262
263 let d1 = backoff.next_duration();
265 assert_eq!(d1, Duration::from_millis(500));
266
267 let d2 = backoff.next_duration();
269 assert_eq!(d2, Duration::from_millis(1000));
270
271 let d3 = backoff.next_duration();
273 assert_eq!(d3, Duration::from_millis(1000));
274 }
275
276 #[rstest]
277 fn test_current_delay_getter() {
278 let initial = Duration::from_millis(100);
279 let max = Duration::from_millis(1600);
280 let factor = 2.0;
281 let jitter = 0;
282 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
283
284 assert_eq!(backoff.current_delay(), initial);
285
286 let _ = backoff.next_duration();
287 assert_eq!(backoff.current_delay(), Duration::from_millis(200));
288
289 let _ = backoff.next_duration();
290 assert_eq!(backoff.current_delay(), Duration::from_millis(400));
291
292 backoff.reset();
293 assert_eq!(backoff.current_delay(), initial);
294 }
295
296 #[rstest]
297 fn test_validation_zero_initial_delay() {
298 let result =
299 ExponentialBackoff::new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
300 assert!(result.is_err());
301 assert!(
302 result
303 .unwrap_err()
304 .to_string()
305 .contains("delay_initial must be non-zero")
306 );
307 }
308
309 #[rstest]
310 fn test_validation_max_less_than_initial() {
311 let result = ExponentialBackoff::new(
312 Duration::from_millis(1000),
313 Duration::from_millis(500),
314 2.0,
315 0,
316 false,
317 );
318 assert!(result.is_err());
319 assert!(
320 result
321 .unwrap_err()
322 .to_string()
323 .contains("delay_max must be >= delay_initial")
324 );
325 }
326
327 #[rstest]
328 fn test_validation_factor_too_small() {
329 let result = ExponentialBackoff::new(
330 Duration::from_millis(100),
331 Duration::from_millis(1000),
332 0.5,
333 0,
334 false,
335 );
336 assert!(result.is_err());
337 assert!(result.unwrap_err().to_string().contains("factor"));
338 }
339
340 #[rstest]
341 fn test_validation_factor_too_large() {
342 let result = ExponentialBackoff::new(
343 Duration::from_millis(100),
344 Duration::from_millis(1000),
345 150.0,
346 0,
347 false,
348 );
349 assert!(result.is_err());
350 assert!(result.unwrap_err().to_string().contains("factor"));
351 }
352
353 #[rstest]
354 fn test_immediate_first() {
355 let initial = Duration::from_millis(100);
356 let max = Duration::from_millis(1600);
357 let factor = 2.0;
358 let jitter = 0;
359 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
360
361 let d1 = backoff.next_duration();
363 assert_eq!(
364 d1,
365 Duration::ZERO,
366 "Expected immediate reconnect (zero delay) on first call"
367 );
368
369 let d2 = backoff.next_duration();
371 assert_eq!(
372 d2, initial,
373 "Expected the delay to be the initial delay after immediate reconnect"
374 );
375
376 let d3 = backoff.next_duration();
378 let expected = initial * 2; assert_eq!(
380 d3, expected,
381 "Expected exponential growth from the initial delay"
382 );
383 }
384
385 #[rstest]
386 fn test_reset_restores_immediate_first() {
387 let initial = Duration::from_millis(100);
388 let max = Duration::from_millis(1600);
389 let factor = 2.0;
390 let jitter = 0;
391 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
392
393 let d1 = backoff.next_duration();
395 assert_eq!(d1, Duration::ZERO);
396
397 let d2 = backoff.next_duration();
399 assert_eq!(d2, initial);
400
401 backoff.reset();
403 let d3 = backoff.next_duration();
404 assert_eq!(
405 d3,
406 Duration::ZERO,
407 "Reset should restore immediate_first behavior"
408 );
409 }
410}