nautilus_indicators/average/
lr.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct LinearRegression {
32 pub period: usize,
33 pub slope: f64,
34 pub intercept: f64,
35 pub degree: f64,
36 pub cfo: f64,
37 pub r2: f64,
38 pub value: f64,
39 pub initialized: bool,
40 has_inputs: bool,
41 inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
42 x_sum: f64,
43 x_mul_sum: f64,
44 divisor: f64,
45}
46
47impl Display for LinearRegression {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}({})", self.name(), self.period)
50 }
51}
52
53impl Indicator for LinearRegression {
54 fn name(&self) -> String {
55 stringify!(LinearRegression).into()
56 }
57
58 fn has_inputs(&self) -> bool {
59 self.has_inputs
60 }
61
62 fn initialized(&self) -> bool {
63 self.initialized
64 }
65
66 fn handle_bar(&mut self, bar: &Bar) {
67 self.update_raw(bar.close.into());
68 }
69
70 fn reset(&mut self) {
71 self.slope = 0.0;
72 self.intercept = 0.0;
73 self.degree = 0.0;
74 self.cfo = 0.0;
75 self.r2 = 0.0;
76 self.value = 0.0;
77 self.inputs.clear();
78 self.has_inputs = false;
79 self.initialized = false;
80 }
81}
82
83impl LinearRegression {
84 #[must_use]
92 pub fn new(period: usize) -> Self {
93 assert!(
94 period > 0,
95 "LinearRegression: period must be > 0 (received {period})"
96 );
97 assert!(
98 period <= MAX_PERIOD,
99 "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
100 );
101
102 let n = period as f64;
103 let x_sum = 0.5 * n * (n + 1.0);
104 let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
105 let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
106
107 Self {
108 period,
109 slope: 0.0,
110 intercept: 0.0,
111 degree: 0.0,
112 cfo: 0.0,
113 r2: 0.0,
114 value: 0.0,
115 initialized: false,
116 has_inputs: false,
117 inputs: ArrayDeque::new(),
118 x_sum,
119 x_mul_sum,
120 divisor,
121 }
122 }
123
124 pub fn update_raw(&mut self, close: f64) {
131 if self.inputs.len() == self.period {
132 let _ = self.inputs.pop_front();
133 }
134 let _ = self.inputs.push_back(close);
135
136 self.has_inputs = true;
137 if self.inputs.len() < self.period {
138 return;
139 }
140 self.initialized = true;
141
142 let n = self.period as f64;
143 let x_sum = self.x_sum;
144 let x_mul_sum = self.x_mul_sum;
145 let divisor = self.divisor;
146
147 let (mut y_sum, mut xy_sum) = (0.0, 0.0);
148 for (i, &y) in self.inputs.iter().enumerate() {
149 let x = (i + 1) as f64;
150 y_sum += y;
151 xy_sum += x * y;
152 }
153
154 self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
155 self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
156
157 let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
158 for (i, &y) in self.inputs.iter().enumerate() {
159 let x = (i + 1) as f64;
160 let y_hat = self.slope.mul_add(x, self.intercept);
161 let resid = y_hat - y;
162 sse += resid * resid;
163 y_last = y;
164 e_last = resid;
165 }
166
167 self.value = y_last + e_last;
168 self.degree = self.slope.atan().to_degrees();
169 self.cfo = if y_last == 0.0 {
170 f64::NAN
171 } else {
172 100.0 * e_last / y_last
173 };
174
175 let mean = y_sum / n;
176 let sst: f64 = self
177 .inputs
178 .iter()
179 .map(|&y| {
180 let d = y - mean;
181 d * d
182 })
183 .sum();
184
185 self.r2 = if sst.abs() < f64::EPSILON {
186 f64::NAN
187 } else {
188 1.0 - sse / sst
189 };
190 }
191}
192
193#[cfg(test)]
197mod tests {
198 use nautilus_model::data::Bar;
199 use rstest::rstest;
200
201 use super::*;
202 use crate::{
203 average::lr::LinearRegression,
204 indicator::Indicator,
205 stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
206 };
207
208 #[rstest]
209 fn test_psl_initialized(indicator_lr_10: LinearRegression) {
210 let display_str = format!("{indicator_lr_10}");
211 assert_eq!(display_str, "LinearRegression(10)");
212 assert_eq!(indicator_lr_10.period, 10);
213 assert!(!indicator_lr_10.initialized);
214 assert!(!indicator_lr_10.has_inputs);
215 }
216
217 #[rstest]
218 #[should_panic(expected = "LinearRegression: period must be > 0")]
219 fn test_new_with_zero_period_panics() {
220 let _ = LinearRegression::new(0);
221 }
222
223 #[rstest]
224 fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
225 indicator_lr_10.update_raw(1.0);
226 assert_eq!(indicator_lr_10.value, 0.0);
227 }
228
229 #[rstest]
230 fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
231 indicator_lr_10.update_raw(1.0);
232 indicator_lr_10.update_raw(2.0);
233 indicator_lr_10.update_raw(3.0);
234 assert_eq!(indicator_lr_10.value, 0.0);
235 }
236
237 #[rstest]
238 fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
239 for i in 1..10 {
240 indicator_lr_10.update_raw(f64::from(i));
241 }
242 assert!(!indicator_lr_10.initialized);
243 indicator_lr_10.update_raw(10.0);
244 assert!(indicator_lr_10.initialized);
245 }
246
247 #[rstest]
248 fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
249 indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
250 assert_eq!(indicator_lr_10.value, 0.0);
251 assert!(indicator_lr_10.has_inputs);
252 assert!(!indicator_lr_10.initialized);
253 }
254
255 #[rstest]
256 fn test_reset(mut indicator_lr_10: LinearRegression) {
257 indicator_lr_10.update_raw(1.0);
258 indicator_lr_10.reset();
259 assert_eq!(indicator_lr_10.value, 0.0);
260 assert_eq!(indicator_lr_10.inputs.len(), 0);
261 assert_eq!(indicator_lr_10.slope, 0.0);
262 assert_eq!(indicator_lr_10.intercept, 0.0);
263 assert_eq!(indicator_lr_10.degree, 0.0);
264 assert_eq!(indicator_lr_10.cfo, 0.0);
265 assert_eq!(indicator_lr_10.r2, 0.0);
266 assert!(!indicator_lr_10.has_inputs);
267 assert!(!indicator_lr_10.initialized);
268 }
269
270 #[rstest]
271 fn test_inputs_len_never_exceeds_period() {
272 let mut lr = LinearRegression::new(3);
273 for i in 0..10 {
274 lr.update_raw(f64::from(i));
275 }
276 assert_eq!(lr.inputs.len(), lr.period);
277 }
278
279 #[rstest]
280 fn test_oldest_element_evicted() {
281 let mut lr = LinearRegression::new(4);
282 for v in 1..=5 {
283 lr.update_raw(f64::from(v));
284 }
285 assert!(!lr.inputs.contains(&1.0));
286 assert_eq!(lr.inputs.front(), Some(&2.0));
287 }
288
289 #[rstest]
290 fn test_recent_elements_preserved() {
291 let mut lr = LinearRegression::new(5);
292 for v in 0..5 {
293 lr.update_raw(f64::from(v));
294 }
295 lr.update_raw(99.0);
296 let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
297 assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
298 }
299
300 #[rstest]
301 fn test_multiple_evictions() {
302 let mut lr = LinearRegression::new(2);
303 lr.update_raw(10.0);
304 lr.update_raw(20.0);
305 lr.update_raw(30.0);
306 lr.update_raw(40.0);
307 assert_eq!(
308 lr.inputs.iter().copied().collect::<Vec<_>>(),
309 vec![30.0, 40.0]
310 );
311 }
312
313 #[rstest]
314 fn test_value_stable_after_eviction() {
315 let mut lr = LinearRegression::new(3);
316 lr.update_raw(1.0);
317 lr.update_raw(2.0);
318 lr.update_raw(3.0);
319 let before = lr.value;
320 lr.update_raw(4.0);
321 let after = lr.value;
322 assert!(after.is_finite());
323 assert_ne!(before, after);
324 }
325
326 #[rstest]
327 fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
328 indicator_lr_10.update_raw(1.00000);
329 indicator_lr_10.update_raw(1.00010);
330 indicator_lr_10.update_raw(1.00030);
331 indicator_lr_10.update_raw(1.00040);
332 indicator_lr_10.update_raw(1.00050);
333 indicator_lr_10.update_raw(1.00060);
334 indicator_lr_10.update_raw(1.00050);
335 indicator_lr_10.update_raw(1.00040);
336 indicator_lr_10.update_raw(1.00030);
337 indicator_lr_10.update_raw(1.00010);
338 indicator_lr_10.update_raw(1.00000);
339
340 assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
341 }
342
343 #[rstest]
344 fn r2_nan_for_constant_series() {
345 let mut lr = LinearRegression::new(5);
346 for _ in 0..5 {
347 lr.update_raw(42.0);
348 }
349 assert!(lr.initialized);
350 assert!(
351 lr.r2.is_nan(),
352 "R² should be NaN for a constant-value input series"
353 );
354 }
355
356 #[rstest]
357 fn cfo_nan_when_last_price_zero() {
358 let mut lr = LinearRegression::new(3);
359 lr.update_raw(1.0);
360 lr.update_raw(2.0);
361 lr.update_raw(0.0);
362 assert!(lr.initialized);
363 assert!(
364 lr.cfo.is_nan(),
365 "CFO should be NaN when the most-recent price equals zero"
366 );
367 }
368
369 #[rstest]
370 fn positive_slope_and_degree_for_uptrend() {
371 let mut lr = LinearRegression::new(4);
372 for v in 1..=4 {
373 lr.update_raw(f64::from(v));
374 }
375 assert!(lr.slope > 0.0, "slope expected positive for up-trend");
376 assert!(lr.degree > 0.0, "degree expected positive for up-trend");
377 }
378
379 #[rstest]
380 fn negative_slope_and_degree_for_downtrend() {
381 let mut lr = LinearRegression::new(4);
382 for v in (1..=4).rev() {
383 lr.update_raw(f64::from(v));
384 }
385 assert!(lr.slope < 0.0, "slope expected negative for down-trend");
386 assert!(lr.degree < 0.0, "degree expected negative for down-trend");
387 }
388
389 #[rstest]
390 fn not_initialized_until_enough_samples() {
391 let mut lr = LinearRegression::new(6);
392 for v in 0..5 {
393 lr.update_raw(f64::from(v));
394 }
395 assert!(
396 !lr.initialized,
397 "indicator should remain uninitialised with fewer than `period` inputs"
398 );
399 }
400
401 #[rstest]
402 #[case(128)]
403 #[case(1_024)]
404 #[case(16_384)]
405 fn large_period_initialisation_and_window_size(#[case] period: usize) {
406 let mut lr = LinearRegression::new(period);
407 for v in 0..period {
408 lr.update_raw(v as f64);
409 }
410 assert!(
411 lr.initialized,
412 "indicator should initialise after exactly `period` samples"
413 );
414 assert_eq!(
415 lr.inputs.len(),
416 period,
417 "internal window length must equal the configured period"
418 );
419 }
420
421 #[rstest]
422 fn cached_constants_correct() {
423 let period = 10;
424 let lr = LinearRegression::new(period);
425
426 let n = period as f64;
427 let expected_x_sum = 0.5 * n * (n + 1.0);
428 let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
429 let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
430
431 assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
432 assert!(
433 (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
434 "x_mul_sum mismatch"
435 );
436 assert!(
437 (lr.divisor - expected_divisor).abs() < 1e-12,
438 "divisor mismatch"
439 );
440 }
441
442 #[rstest]
443 fn cached_constants_immutable_through_updates() {
444 let mut lr = LinearRegression::new(5);
445
446 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
447
448 for v in 0..20 {
449 lr.update_raw(f64::from(v));
450 }
451
452 assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
453 assert_eq!(
454 lr.x_mul_sum, x_mul_sum,
455 "x_mul_sum must remain unchanged after updates"
456 );
457 assert_eq!(
458 lr.divisor, divisor,
459 "divisor must remain unchanged after updates"
460 );
461 }
462
463 #[rstest]
464 fn cached_constants_immutable_after_reset() {
465 let mut lr = LinearRegression::new(8);
466
467 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
468
469 for v in 0..8 {
470 lr.update_raw(f64::from(v));
471 }
472 lr.reset();
473
474 assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
475 assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
476 assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
477 }
478
479 const EPS: f64 = 1e-12;
480
481 #[rstest]
482 #[should_panic]
483 fn new_zero_period_panics() {
484 let _ = LinearRegression::new(0);
485 }
486
487 #[rstest]
488 #[should_panic]
489 fn new_period_exceeds_max_panics() {
490 let _ = LinearRegression::new(MAX_PERIOD + 1);
491 }
492
493 #[rstest(
494 period, value,
495 case(8, 5.0),
496 case(16, -3.1415)
497 )]
498 fn constant_non_zero_series(period: usize, value: f64) {
499 let mut lr = LinearRegression::new(period);
500
501 for _ in 0..period {
502 lr.update_raw(value);
503 }
504
505 assert!(lr.initialized());
506 assert!(lr.slope.abs() < EPS);
507 assert!((lr.intercept - value).abs() < EPS);
508 assert_eq!(lr.degree, 0.0);
509 assert!(lr.r2.is_nan());
510 assert!((lr.cfo).abs() < EPS);
511 assert!((lr.value - value).abs() < EPS);
512 }
513
514 #[rstest(period, case(4), case(32))]
515 fn constant_zero_series_cfo_nan(period: usize) {
516 let mut lr = LinearRegression::new(period);
517
518 for _ in 0..period {
519 lr.update_raw(0.0);
520 }
521
522 assert!(lr.initialized());
523 assert!(lr.cfo.is_nan());
524 }
525
526 #[rstest(period, case(6), case(13))]
527 fn reset_clears_state_but_keeps_constants(period: usize) {
528 let mut lr = LinearRegression::new(period);
529
530 for i in 1..=period {
531 lr.update_raw(i as f64);
532 }
533
534 let x_sum_before = lr.x_sum;
535 let x_mul_sum_before = lr.x_mul_sum;
536 let divisor_before = lr.divisor;
537
538 lr.reset();
539
540 assert!(!lr.initialized());
541 assert!(!lr.has_inputs());
542
543 assert!(lr.slope.abs() < EPS);
544 assert!(lr.intercept.abs() < EPS);
545 assert!(lr.degree.abs() < EPS);
546 assert!(lr.cfo.abs() < EPS);
547 assert!(lr.r2.abs() < EPS);
548 assert!(lr.value.abs() < EPS);
549
550 assert_eq!(lr.x_sum, x_sum_before);
551 assert_eq!(lr.x_mul_sum, x_mul_sum_before);
552 assert_eq!(lr.divisor, divisor_before);
553 }
554
555 #[rstest(period, case(5), case(31))]
556 fn perfect_linear_series(period: usize) {
557 const A: f64 = 2.0;
558 const B: f64 = -3.0;
559 let mut lr = LinearRegression::new(period);
560
561 for x in 1..=period {
562 lr.update_raw(A.mul_add(x as f64, B));
563 }
564
565 assert!(lr.initialized());
566 assert!((lr.slope - A).abs() < EPS);
567 assert!((lr.intercept - B).abs() < EPS);
568 assert!((lr.r2 - 1.0).abs() < EPS);
569 assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
570 }
571
572 #[rstest]
573 fn sliding_window_keeps_last_period() {
574 const P: usize = 4;
575 let mut lr = LinearRegression::new(P);
576 for i in 1..=P {
577 lr.update_raw(i as f64);
578 }
579 let slope_first_window = lr.slope;
580
581 lr.update_raw(-100.0);
582 assert!(lr.slope < slope_first_window);
583 assert_eq!(lr.inputs.len(), P);
584 assert_eq!(lr.inputs.front(), Some(&2.0));
585 }
586
587 #[rstest]
588 fn r2_between_zero_and_one() {
589 const P: usize = 32;
590 let mut lr = LinearRegression::new(P);
591 for x in 1..=P {
592 let noise = if x % 2 == 0 { 0.5 } else { -0.5 };
593 lr.update_raw(3.0f64.mul_add(x as f64, noise));
594 }
595 assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
596 }
597
598 #[rstest]
599 fn reset_before_initialized() {
600 let mut lr = LinearRegression::new(10);
601 lr.update_raw(1.0);
602 lr.reset();
603
604 assert!(!lr.initialized());
605 assert!(!lr.has_inputs());
606 assert_eq!(lr.inputs.len(), 0);
607 }
608}