nautilus_indicators/momentum/
swings.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 1_024;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators")
30)]
31pub struct Swings {
32    pub period: usize,
33    pub direction: i64,
34    pub changed: bool,
35    pub high_datetime: f64,
36    pub low_datetime: f64,
37    pub high_price: f64,
38    pub low_price: f64,
39    pub length: usize,
40    pub duration: usize,
41    pub since_high: usize,
42    pub since_low: usize,
43    high_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
44    low_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
45    has_inputs: bool,
46    initialized: bool,
47}
48
49impl Display for Swings {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "{}({})", self.name(), self.period,)
52    }
53}
54
55impl Indicator for Swings {
56    fn name(&self) -> String {
57        stringify!(Swings).to_string()
58    }
59
60    fn has_inputs(&self) -> bool {
61        self.has_inputs
62    }
63
64    fn initialized(&self) -> bool {
65        self.initialized
66    }
67
68    fn handle_bar(&mut self, bar: &Bar) {
69        self.update_raw((&bar.high).into(), (&bar.low).into(), bar.ts_init.as_f64());
70    }
71
72    fn reset(&mut self) {
73        self.high_inputs.clear();
74        self.low_inputs.clear();
75        self.has_inputs = false;
76        self.initialized = false;
77        self.direction = 0;
78        self.changed = false;
79        self.high_datetime = 0.0;
80        self.low_datetime = 0.0;
81        self.high_price = 0.0;
82        self.low_price = 0.0;
83        self.length = 0;
84        self.duration = 0;
85        self.since_high = 0;
86        self.since_low = 0;
87    }
88}
89
90impl Swings {
91    /// Creates a new [`Swings`] instance.
92    ///
93    /// # Panics
94    ///
95    /// This function panics if:
96    /// - `period` is less than or equal to 0.
97    /// - `period` exceeds the maximum allowed value of `MAX_PERIOD`.
98    #[must_use]
99    pub fn new(period: usize) -> Self {
100        assert!(
101            period > 0 && period <= MAX_PERIOD,
102            "Swings: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
103        );
104
105        Self {
106            period,
107            high_inputs: ArrayDeque::new(),
108            low_inputs: ArrayDeque::new(),
109            has_inputs: false,
110            initialized: false,
111            direction: 0,
112            changed: false,
113            high_datetime: 0.0,
114            low_datetime: 0.0,
115            high_price: 0.0,
116            low_price: 0.0,
117            length: 0,
118            duration: 0,
119            since_high: 0,
120            since_low: 0,
121        }
122    }
123
124    pub fn update_raw(&mut self, high: f64, low: f64, timestamp: f64) {
125        self.changed = false;
126
127        if self.high_inputs.len() == self.period {
128            self.high_inputs.pop_front();
129        }
130        if self.low_inputs.len() == self.period {
131            self.low_inputs.pop_front();
132        }
133        let _ = self.high_inputs.push_back(high);
134        let _ = self.low_inputs.push_back(low);
135
136        let max_high = self.high_inputs.iter().fold(f64::MIN, |a, &b| a.max(b));
137        let min_low = self.low_inputs.iter().fold(f64::MAX, |a, &b| a.min(b));
138
139        let is_swing_high = high >= max_high && low >= min_low;
140        let is_swing_low = high <= max_high && low <= min_low;
141
142        if is_swing_high && is_swing_low {
143            if self.high_price == 0.0 {
144                self.high_price = high;
145                self.high_datetime = timestamp;
146            }
147            self.since_high += 1;
148            self.since_low += 1;
149        } else if is_swing_high {
150            if self.direction == -1 {
151                self.changed = true;
152            }
153            if high > self.high_price {
154                self.high_price = high;
155                self.high_datetime = timestamp;
156            }
157            self.direction = 1;
158            self.since_high = 0;
159            self.since_low += 1;
160        } else if is_swing_low {
161            if self.direction == 1 {
162                self.changed = true;
163            }
164            if self.high_price == 0.0 {
165                self.high_price = max_high;
166                self.high_datetime = timestamp;
167            }
168            if low < self.low_price || self.low_price == 0.0 {
169                self.low_price = low;
170                self.low_datetime = timestamp;
171            }
172            self.direction = -1;
173            self.since_high += 1;
174            self.since_low = 0;
175        } else {
176            self.since_high += 1;
177            self.since_low += 1;
178        }
179
180        self.has_inputs = true;
181
182        if self.high_price != 0.0 && self.low_price != 0.0 {
183            self.initialized = true;
184            self.length = ((self.high_price - self.low_price).abs().round()) as usize;
185            if self.direction == 1 {
186                self.duration = self.since_low;
187            } else if self.direction == -1 {
188                self.duration = self.since_high;
189            } else {
190                self.duration = 0;
191            }
192        }
193    }
194}
195
196////////////////////////////////////////////////////////////////////////////////
197// Tests
198////////////////////////////////////////////////////////////////////////////////
199#[cfg(test)]
200mod tests {
201    use rstest::rstest;
202
203    use super::*;
204    use crate::stubs::swings_10;
205
206    #[rstest]
207    fn test_name_returns_expected_string(swings_10: Swings) {
208        assert_eq!(swings_10.name(), "Swings");
209    }
210
211    #[rstest]
212    fn test_str_repr_returns_expected_string(swings_10: Swings) {
213        assert_eq!(format!("{swings_10}"), "Swings(10)");
214    }
215
216    #[rstest]
217    fn test_period_returns_expected_value(swings_10: Swings) {
218        assert_eq!(swings_10.period, 10);
219    }
220
221    #[rstest]
222    fn test_initialized_without_inputs_returns_false(swings_10: Swings) {
223        assert!(!swings_10.initialized());
224    }
225
226    #[rstest]
227    fn test_value_with_all_higher_inputs_returns_expected_value(mut swings_10: Swings) {
228        let high = [
229            0.9, 1.9, 2.9, 3.9, 4.9, 3.2, 6.9, 7.9, 8.9, 9.9, 1.1, 3.2, 10.3, 11.1, 11.4,
230        ];
231        let low = [
232            0.8, 1.8, 2.8, 3.8, 4.8, 3.1, 6.8, 7.8, 0.8, 9.8, 1.0, 3.1, 10.2, 11.0, 11.3,
233        ];
234        let time = [
235            1_643_723_400.0,
236            1_643_723_410.0,
237            1_643_723_420.0,
238            1_643_723_430.0,
239            1_643_723_440.0,
240            1_643_723_450.0,
241            1_643_723_460.0,
242            1_643_723_470.0,
243            1_643_723_480.0,
244            1_643_723_490.0,
245            1_643_723_500.0,
246            1_643_723_510.0,
247            1_643_723_520.0,
248            1_643_723_530.0,
249            1_643_723_540.0,
250        ];
251
252        for i in 0..15 {
253            swings_10.update_raw(high[i], low[i], time[i]);
254        }
255
256        assert_eq!(swings_10.direction, 1);
257        assert_eq!(swings_10.high_price, 11.4);
258        assert_eq!(swings_10.low_price, 0.0);
259        assert_eq!(swings_10.high_datetime, time[14]);
260        assert_eq!(swings_10.low_datetime, 0.0);
261        assert_eq!(swings_10.length, 0);
262        assert_eq!(swings_10.duration, 0);
263        assert_eq!(swings_10.since_high, 0);
264        assert_eq!(swings_10.since_low, 15);
265    }
266
267    #[rstest]
268    fn test_reset_successfully_returns_indicator_to_fresh_state(mut swings_10: Swings) {
269        let high = [1.0, 2.0, 3.0, 4.0, 5.0];
270        let low = [0.9, 1.9, 2.9, 3.9, 4.9];
271        let time = [
272            1_643_723_400.0,
273            1_643_723_410.0,
274            1_643_723_420.0,
275            1_643_723_430.0,
276            1_643_723_440.0,
277        ];
278
279        for i in 0..5 {
280            swings_10.update_raw(high[i], low[i], time[i]);
281        }
282
283        swings_10.reset();
284
285        assert!(!swings_10.initialized());
286        assert_eq!(swings_10.direction, 0);
287        assert_eq!(swings_10.high_price, 0.0);
288        assert_eq!(swings_10.low_price, 0.0);
289        assert_eq!(swings_10.high_datetime, 0.0);
290        assert_eq!(swings_10.low_datetime, 0.0);
291        assert_eq!(swings_10.length, 0);
292        assert_eq!(swings_10.duration, 0);
293        assert_eq!(swings_10.since_high, 0);
294        assert_eq!(swings_10.since_low, 0);
295        assert!(swings_10.high_inputs.is_empty());
296        assert!(swings_10.low_inputs.is_empty());
297    }
298
299    #[rstest]
300    fn test_changed_flag_flips() {
301        let mut swings = Swings::new(2);
302
303        swings.update_raw(1.0, 0.5, 1.0);
304        assert!(!swings.changed);
305
306        swings.update_raw(2.0, 1.5, 2.0);
307        assert!(!swings.changed);
308
309        swings.update_raw(0.0, -1.0, 3.0);
310        assert!(swings.changed);
311
312        swings.update_raw(-0.5, -1.5, 4.0);
313        assert!(!swings.changed);
314    }
315
316    #[rstest]
317    fn test_length_computation_after_initialization() {
318        let mut swings = Swings::new(2);
319        swings.update_raw(10.0, 9.0, 1.0);
320        swings.update_raw(8.0, 7.0, 2.0);
321        swings.update_raw(8.0, 7.5, 3.0);
322        assert_eq!(swings.length, 3);
323    }
324
325    #[rstest]
326    fn test_length_rounds_fractional_difference() {
327        let mut swings = Swings::new(2);
328        swings.update_raw(10.9, 10.7, 1.0);
329        swings.update_raw(9.7, 9.4, 2.0);
330        swings.update_raw(9.7, 9.4, 3.0);
331        assert_eq!(swings.length, 2);
332    }
333
334    #[rstest]
335    fn test_queue_eviction_does_not_exceed_capacity() {
336        let period = 3;
337        let mut swings = Swings::new(period);
338
339        let highs = [1.0, 2.0, 3.0, 4.0, 5.0];
340        let lows = [0.5, 1.5, 2.5, 3.5, 4.5];
341
342        for i in 0..highs.len() {
343            swings.update_raw(highs[i], lows[i], (i + 1) as f64);
344
345            assert!(swings.high_inputs.len() <= period);
346            assert!(swings.low_inputs.len() <= period);
347        }
348
349        assert_eq!(swings.high_inputs.len(), period);
350        assert_eq!(swings.low_inputs.len(), period);
351        assert_eq!(swings.high_inputs.front().copied(), Some(3.0));
352        assert_eq!(swings.low_inputs.front().copied(), Some(2.5));
353    }
354
355    #[rstest]
356    fn test_changed_flag_toggles_on_every_direction_flip() {
357        let mut swings = Swings::new(2);
358
359        swings.update_raw(1.0, 0.7, 1.0);
360        assert!(!swings.changed);
361        swings.update_raw(2.0, 1.7, 2.0);
362        assert!(!swings.changed);
363
364        swings.update_raw(0.0, -1.0, 3.0);
365        assert!(swings.changed);
366        swings.update_raw(-0.5, -1.5, 4.0);
367        assert!(!swings.changed);
368
369        swings.update_raw(2.5, 1.5, 5.0);
370        assert!(swings.changed);
371        swings.update_raw(3.0, 2.0, 6.0);
372        assert!(!swings.changed);
373    }
374
375    #[rstest]
376    fn test_length_precision_rounding() {
377        let mut swings = Swings::new(3);
378        swings.update_raw(10.49, 9.9, 1.0);
379        swings.update_raw(9.00, 8.0, 2.0);
380        swings.update_raw(9.00, 8.0, 3.0);
381        assert_eq!(swings.length, 2);
382
383        swings.reset();
384        swings.update_raw(10.5, 10.4, 10.0);
385        swings.update_raw(8.0, 7.5, 20.0);
386        swings.update_raw(8.0, 7.5, 30.0);
387        assert_eq!(swings.length, 3);
388
389        swings.reset();
390        swings.update_raw(10.8, 10.6, 40.0);
391        swings.update_raw(8.2, 7.4, 50.0);
392        swings.update_raw(8.2, 7.4, 60.0);
393        assert_eq!(swings.length, 3);
394    }
395}