nautilus_indicators/momentum/
amat.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::{
22 average::{MovingAverageFactory, MovingAverageType},
23 indicator::{Indicator, MovingAverage},
24};
25
26const DEFAULT_MA_TYPE: MovingAverageType = MovingAverageType::Exponential;
27const MAX_SIGNAL: usize = 1_024;
28
29type SignalBuf = ArrayDeque<f64, { MAX_SIGNAL + 1 }, Wrapping>;
30
31#[repr(C)]
32#[derive(Debug)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators", unsendable)
36)]
37pub struct ArcherMovingAveragesTrends {
38 pub fast_period: usize,
39 pub slow_period: usize,
40 pub signal_period: usize,
41 pub ma_type: MovingAverageType,
42 pub long_run: bool,
43 pub short_run: bool,
44 pub initialized: bool,
45 fast_ma: Box<dyn MovingAverage + Send + 'static>,
46 slow_ma: Box<dyn MovingAverage + Send + 'static>,
47 fast_ma_price: SignalBuf,
48 slow_ma_price: SignalBuf,
49 has_inputs: bool,
50}
51
52impl Display for ArcherMovingAveragesTrends {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(
55 f,
56 "{}({},{},{},{})",
57 self.name(),
58 self.fast_period,
59 self.slow_period,
60 self.signal_period,
61 self.ma_type,
62 )
63 }
64}
65
66impl Indicator for ArcherMovingAveragesTrends {
67 fn name(&self) -> String {
68 stringify!(ArcherMovingAveragesTrends).into()
69 }
70
71 fn has_inputs(&self) -> bool {
72 self.has_inputs
73 }
74
75 fn initialized(&self) -> bool {
76 self.initialized
77 }
78
79 fn handle_bar(&mut self, bar: &Bar) {
80 self.update_raw(bar.close.into());
81 }
82
83 fn reset(&mut self) {
84 self.fast_ma.reset();
85 self.slow_ma.reset();
86 self.long_run = false;
87 self.short_run = false;
88 self.fast_ma_price.clear();
89 self.slow_ma_price.clear();
90 self.has_inputs = false;
91 self.initialized = false;
92 }
93}
94
95impl ArcherMovingAveragesTrends {
96 #[must_use]
105 pub fn new(
106 fast_period: usize,
107 slow_period: usize,
108 signal_period: usize,
109 ma_type: Option<MovingAverageType>,
110 ) -> Self {
111 assert!(
112 fast_period > 0,
113 "fast_period must be positive (got {fast_period})"
114 );
115 assert!(
116 slow_period > 0,
117 "slow_period must be positive (got {slow_period})"
118 );
119 assert!(
120 signal_period > 0,
121 "signal_period must be positive (got {signal_period})"
122 );
123 assert!(
124 slow_period > fast_period,
125 "slow_period ({slow_period}) must be greater than fast_period ({fast_period})"
126 );
127 assert!(
128 signal_period <= MAX_SIGNAL,
129 "signal_period ({signal_period}) must not exceed MAX_SIGNAL ({MAX_SIGNAL})"
130 );
131
132 let ma_type = ma_type.unwrap_or(DEFAULT_MA_TYPE);
133
134 Self {
135 fast_period,
136 slow_period,
137 signal_period,
138 ma_type,
139 long_run: false,
140 short_run: false,
141 fast_ma: MovingAverageFactory::create(ma_type, fast_period),
142 slow_ma: MovingAverageFactory::create(ma_type, slow_period),
143 fast_ma_price: SignalBuf::new(),
144 slow_ma_price: SignalBuf::new(),
145 has_inputs: false,
146 initialized: false,
147 }
148 }
149
150 pub fn update_raw(&mut self, close: f64) {
155 self.fast_ma.update_raw(close);
156 self.slow_ma.update_raw(close);
157
158 if self.slow_ma.initialized() {
159 self.fast_ma_price.push_back(self.fast_ma.value());
160 self.slow_ma_price.push_back(self.slow_ma.value());
161
162 let max_len = self.signal_period + 1;
163 if self.fast_ma_price.len() > max_len {
164 self.fast_ma_price.pop_front();
165 self.slow_ma_price.pop_front();
166 }
167
168 let fast_back = self.fast_ma.value();
169 let fast_front = *self
170 .fast_ma_price
171 .front()
172 .expect("buffer has at least one element");
173
174 let fast_diff = fast_back - fast_front;
175 self.long_run = fast_diff > 0.0 || self.long_run;
176 self.short_run = fast_diff < 0.0 || self.short_run;
177 }
178
179 if !self.initialized {
180 self.has_inputs = true;
181 let max_len = self.signal_period + 1;
182 if self.slow_ma_price.len() == max_len && self.slow_ma.initialized() {
183 self.initialized = true;
184 }
185 }
186 }
187}
188
189#[cfg(test)]
193mod tests {
194 use rstest::rstest;
195
196 use super::*;
197 use crate::stubs::amat_345;
198
199 fn make(fast: usize, slow: usize, signal: usize) {
200 let _ = ArcherMovingAveragesTrends::new(fast, slow, signal, None);
201 }
202
203 #[rstest]
204 fn default_ma_type_is_exponential() {
205 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
206 assert_eq!(ind.ma_type, MovingAverageType::Exponential);
207 }
208
209 #[rstest]
210 fn test_name_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
211 assert_eq!(amat_345.name(), "ArcherMovingAveragesTrends");
212 }
213
214 #[rstest]
215 fn test_str_repr_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
216 assert_eq!(
217 format!("{amat_345}"),
218 "ArcherMovingAveragesTrends(3,4,5,SIMPLE)"
219 );
220 }
221
222 #[rstest]
223 fn test_period_returns_expected_value(amat_345: ArcherMovingAveragesTrends) {
224 assert_eq!(amat_345.fast_period, 3);
225 assert_eq!(amat_345.slow_period, 4);
226 assert_eq!(amat_345.signal_period, 5);
227 }
228
229 #[rstest]
230 fn test_initialized_without_inputs_returns_false(amat_345: ArcherMovingAveragesTrends) {
231 assert!(!amat_345.initialized());
232 }
233
234 #[rstest]
235 #[should_panic(expected = "fast_period must be positive")]
236 fn new_panics_on_zero_fast_period() {
237 make(0, 4, 5);
238 }
239
240 #[rstest]
241 #[should_panic(expected = "slow_period must be positive")]
242 fn new_panics_on_zero_slow_period() {
243 make(3, 0, 5);
244 }
245
246 #[rstest]
247 #[should_panic(expected = "signal_period must be positive")]
248 fn new_panics_on_zero_signal_period() {
249 make(3, 5, 0);
250 }
251
252 #[rstest]
253 #[should_panic(expected = "slow_period (3) must be greater than fast_period (3)")]
254 fn new_panics_when_slow_not_greater_than_fast() {
255 make(3, 3, 5);
256 }
257
258 #[rstest]
259 #[should_panic(expected = "slow_period (2) must be greater than fast_period (3)")]
260 fn new_panics_when_slow_less_than_fast() {
261 make(3, 2, 5);
262 }
263
264 fn feed_sequence(ind: &mut ArcherMovingAveragesTrends, start: i64, count: usize, step: i64) {
265 (0..count).for_each(|i| ind.update_raw((start + i as i64 * step) as f64));
266 }
267
268 #[rstest]
269 fn buffer_len_never_exceeds_signal_plus_one() {
270 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
271 feed_sequence(&mut ind, 0, 100, 1);
272 assert_eq!(ind.fast_ma_price.len(), ind.signal_period + 1);
273 assert_eq!(ind.slow_ma_price.len(), ind.signal_period + 1);
274 }
275
276 #[rstest]
277 fn initialized_becomes_true_after_slow_ready_and_buffer_full() {
278 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
279 feed_sequence(&mut ind, 0, 11, 1); assert!(ind.initialized());
281 }
282
283 #[rstest]
284 fn long_run_flag_sets_on_bullish_trend() {
285 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
286 feed_sequence(&mut ind, 0, 60, 1);
287 assert!(ind.long_run, "Expected long_run=TRUE on up-trend");
288 assert!(!ind.short_run, "short_run should remain FALSE here");
289 }
290
291 #[rstest]
292 fn short_run_flag_sets_on_bearish_trend() {
293 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
294 feed_sequence(&mut ind, 100, 60, -1);
295 assert!(ind.short_run, "Expected short_run=TRUE on down-trend");
296 assert!(!ind.long_run, "long_run should remain FALSE here");
297 }
298
299 #[rstest]
300 fn reset_clears_internal_state() {
301 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
302 feed_sequence(&mut ind, 0, 50, 1);
303 assert!(ind.long_run || ind.short_run);
304 assert!(!ind.fast_ma_price.is_empty());
305
306 ind.reset();
307
308 assert!(!ind.long_run && !ind.short_run);
309 assert_eq!(ind.fast_ma_price.len(), 0);
310 assert_eq!(ind.slow_ma_price.len(), 0);
311 assert!(!ind.initialized());
312 assert!(!ind.has_inputs());
313 }
314
315 #[rstest]
316 #[should_panic(expected = "signal_period (1025) must not exceed MAX_SIGNAL (1024)")]
317 fn new_panics_when_signal_exceeds_max() {
318 let _ = ArcherMovingAveragesTrends::new(3, 4, MAX_SIGNAL + 1, None);
319 }
320
321 #[rstest]
322 fn ma_type_override_is_respected() {
323 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, Some(MovingAverageType::Simple));
324 assert_eq!(ind.ma_type, MovingAverageType::Simple);
325 }
326}