nautilus_common/msgbus/
stubs.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    any::Any,
18    cell::RefCell,
19    fmt::Debug,
20    rc::Rc,
21    sync::{
22        Arc,
23        atomic::{AtomicBool, Ordering},
24    },
25};
26
27use nautilus_core::message::Message;
28use ustr::Ustr;
29use uuid::Uuid;
30
31use crate::msgbus::{ShareableMessageHandler, handler::MessageHandler};
32
33// Stub message handler which logs the data it receives
34pub struct StubMessageHandler {
35    id: Ustr,
36    callback: Arc<dyn Fn(Message) + Send>,
37}
38
39impl Debug for StubMessageHandler {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct(stringify!(StubMessageHandler))
42            .field("id", &self.id)
43            .finish()
44    }
45}
46
47impl MessageHandler for StubMessageHandler {
48    fn id(&self) -> Ustr {
49        self.id
50    }
51
52    fn handle(&self, message: &dyn Any) {
53        (self.callback)(message.downcast_ref::<Message>().unwrap().clone());
54    }
55
56    fn as_any(&self) -> &dyn Any {
57        self
58    }
59}
60
61#[must_use]
62#[allow(unused_must_use)] // TODO: Temporary to fix docs build
63pub fn get_stub_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
64    // TODO: This reduces the need to come up with ID strings in tests.
65    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
66    // which includes the memory address, just went with a UUID4 here.
67    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
68    ShareableMessageHandler(Rc::new(StubMessageHandler {
69        id: unique_id,
70        callback: Arc::new(|m: Message| {
71            format!("{m:?}");
72        }),
73    }))
74}
75
76// Stub message handler which checks if handle was called
77#[derive(Debug)]
78pub struct CallCheckMessageHandler {
79    id: Ustr,
80    called: Arc<AtomicBool>,
81}
82
83impl CallCheckMessageHandler {
84    #[must_use]
85    pub fn was_called(&self) -> bool {
86        self.called.load(Ordering::SeqCst)
87    }
88}
89
90impl MessageHandler for CallCheckMessageHandler {
91    fn id(&self) -> Ustr {
92        self.id
93    }
94
95    fn handle(&self, _message: &dyn Any) {
96        self.called.store(true, Ordering::SeqCst);
97    }
98
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102}
103
104#[must_use]
105pub fn get_call_check_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
106    // TODO: This reduces the need to come up with ID strings in tests.
107    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
108    // which includes the memory address, just went with a UUID4 here.
109    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
110    ShareableMessageHandler(Rc::new(CallCheckMessageHandler {
111        id: unique_id,
112        called: Arc::new(AtomicBool::new(false)),
113    }))
114}
115
116/// Returns whether the given `CallCheckMessageHandler` has been invoked at least once.
117///
118/// # Panics
119///
120/// Panics if the provided `handler` is not a `CallCheckMessageHandler`.
121#[must_use]
122pub fn check_handler_was_called(call_check_handler: ShareableMessageHandler) -> bool {
123    call_check_handler
124        .0
125        .as_ref()
126        .as_any()
127        .downcast_ref::<CallCheckMessageHandler>()
128        .unwrap()
129        .was_called()
130}
131
132// Handler which saves the messages it receives
133#[derive(Debug, Clone)]
134pub struct MessageSavingHandler<T> {
135    id: Ustr,
136    messages: Rc<RefCell<Vec<T>>>,
137}
138
139impl<T: Clone + 'static> MessageSavingHandler<T> {
140    #[must_use]
141    pub fn get_messages(&self) -> Vec<T> {
142        self.messages.borrow().clone()
143    }
144}
145
146impl<T: Clone + 'static> MessageHandler for MessageSavingHandler<T> {
147    fn id(&self) -> Ustr {
148        self.id
149    }
150
151    /// Handles an incoming message by saving it.
152    ///
153    /// # Panics
154    ///
155    /// Panics if the provided `message` is not of the expected type `T`.
156    fn handle(&self, message: &dyn Any) {
157        let mut messages = self.messages.borrow_mut();
158        match message.downcast_ref::<T>() {
159            Some(m) => messages.push(m.clone()),
160            None => panic!("MessageSavingHandler: message type mismatch {message:?}"),
161        }
162    }
163
164    fn as_any(&self) -> &dyn Any {
165        self
166    }
167}
168
169#[must_use]
170pub fn get_message_saving_handler<T: Clone + 'static>(id: Option<Ustr>) -> ShareableMessageHandler {
171    // TODO: This reduces the need to come up with ID strings in tests.
172    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
173    // which includes the memory address, just went with a UUID4 here.
174    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
175    ShareableMessageHandler(Rc::new(MessageSavingHandler::<T> {
176        id: unique_id,
177        messages: Rc::new(RefCell::new(Vec::new())),
178    }))
179}
180
181/// Retrieves the messages saved by a [`MessageSavingHandler`].
182///
183/// # Panics
184///
185/// Panics if the provided `handler` is not a `MessageSavingHandler<T>`.
186#[must_use]
187pub fn get_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) -> Vec<T> {
188    handler
189        .0
190        .as_ref()
191        .as_any()
192        .downcast_ref::<MessageSavingHandler<T>>()
193        .unwrap()
194        .get_messages()
195}