nautilus_indicators/average/
lr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct LinearRegression {
32    pub period: usize,
33    pub slope: f64,
34    pub intercept: f64,
35    pub degree: f64,
36    pub cfo: f64,
37    pub r2: f64,
38    pub value: f64,
39    pub initialized: bool,
40    has_inputs: bool,
41    inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
42    x_sum: f64,
43    x_mul_sum: f64,
44    divisor: f64,
45}
46
47impl Display for LinearRegression {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}({})", self.name(), self.period)
50    }
51}
52
53impl Indicator for LinearRegression {
54    fn name(&self) -> String {
55        stringify!(LinearRegression).into()
56    }
57
58    fn has_inputs(&self) -> bool {
59        self.has_inputs
60    }
61
62    fn initialized(&self) -> bool {
63        self.initialized
64    }
65
66    fn handle_bar(&mut self, bar: &Bar) {
67        self.update_raw(bar.close.into());
68    }
69
70    fn reset(&mut self) {
71        self.slope = 0.0;
72        self.intercept = 0.0;
73        self.degree = 0.0;
74        self.cfo = 0.0;
75        self.r2 = 0.0;
76        self.value = 0.0;
77        self.inputs.clear();
78        self.has_inputs = false;
79        self.initialized = false;
80    }
81}
82
83impl LinearRegression {
84    /// Creates a new [`LinearRegression`] instance.
85    ///
86    /// # Panics
87    ///
88    /// This function panics if:
89    /// `period` is zero.
90    /// `period` exceeds [`MAX_PERIOD`].
91    #[must_use]
92    pub fn new(period: usize) -> Self {
93        assert!(
94            period > 0,
95            "LinearRegression: period must be > 0 (received {period})"
96        );
97        assert!(
98            period <= MAX_PERIOD,
99            "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
100        );
101
102        let n = period as f64;
103        let x_sum = 0.5 * n * (n + 1.0);
104        let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
105        let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
106
107        Self {
108            period,
109            slope: 0.0,
110            intercept: 0.0,
111            degree: 0.0,
112            cfo: 0.0,
113            r2: 0.0,
114            value: 0.0,
115            initialized: false,
116            has_inputs: false,
117            inputs: ArrayDeque::new(),
118            x_sum,
119            x_mul_sum,
120            divisor,
121        }
122    }
123
124    /// Updates the linear regression with a new data point.
125    ///
126    /// # Panics
127    ///
128    /// Panics if called with an empty window – this is protected against by the logic
129    /// that returns early until enough samples have been collected.
130    pub fn update_raw(&mut self, close: f64) {
131        if self.inputs.len() == self.period {
132            let _ = self.inputs.pop_front();
133        }
134        let _ = self.inputs.push_back(close);
135
136        self.has_inputs = true;
137        if self.inputs.len() < self.period {
138            return;
139        }
140        self.initialized = true;
141
142        let n = self.period as f64;
143        let x_sum = self.x_sum;
144        let x_mul_sum = self.x_mul_sum;
145        let divisor = self.divisor;
146
147        let (mut y_sum, mut xy_sum) = (0.0, 0.0);
148        for (i, &y) in self.inputs.iter().enumerate() {
149            let x = (i + 1) as f64;
150            y_sum += y;
151            xy_sum += x * y;
152        }
153
154        self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
155        self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
156
157        let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
158        for (i, &y) in self.inputs.iter().enumerate() {
159            let x = (i + 1) as f64;
160            let y_hat = self.slope.mul_add(x, self.intercept);
161            let resid = y_hat - y;
162            sse += resid * resid;
163            y_last = y;
164            e_last = resid;
165        }
166
167        self.value = y_last + e_last;
168        self.degree = self.slope.atan().to_degrees();
169        self.cfo = if y_last == 0.0 {
170            f64::NAN
171        } else {
172            100.0 * e_last / y_last
173        };
174
175        let mean = y_sum / n;
176        let sst: f64 = self
177            .inputs
178            .iter()
179            .map(|&y| {
180                let d = y - mean;
181                d * d
182            })
183            .sum();
184
185        self.r2 = if sst.abs() < f64::EPSILON {
186            f64::NAN
187        } else {
188            1.0 - sse / sst
189        };
190    }
191}
192
193////////////////////////////////////////////////////////////////////////////////
194// Tests
195////////////////////////////////////////////////////////////////////////////////
196#[cfg(test)]
197mod tests {
198    use nautilus_model::data::Bar;
199    use rstest::rstest;
200
201    use super::*;
202    use crate::{
203        average::lr::LinearRegression,
204        indicator::Indicator,
205        stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
206    };
207
208    #[rstest]
209    fn test_psl_initialized(indicator_lr_10: LinearRegression) {
210        let display_str = format!("{indicator_lr_10}");
211        assert_eq!(display_str, "LinearRegression(10)");
212        assert_eq!(indicator_lr_10.period, 10);
213        assert!(!indicator_lr_10.initialized);
214        assert!(!indicator_lr_10.has_inputs);
215    }
216
217    #[rstest]
218    #[should_panic(expected = "LinearRegression: period must be > 0")]
219    fn test_new_with_zero_period_panics() {
220        let _ = LinearRegression::new(0);
221    }
222
223    #[rstest]
224    fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
225        indicator_lr_10.update_raw(1.0);
226        assert_eq!(indicator_lr_10.value, 0.0);
227    }
228
229    #[rstest]
230    fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
231        indicator_lr_10.update_raw(1.0);
232        indicator_lr_10.update_raw(2.0);
233        indicator_lr_10.update_raw(3.0);
234        assert_eq!(indicator_lr_10.value, 0.0);
235    }
236
237    #[rstest]
238    fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
239        for i in 1..10 {
240            indicator_lr_10.update_raw(f64::from(i));
241        }
242        assert!(!indicator_lr_10.initialized);
243        indicator_lr_10.update_raw(10.0);
244        assert!(indicator_lr_10.initialized);
245    }
246
247    #[rstest]
248    fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
249        indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
250        assert_eq!(indicator_lr_10.value, 0.0);
251        assert!(indicator_lr_10.has_inputs);
252        assert!(!indicator_lr_10.initialized);
253    }
254
255    #[rstest]
256    fn test_reset(mut indicator_lr_10: LinearRegression) {
257        indicator_lr_10.update_raw(1.0);
258        indicator_lr_10.reset();
259        assert_eq!(indicator_lr_10.value, 0.0);
260        assert_eq!(indicator_lr_10.inputs.len(), 0);
261        assert_eq!(indicator_lr_10.slope, 0.0);
262        assert_eq!(indicator_lr_10.intercept, 0.0);
263        assert_eq!(indicator_lr_10.degree, 0.0);
264        assert_eq!(indicator_lr_10.cfo, 0.0);
265        assert_eq!(indicator_lr_10.r2, 0.0);
266        assert!(!indicator_lr_10.has_inputs);
267        assert!(!indicator_lr_10.initialized);
268    }
269
270    #[rstest]
271    fn test_inputs_len_never_exceeds_period() {
272        let mut lr = LinearRegression::new(3);
273        for i in 0..10 {
274            lr.update_raw(f64::from(i));
275        }
276        assert_eq!(lr.inputs.len(), lr.period);
277    }
278
279    #[rstest]
280    fn test_oldest_element_evicted() {
281        let mut lr = LinearRegression::new(4);
282        for v in 1..=5 {
283            lr.update_raw(f64::from(v));
284        }
285        assert!(!lr.inputs.contains(&1.0));
286        assert_eq!(lr.inputs.front(), Some(&2.0));
287    }
288
289    #[rstest]
290    fn test_recent_elements_preserved() {
291        let mut lr = LinearRegression::new(5);
292        for v in 0..5 {
293            lr.update_raw(f64::from(v));
294        }
295        lr.update_raw(99.0);
296        let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
297        assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
298    }
299
300    #[rstest]
301    fn test_multiple_evictions() {
302        let mut lr = LinearRegression::new(2);
303        lr.update_raw(10.0);
304        lr.update_raw(20.0);
305        lr.update_raw(30.0);
306        lr.update_raw(40.0);
307        assert_eq!(
308            lr.inputs.iter().copied().collect::<Vec<_>>(),
309            vec![30.0, 40.0]
310        );
311    }
312
313    #[rstest]
314    fn test_value_stable_after_eviction() {
315        let mut lr = LinearRegression::new(3);
316        lr.update_raw(1.0);
317        lr.update_raw(2.0);
318        lr.update_raw(3.0);
319        let before = lr.value;
320        lr.update_raw(4.0);
321        let after = lr.value;
322        assert!(after.is_finite());
323        assert_ne!(before, after);
324    }
325
326    #[rstest]
327    fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
328        indicator_lr_10.update_raw(1.00000);
329        indicator_lr_10.update_raw(1.00010);
330        indicator_lr_10.update_raw(1.00030);
331        indicator_lr_10.update_raw(1.00040);
332        indicator_lr_10.update_raw(1.00050);
333        indicator_lr_10.update_raw(1.00060);
334        indicator_lr_10.update_raw(1.00050);
335        indicator_lr_10.update_raw(1.00040);
336        indicator_lr_10.update_raw(1.00030);
337        indicator_lr_10.update_raw(1.00010);
338        indicator_lr_10.update_raw(1.00000);
339
340        assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
341    }
342
343    #[rstest]
344    fn r2_nan_for_constant_series() {
345        let mut lr = LinearRegression::new(5);
346        for _ in 0..5 {
347            lr.update_raw(42.0);
348        }
349        assert!(lr.initialized);
350        assert!(
351            lr.r2.is_nan(),
352            "R² should be NaN for a constant-value input series"
353        );
354    }
355
356    #[rstest]
357    fn cfo_nan_when_last_price_zero() {
358        let mut lr = LinearRegression::new(3);
359        lr.update_raw(1.0);
360        lr.update_raw(2.0);
361        lr.update_raw(0.0);
362        assert!(lr.initialized);
363        assert!(
364            lr.cfo.is_nan(),
365            "CFO should be NaN when the most-recent price equals zero"
366        );
367    }
368
369    #[rstest]
370    fn positive_slope_and_degree_for_uptrend() {
371        let mut lr = LinearRegression::new(4);
372        for v in 1..=4 {
373            lr.update_raw(f64::from(v));
374        }
375        assert!(lr.slope > 0.0, "slope expected positive for up-trend");
376        assert!(lr.degree > 0.0, "degree expected positive for up-trend");
377    }
378
379    #[rstest]
380    fn negative_slope_and_degree_for_downtrend() {
381        let mut lr = LinearRegression::new(4);
382        for v in (1..=4).rev() {
383            lr.update_raw(f64::from(v));
384        }
385        assert!(lr.slope < 0.0, "slope expected negative for down-trend");
386        assert!(lr.degree < 0.0, "degree expected negative for down-trend");
387    }
388
389    #[rstest]
390    fn not_initialized_until_enough_samples() {
391        let mut lr = LinearRegression::new(6);
392        for v in 0..5 {
393            lr.update_raw(f64::from(v));
394        }
395        assert!(
396            !lr.initialized,
397            "indicator should remain uninitialised with fewer than `period` inputs"
398        );
399    }
400
401    #[rstest]
402    #[case(128)]
403    #[case(1_024)]
404    #[case(16_384)]
405    fn large_period_initialisation_and_window_size(#[case] period: usize) {
406        let mut lr = LinearRegression::new(period);
407        for v in 0..period {
408            lr.update_raw(v as f64);
409        }
410        assert!(
411            lr.initialized,
412            "indicator should initialise after exactly `period` samples"
413        );
414        assert_eq!(
415            lr.inputs.len(),
416            period,
417            "internal window length must equal the configured period"
418        );
419    }
420
421    #[rstest]
422    fn cached_constants_correct() {
423        let period = 10;
424        let lr = LinearRegression::new(period);
425
426        let n = period as f64;
427        let expected_x_sum = 0.5 * n * (n + 1.0);
428        let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
429        let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
430
431        assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
432        assert!(
433            (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
434            "x_mul_sum mismatch"
435        );
436        assert!(
437            (lr.divisor - expected_divisor).abs() < 1e-12,
438            "divisor mismatch"
439        );
440    }
441
442    #[rstest]
443    fn cached_constants_immutable_through_updates() {
444        let mut lr = LinearRegression::new(5);
445
446        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
447
448        for v in 0..20 {
449            lr.update_raw(f64::from(v));
450        }
451
452        assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
453        assert_eq!(
454            lr.x_mul_sum, x_mul_sum,
455            "x_mul_sum must remain unchanged after updates"
456        );
457        assert_eq!(
458            lr.divisor, divisor,
459            "divisor must remain unchanged after updates"
460        );
461    }
462
463    #[rstest]
464    fn cached_constants_immutable_after_reset() {
465        let mut lr = LinearRegression::new(8);
466
467        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
468
469        for v in 0..8 {
470            lr.update_raw(f64::from(v));
471        }
472        lr.reset();
473
474        assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
475        assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
476        assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
477    }
478
479    const EPS: f64 = 1e-12;
480
481    #[rstest]
482    #[should_panic]
483    fn new_zero_period_panics() {
484        let _ = LinearRegression::new(0);
485    }
486
487    #[rstest]
488    #[should_panic]
489    fn new_period_exceeds_max_panics() {
490        let _ = LinearRegression::new(MAX_PERIOD + 1);
491    }
492
493    #[rstest(
494        period, value,
495        case(8, 5.0),
496        case(16, -3.1415)
497    )]
498    fn constant_non_zero_series(period: usize, value: f64) {
499        let mut lr = LinearRegression::new(period);
500
501        for _ in 0..period {
502            lr.update_raw(value);
503        }
504
505        assert!(lr.initialized());
506        assert!(lr.slope.abs() < EPS);
507        assert!((lr.intercept - value).abs() < EPS);
508        assert_eq!(lr.degree, 0.0);
509        assert!(lr.r2.is_nan());
510        assert!((lr.cfo).abs() < EPS);
511        assert!((lr.value - value).abs() < EPS);
512    }
513
514    #[rstest(period, case(4), case(32))]
515    fn constant_zero_series_cfo_nan(period: usize) {
516        let mut lr = LinearRegression::new(period);
517
518        for _ in 0..period {
519            lr.update_raw(0.0);
520        }
521
522        assert!(lr.initialized());
523        assert!(lr.cfo.is_nan());
524    }
525
526    #[rstest(period, case(6), case(13))]
527    fn reset_clears_state_but_keeps_constants(period: usize) {
528        let mut lr = LinearRegression::new(period);
529
530        for i in 1..=period {
531            lr.update_raw(i as f64);
532        }
533
534        let x_sum_before = lr.x_sum;
535        let x_mul_sum_before = lr.x_mul_sum;
536        let divisor_before = lr.divisor;
537
538        lr.reset();
539
540        assert!(!lr.initialized());
541        assert!(!lr.has_inputs());
542
543        assert!(lr.slope.abs() < EPS);
544        assert!(lr.intercept.abs() < EPS);
545        assert!(lr.degree.abs() < EPS);
546        assert!(lr.cfo.abs() < EPS);
547        assert!(lr.r2.abs() < EPS);
548        assert!(lr.value.abs() < EPS);
549
550        assert_eq!(lr.x_sum, x_sum_before);
551        assert_eq!(lr.x_mul_sum, x_mul_sum_before);
552        assert_eq!(lr.divisor, divisor_before);
553    }
554
555    #[rstest(period, case(5), case(31))]
556    fn perfect_linear_series(period: usize) {
557        const A: f64 = 2.0;
558        const B: f64 = -3.0;
559        let mut lr = LinearRegression::new(period);
560
561        for x in 1..=period {
562            lr.update_raw(A.mul_add(x as f64, B));
563        }
564
565        assert!(lr.initialized());
566        assert!((lr.slope - A).abs() < EPS);
567        assert!((lr.intercept - B).abs() < EPS);
568        assert!((lr.r2 - 1.0).abs() < EPS);
569        assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
570    }
571
572    #[rstest]
573    fn sliding_window_keeps_last_period() {
574        const P: usize = 4;
575        let mut lr = LinearRegression::new(P);
576        for i in 1..=P {
577            lr.update_raw(i as f64);
578        }
579        let slope_first_window = lr.slope;
580
581        lr.update_raw(-100.0);
582        assert!(lr.slope < slope_first_window);
583        assert_eq!(lr.inputs.len(), P);
584        assert_eq!(lr.inputs.front(), Some(&2.0));
585    }
586
587    #[rstest]
588    fn r2_between_zero_and_one() {
589        const P: usize = 32;
590        let mut lr = LinearRegression::new(P);
591        for x in 1..=P {
592            let noise = if x % 2 == 0 { 0.5 } else { -0.5 };
593            lr.update_raw(3.0f64.mul_add(x as f64, noise));
594        }
595        assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
596    }
597
598    #[rstest]
599    fn reset_before_initialized() {
600        let mut lr = LinearRegression::new(10);
601        lr.update_raw(1.0);
602        lr.reset();
603
604        assert!(!lr.initialized());
605        assert!(!lr.has_inputs());
606        assert_eq!(lr.inputs.len(), 0);
607    }
608}