nautilus_indicators/momentum/
cmo.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2023 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_model::data::{Bar, QuoteTick, TradeTick};
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators", unsendable)
30)]
31pub struct ChandeMomentumOscillator {
32    pub period: usize,
33    pub ma_type: MovingAverageType,
34    pub value: f64,
35    pub count: usize,
36    pub initialized: bool,
37    previous_close: f64,
38    average_gain: Box<dyn MovingAverage + Send + 'static>,
39    average_loss: Box<dyn MovingAverage + Send + 'static>,
40    has_inputs: bool,
41}
42
43impl Display for ChandeMomentumOscillator {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "{}({})", self.name(), self.period)
46    }
47}
48
49impl Indicator for ChandeMomentumOscillator {
50    fn name(&self) -> String {
51        stringify!(ChandeMomentumOscillator).to_string()
52    }
53
54    fn has_inputs(&self) -> bool {
55        self.has_inputs
56    }
57
58    fn initialized(&self) -> bool {
59        self.initialized
60    }
61
62    fn handle_quote(&mut self, _quote: &QuoteTick) {}
63
64    fn handle_trade(&mut self, _trade: &TradeTick) {}
65
66    fn handle_bar(&mut self, bar: &Bar) {
67        self.update_raw((&bar.close).into());
68    }
69
70    fn reset(&mut self) {
71        self.value = 0.0;
72        self.count = 0;
73        self.has_inputs = false;
74        self.initialized = false;
75        self.previous_close = 0.0;
76        self.average_gain.reset();
77        self.average_loss.reset();
78    }
79}
80
81impl ChandeMomentumOscillator {
82    /// Creates a new [`ChandeMomentumOscillator`] instance.
83    ///
84    /// # Panics
85    ///
86    /// Panics if `period` is not positive (> 0).
87    #[must_use]
88    pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
89        assert!(period > 0, "ChandeMomentumOscillator: period must be > 0");
90        let ma_type = ma_type.unwrap_or(MovingAverageType::Wilder);
91        Self {
92            period,
93            ma_type,
94            average_gain: MovingAverageFactory::create(ma_type, period),
95            average_loss: MovingAverageFactory::create(ma_type, period),
96            previous_close: 0.0,
97            value: 0.0,
98            count: 0,
99            initialized: false,
100            has_inputs: false,
101        }
102    }
103
104    pub fn update_raw(&mut self, close: f64) {
105        self.count += 1;
106        if !self.has_inputs {
107            self.previous_close = close;
108            self.has_inputs = true;
109        }
110
111        let gain: f64 = close - self.previous_close;
112        if gain > 0.0 {
113            self.average_gain.update_raw(gain);
114            self.average_loss.update_raw(0.0);
115        } else if gain < 0.0 {
116            self.average_gain.update_raw(0.0);
117            self.average_loss.update_raw(-gain);
118        } else {
119            self.average_gain.update_raw(0.0);
120            self.average_loss.update_raw(0.0);
121        }
122
123        if !self.initialized && self.average_gain.initialized() && self.average_loss.initialized() {
124            self.initialized = true;
125        }
126        if self.initialized {
127            let divisor = self.average_gain.value() + self.average_loss.value();
128            if divisor == 0.0 {
129                self.value = 0.0;
130            } else {
131                self.value =
132                    100.0 * (self.average_gain.value() - self.average_loss.value()) / divisor;
133            }
134        }
135        self.previous_close = close;
136    }
137}
138
139////////////////////////////////////////////////////////////////////////////////
140// Tests
141////////////////////////////////////////////////////////////////////////////////
142#[cfg(test)]
143mod tests {
144    use nautilus_model::data::{Bar, QuoteTick};
145    use rstest::rstest;
146
147    use crate::{
148        average::MovingAverageType, indicator::Indicator, momentum::cmo::ChandeMomentumOscillator,
149        stubs::*,
150    };
151
152    #[rstest]
153    fn test_cmo_initialized(cmo_10: ChandeMomentumOscillator) {
154        let display_str = format!("{cmo_10}");
155        assert_eq!(display_str, "ChandeMomentumOscillator(10)");
156        assert_eq!(cmo_10.period, 10);
157        assert!(!cmo_10.initialized);
158    }
159
160    #[rstest]
161    fn test_initialized_with_required_inputs_returns_true(mut cmo_10: ChandeMomentumOscillator) {
162        for i in 0..12 {
163            cmo_10.update_raw(f64::from(i));
164        }
165        assert!(cmo_10.initialized);
166    }
167
168    #[rstest]
169    fn test_value_all_higher_inputs_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
170        cmo_10.update_raw(109.93);
171        cmo_10.update_raw(110.0);
172        cmo_10.update_raw(109.77);
173        cmo_10.update_raw(109.96);
174        cmo_10.update_raw(110.29);
175        cmo_10.update_raw(110.53);
176        cmo_10.update_raw(110.27);
177        cmo_10.update_raw(110.21);
178        cmo_10.update_raw(110.06);
179        cmo_10.update_raw(110.19);
180        cmo_10.update_raw(109.83);
181        cmo_10.update_raw(109.9);
182        cmo_10.update_raw(110.0);
183        cmo_10.update_raw(110.03);
184        cmo_10.update_raw(110.13);
185        cmo_10.update_raw(109.95);
186        cmo_10.update_raw(109.75);
187        cmo_10.update_raw(110.15);
188        cmo_10.update_raw(109.9);
189        cmo_10.update_raw(110.04);
190        assert_eq!(cmo_10.value, 2.089_629_456_238_705_4);
191    }
192
193    #[rstest]
194    fn test_value_with_one_input_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
195        cmo_10.update_raw(1.00000);
196        assert_eq!(cmo_10.value, 0.0);
197    }
198
199    #[rstest]
200    fn test_reset(mut cmo_10: ChandeMomentumOscillator) {
201        cmo_10.update_raw(1.00020);
202        cmo_10.update_raw(1.00030);
203        cmo_10.update_raw(1.00050);
204        cmo_10.reset();
205        assert!(!cmo_10.initialized());
206        assert_eq!(cmo_10.count, 0);
207        assert_eq!(cmo_10.value, 0.0);
208        assert_eq!(cmo_10.previous_close, 0.0);
209    }
210
211    #[rstest]
212    fn test_handle_quote_tick(mut cmo_10: ChandeMomentumOscillator, stub_quote: QuoteTick) {
213        cmo_10.handle_quote(&stub_quote);
214        assert_eq!(cmo_10.count, 0);
215        assert_eq!(cmo_10.value, 0.0);
216    }
217
218    #[rstest]
219    fn test_handle_bar(mut cmo_10: ChandeMomentumOscillator, bar_ethusdt_binance_minute_bid: Bar) {
220        cmo_10.handle_bar(&bar_ethusdt_binance_minute_bid);
221        assert_eq!(cmo_10.count, 1);
222        assert_eq!(cmo_10.value, 0.0);
223    }
224
225    #[rstest]
226    fn test_ma_type_affects_value() {
227        let mut cmo_sma = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Simple));
228        let mut cmo_wilder = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Wilder));
229        let prices = [1.0, 2.0, 3.0, 2.5, 3.5];
230        for price in prices {
231            cmo_sma.update_raw(price);
232            cmo_wilder.update_raw(price);
233        }
234        assert_ne!(cmo_sma.value, cmo_wilder.value);
235    }
236
237    #[rstest]
238    fn test_count_increments(mut cmo_10: ChandeMomentumOscillator) {
239        for i in 0..5 {
240            cmo_10.update_raw(f64::from(i));
241        }
242        assert_eq!(cmo_10.count, 5);
243    }
244
245    #[rstest]
246    fn test_reset_resets_inner_mas() {
247        let mut cmo = ChandeMomentumOscillator::new(3, None);
248        for price in [1.0, 2.0, 3.0] {
249            cmo.update_raw(price);
250        }
251        assert!(cmo.average_gain.initialized());
252        assert!(cmo.average_loss.initialized());
253        assert_ne!(cmo.average_gain.value(), 0.0);
254        cmo.reset();
255        assert!(!cmo.average_gain.initialized());
256        assert!(!cmo.average_loss.initialized());
257        assert_eq!(cmo.average_gain.value(), 0.0);
258        assert_eq!(cmo.average_loss.value(), 0.0);
259    }
260
261    #[rstest]
262    #[should_panic]
263    fn test_invalid_period_panics() {
264        let _ = ChandeMomentumOscillator::new(0, None);
265    }
266
267    #[rstest]
268    fn test_ma_type_propagation() {
269        let cmo = ChandeMomentumOscillator::new(5, Some(MovingAverageType::Simple));
270        assert_eq!(cmo.ma_type, MovingAverageType::Simple);
271    }
272
273    #[rstest]
274    fn test_zero_divisor_returns_zero() {
275        let mut cmo = ChandeMomentumOscillator::new(3, None);
276        for _ in 0..5 {
277            cmo.update_raw(100.0);
278        }
279        assert!(cmo.initialized);
280        assert_eq!(cmo.value, 0.0);
281    }
282
283    #[rstest]
284    fn test_random_walk_values_within_bounds() {
285        let prices = [
286            100.0, 100.5, 99.8, 100.3, 101.0, 100.7, 101.5, 101.2, 100.6, 101.1, 100.9, 101.4,
287            100.8, 101.2, 100.6, 100.9, 101.3, 101.0, 100.5, 101.1, 100.7, 101.4, 100.9, 100.8,
288            101.2, 100.6, 100.9, 101.3, 101.0, 100.5,
289        ];
290        let mut cmo = ChandeMomentumOscillator::new(10, None);
291        for price in prices {
292            cmo.update_raw(price);
293        }
294        assert!(cmo.initialized);
295        assert!(cmo.value <= 100.0 && cmo.value >= -100.0);
296    }
297}