nautilus_network/python/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{Arc, atomic::Ordering},
18    time::Duration,
19};
20
21use nautilus_core::python::{to_pyruntime_err, to_pyvalue_err};
22use pyo3::{create_exception, exceptions::PyException, prelude::*};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26    mode::ConnectionMode,
27    ratelimiter::quota::Quota,
28    websocket::{Consumer, WebSocketClient, WebSocketConfig, WriterCommand},
29};
30
31// Python exception class for websocket errors
32create_exception!(network, WebSocketClientError, PyException);
33
34fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
35    PyErr::new::<WebSocketClientError, _>(e.to_string())
36}
37
38#[pymethods]
39impl WebSocketConfig {
40    #[new]
41    #[allow(clippy::too_many_arguments)]
42    #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43    fn py_new(
44        url: String,
45        handler: PyObject,
46        headers: Vec<(String, String)>,
47        heartbeat: Option<u64>,
48        heartbeat_msg: Option<String>,
49        ping_handler: Option<PyObject>,
50        reconnect_timeout_ms: Option<u64>,
51        reconnect_delay_initial_ms: Option<u64>,
52        reconnect_delay_max_ms: Option<u64>,
53        reconnect_backoff_factor: Option<f64>,
54        reconnect_jitter_ms: Option<u64>,
55    ) -> Self {
56        Self {
57            url,
58            handler: Consumer::Python(Some(Arc::new(handler))),
59            headers,
60            heartbeat,
61            heartbeat_msg,
62            ping_handler: ping_handler.map(Arc::new),
63            reconnect_timeout_ms,
64            reconnect_delay_initial_ms,
65            reconnect_delay_max_ms,
66            reconnect_backoff_factor,
67            reconnect_jitter_ms,
68        }
69    }
70}
71
72#[pymethods]
73impl WebSocketClient {
74    /// Create a websocket client.
75    ///
76    /// # Safety
77    ///
78    /// - Throws an Exception if it is unable to make websocket connection.
79    #[staticmethod]
80    #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
81    fn py_connect(
82        config: WebSocketConfig,
83        post_connection: Option<PyObject>,
84        post_reconnection: Option<PyObject>,
85        post_disconnection: Option<PyObject>,
86        keyed_quotas: Vec<(String, Quota)>,
87        default_quota: Option<Quota>,
88        py: Python<'_>,
89    ) -> PyResult<Bound<'_, PyAny>> {
90        pyo3_async_runtimes::tokio::future_into_py(py, async move {
91            Self::connect(
92                config,
93                post_connection,
94                post_reconnection,
95                post_disconnection,
96                keyed_quotas,
97                default_quota,
98            )
99            .await
100            .map_err(to_websocket_pyerr)
101        })
102    }
103
104    /// Closes the client heart beat and reader task.
105    ///
106    /// The connection is not completely closed the till all references
107    /// to the client are gone and the client is dropped.
108    ///
109    /// # Safety
110    ///
111    /// - The client should not be used after closing it.
112    /// - Any auto-reconnect job should be aborted before closing the client.
113    #[pyo3(name = "disconnect")]
114    fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
115        let connection_mode = slf.connection_mode.clone();
116        let mode = ConnectionMode::from_atomic(&connection_mode);
117        tracing::debug!("Close from mode {mode}");
118
119        pyo3_async_runtimes::tokio::future_into_py(py, async move {
120            match ConnectionMode::from_atomic(&connection_mode) {
121                ConnectionMode::Closed => {
122                    tracing::warn!("WebSocket already closed");
123                }
124                ConnectionMode::Disconnect => {
125                    tracing::warn!("WebSocket already disconnecting");
126                }
127                _ => {
128                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
129                    while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
130                        tokio::time::sleep(Duration::from_millis(10)).await;
131                    }
132                }
133            }
134
135            Ok(())
136        })
137    }
138
139    /// Check if the client is still alive.
140    ///
141    /// Even if the connection is disconnected the client will still be alive
142    /// and trying to reconnect.
143    ///
144    /// This is particularly useful for checking why a `send` failed. It could
145    /// be because the connection disconnected and the client is still alive
146    /// and reconnecting. In such cases the send can be retried after some
147    /// delay.
148    #[pyo3(name = "is_active")]
149    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
150        !slf.controller_task.is_finished()
151    }
152
153    #[pyo3(name = "is_reconnecting")]
154    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
155        slf.is_reconnecting()
156    }
157
158    #[pyo3(name = "is_disconnecting")]
159    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
160        slf.is_disconnecting()
161    }
162
163    #[pyo3(name = "is_closed")]
164    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
165        slf.is_closed()
166    }
167
168    /// Send bytes data to the server.
169    ///
170    /// # Errors
171    ///
172    /// - Raises `PyRuntimeError` if not able to send data.
173    #[pyo3(name = "send")]
174    #[pyo3(signature = (data, keys=None))]
175    fn py_send<'py>(
176        slf: PyRef<'_, Self>,
177        data: Vec<u8>,
178        py: Python<'py>,
179        keys: Option<Vec<String>>,
180    ) -> PyResult<Bound<'py, PyAny>> {
181        let rate_limiter = slf.rate_limiter.clone();
182        let writer_tx = slf.writer_tx.clone();
183        let mode = slf.connection_mode.clone();
184
185        pyo3_async_runtimes::tokio::future_into_py(py, async move {
186            if !ConnectionMode::from_atomic(&mode).is_active() {
187                let msg = "Cannot send data: connection not active".to_string();
188                tracing::error!("{msg}");
189                return Err(to_pyruntime_err(std::io::Error::new(
190                    std::io::ErrorKind::NotConnected,
191                    msg,
192                )));
193            }
194            rate_limiter.await_keys_ready(keys).await;
195            tracing::trace!("Sending binary: {data:?}");
196
197            let msg = Message::Binary(data.into());
198            writer_tx
199                .send(WriterCommand::Send(msg))
200                .map_err(to_pyruntime_err)
201        })
202    }
203
204    /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
205    ///
206    /// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
207    /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
208    ///
209    /// # Errors
210    /// - Raises `PyRuntimeError` if unable to send the data.
211    ///
212    /// # Example
213    ///
214    /// When a request is made the URL should be split into all relevant keys within it.
215    ///
216    /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
217    #[pyo3(name = "send_text")]
218    #[pyo3(signature = (data, keys=None))]
219    fn py_send_text<'py>(
220        slf: PyRef<'_, Self>,
221        data: Vec<u8>,
222        py: Python<'py>,
223        keys: Option<Vec<String>>,
224    ) -> PyResult<Bound<'py, PyAny>> {
225        let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
226        let data = Utf8Bytes::from(data_str);
227        let rate_limiter = slf.rate_limiter.clone();
228        let writer_tx = slf.writer_tx.clone();
229        let mode = slf.connection_mode.clone();
230
231        pyo3_async_runtimes::tokio::future_into_py(py, async move {
232            if !ConnectionMode::from_atomic(&mode).is_active() {
233                let err = std::io::Error::new(
234                    std::io::ErrorKind::NotConnected,
235                    "Cannot send text: connection not active",
236                );
237                return Err(to_pyruntime_err(err));
238            }
239            rate_limiter.await_keys_ready(keys).await;
240            tracing::trace!("Sending text: {data}");
241
242            let msg = Message::Text(data);
243            writer_tx
244                .send(WriterCommand::Send(msg))
245                .map_err(to_pyruntime_err)
246        })
247    }
248
249    /// Send pong bytes data to the server.
250    ///
251    /// # Errors
252    ///
253    /// - Raises `PyRuntimeError` if not able to send data.
254    #[pyo3(name = "send_pong")]
255    fn py_send_pong<'py>(
256        slf: PyRef<'_, Self>,
257        data: Vec<u8>,
258        py: Python<'py>,
259    ) -> PyResult<Bound<'py, PyAny>> {
260        let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
261        let writer_tx = slf.writer_tx.clone();
262        let mode = slf.connection_mode.clone();
263
264        pyo3_async_runtimes::tokio::future_into_py(py, async move {
265            if !ConnectionMode::from_atomic(&mode).is_active() {
266                let err = std::io::Error::new(
267                    std::io::ErrorKind::NotConnected,
268                    "Cannot send pong: connection not active",
269                );
270                return Err(to_pyruntime_err(err));
271            }
272            tracing::trace!("Sending pong: {data_str}");
273
274            let msg = Message::Pong(data.into());
275            writer_tx
276                .send(WriterCommand::Send(msg))
277                .map_err(to_pyruntime_err)
278        })
279    }
280}
281
282////////////////////////////////////////////////////////////////////////////////
283// Tests
284////////////////////////////////////////////////////////////////////////////////
285#[cfg(test)]
286#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
287mod tests {
288    use std::ffi::CString;
289
290    use futures_util::{SinkExt, StreamExt};
291    use nautilus_core::python::IntoPyObjectPoseiExt;
292    use pyo3::{prelude::*, prepare_freethreaded_python};
293    use tokio::{
294        net::TcpListener,
295        task::{self, JoinHandle},
296        time::{Duration, sleep},
297    };
298    use tokio_tungstenite::{
299        accept_hdr_async,
300        tungstenite::{
301            handshake::server::{self, Callback},
302            http::HeaderValue,
303        },
304    };
305    use tracing_test::traced_test;
306
307    use crate::websocket::{WebSocketClient, WebSocketConfig};
308
309    struct TestServer {
310        task: JoinHandle<()>,
311        port: u16,
312    }
313
314    #[derive(Debug, Clone)]
315    struct TestCallback {
316        key: String,
317        value: HeaderValue,
318    }
319
320    impl Callback for TestCallback {
321        fn on_request(
322            self,
323            request: &server::Request,
324            response: server::Response,
325        ) -> Result<server::Response, server::ErrorResponse> {
326            let _ = response;
327            let value = request.headers().get(&self.key);
328            assert!(value.is_some());
329
330            if let Some(value) = request.headers().get(&self.key) {
331                assert_eq!(value, self.value);
332            }
333
334            Ok(response)
335        }
336    }
337
338    impl TestServer {
339        async fn setup(key: String, value: String) -> Self {
340            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
341            let port = TcpListener::local_addr(&server).unwrap().port();
342
343            let test_call_back = TestCallback {
344                key,
345                value: HeaderValue::from_str(&value).unwrap(),
346            };
347
348            // Set up test server
349            let task = task::spawn(async move {
350                // Keep accepting connections
351                loop {
352                    let (conn, _) = server.accept().await.unwrap();
353                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
354                        .await
355                        .unwrap();
356
357                    task::spawn(async move {
358                        while let Some(Ok(msg)) = websocket.next().await {
359                            match msg {
360                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
361                                    if txt == "close-now" =>
362                                {
363                                    tracing::debug!("Forcibly closing from server side");
364                                    // This sends a close frame, then stops reading
365                                    let _ = websocket.close(None).await;
366                                    break;
367                                }
368                                // Echo text/binary frames
369                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
370                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
371                                    if websocket.send(msg).await.is_err() {
372                                        break;
373                                    }
374                                }
375                                // If the client closes, we also break
376                                tokio_tungstenite::tungstenite::protocol::Message::Close(
377                                    _frame,
378                                ) => {
379                                    let _ = websocket.close(None).await;
380                                    break;
381                                }
382                                // Ignore pings/pongs
383                                _ => {}
384                            }
385                        }
386                    });
387                }
388            });
389
390            Self { task, port }
391        }
392    }
393
394    impl Drop for TestServer {
395        fn drop(&mut self) {
396            self.task.abort();
397        }
398    }
399
400    fn create_test_handler() -> (PyObject, PyObject) {
401        let code_raw = r"
402class Counter:
403    def __init__(self):
404        self.count = 0
405        self.check = False
406
407    def handler(self, bytes):
408        msg = bytes.decode()
409        if msg == 'ping':
410            self.count += 1
411        elif msg == 'heartbeat message':
412            self.check = True
413
414    def get_check(self):
415        return self.check
416
417    def get_count(self):
418        return self.count
419
420counter = Counter()
421";
422
423        let code = CString::new(code_raw).unwrap();
424        let filename = CString::new("test".to_string()).unwrap();
425        let module = CString::new("test".to_string()).unwrap();
426        Python::with_gil(|py| {
427            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
428
429            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
430            let handler = counter
431                .getattr(py, "handler")
432                .unwrap()
433                .into_py_any_unwrap(py);
434
435            (counter, handler)
436        })
437    }
438
439    #[tokio::test]
440    #[traced_test]
441    async fn basic_client_test() {
442        prepare_freethreaded_python();
443
444        const N: usize = 10;
445        let mut success_count = 0;
446        let header_key = "hello-custom-key".to_string();
447        let header_value = "hello-custom-value".to_string();
448
449        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
450        let (counter, handler) = create_test_handler();
451
452        let config = WebSocketConfig::py_new(
453            format!("ws://127.0.0.1:{}", server.port),
454            Python::with_gil(|py| handler.clone_ref(py)),
455            vec![(header_key, header_value)],
456            None,
457            None,
458            None,
459            None,
460            None,
461            None,
462            None,
463            None,
464        );
465        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
466            .await
467            .unwrap();
468
469        // Send messages that increment the count
470        for _ in 0..N {
471            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
472            success_count += 1;
473        }
474
475        // Check count is same as number messages sent
476        sleep(Duration::from_secs(1)).await;
477        let count_value: usize = Python::with_gil(|py| {
478            counter
479                .getattr(py, "get_count")
480                .unwrap()
481                .call0(py)
482                .unwrap()
483                .extract(py)
484                .unwrap()
485        });
486        assert_eq!(count_value, success_count);
487
488        // Close the connection => client should reconnect automatically
489        client.send_close_message().await.unwrap();
490
491        // Send messages that increment the count
492        sleep(Duration::from_secs(2)).await;
493        for _ in 0..N {
494            client.send_bytes(b"ping".to_vec(), None).await.unwrap();
495            success_count += 1;
496        }
497
498        // Check count is same as number messages sent
499        sleep(Duration::from_secs(1)).await;
500        let count_value: usize = Python::with_gil(|py| {
501            counter
502                .getattr(py, "get_count")
503                .unwrap()
504                .call0(py)
505                .unwrap()
506                .extract(py)
507                .unwrap()
508        });
509        assert_eq!(count_value, success_count);
510        assert_eq!(success_count, N + N);
511
512        // Cleanup
513        client.disconnect().await;
514        assert!(client.is_disconnected());
515    }
516
517    #[tokio::test]
518    #[traced_test]
519    async fn message_ping_test() {
520        prepare_freethreaded_python();
521
522        let header_key = "hello-custom-key".to_string();
523        let header_value = "hello-custom-value".to_string();
524
525        let (checker, handler) = create_test_handler();
526
527        // Initialize test server and config
528        let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
529        let config = WebSocketConfig::py_new(
530            format!("ws://127.0.0.1:{}", server.port),
531            Python::with_gil(|py| handler.clone_ref(py)),
532            vec![(header_key, header_value)],
533            Some(1),
534            Some("heartbeat message".to_string()),
535            None,
536            None,
537            None,
538            None,
539            None,
540            None,
541        );
542        let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
543            .await
544            .unwrap();
545
546        // Check if ping message has the correct message
547        sleep(Duration::from_secs(2)).await;
548        let check_value: bool = Python::with_gil(|py| {
549            checker
550                .getattr(py, "get_check")
551                .unwrap()
552                .call0(py)
553                .unwrap()
554                .extract(py)
555                .unwrap()
556        });
557        assert!(check_value);
558
559        // Cleanup
560        client.disconnect().await;
561        assert!(client.is_disconnected());
562    }
563}