1use std::{
17 sync::{Arc, atomic::Ordering},
18 time::Duration,
19};
20
21use nautilus_core::python::{to_pyruntime_err, to_pyvalue_err};
22use pyo3::{create_exception, exceptions::PyException, prelude::*};
23use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
24
25use crate::{
26 mode::ConnectionMode,
27 ratelimiter::quota::Quota,
28 websocket::{Consumer, WebSocketClient, WebSocketConfig, WriterCommand},
29};
30
31create_exception!(network, WebSocketClientError, PyException);
33
34fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
35 PyErr::new::<WebSocketClientError, _>(e.to_string())
36}
37
38#[pymethods]
39impl WebSocketConfig {
40 #[new]
41 #[allow(clippy::too_many_arguments)]
42 #[pyo3(signature = (url, handler, headers, heartbeat=None, heartbeat_msg=None, ping_handler=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100))]
43 fn py_new(
44 url: String,
45 handler: PyObject,
46 headers: Vec<(String, String)>,
47 heartbeat: Option<u64>,
48 heartbeat_msg: Option<String>,
49 ping_handler: Option<PyObject>,
50 reconnect_timeout_ms: Option<u64>,
51 reconnect_delay_initial_ms: Option<u64>,
52 reconnect_delay_max_ms: Option<u64>,
53 reconnect_backoff_factor: Option<f64>,
54 reconnect_jitter_ms: Option<u64>,
55 ) -> Self {
56 Self {
57 url,
58 handler: Consumer::Python(Some(Arc::new(handler))),
59 headers,
60 heartbeat,
61 heartbeat_msg,
62 ping_handler: ping_handler.map(Arc::new),
63 reconnect_timeout_ms,
64 reconnect_delay_initial_ms,
65 reconnect_delay_max_ms,
66 reconnect_backoff_factor,
67 reconnect_jitter_ms,
68 }
69 }
70}
71
72#[pymethods]
73impl WebSocketClient {
74 #[staticmethod]
80 #[pyo3(name = "connect", signature = (config, post_connection= None, post_reconnection= None, post_disconnection= None, keyed_quotas = Vec::new(), default_quota = None))]
81 fn py_connect(
82 config: WebSocketConfig,
83 post_connection: Option<PyObject>,
84 post_reconnection: Option<PyObject>,
85 post_disconnection: Option<PyObject>,
86 keyed_quotas: Vec<(String, Quota)>,
87 default_quota: Option<Quota>,
88 py: Python<'_>,
89 ) -> PyResult<Bound<'_, PyAny>> {
90 pyo3_async_runtimes::tokio::future_into_py(py, async move {
91 Self::connect(
92 config,
93 post_connection,
94 post_reconnection,
95 post_disconnection,
96 keyed_quotas,
97 default_quota,
98 )
99 .await
100 .map_err(to_websocket_pyerr)
101 })
102 }
103
104 #[pyo3(name = "disconnect")]
114 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
115 let connection_mode = slf.connection_mode.clone();
116 let mode = ConnectionMode::from_atomic(&connection_mode);
117 tracing::debug!("Close from mode {mode}");
118
119 pyo3_async_runtimes::tokio::future_into_py(py, async move {
120 match ConnectionMode::from_atomic(&connection_mode) {
121 ConnectionMode::Closed => {
122 tracing::warn!("WebSocket already closed");
123 }
124 ConnectionMode::Disconnect => {
125 tracing::warn!("WebSocket already disconnecting");
126 }
127 _ => {
128 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
129 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
130 tokio::time::sleep(Duration::from_millis(10)).await;
131 }
132 }
133 }
134
135 Ok(())
136 })
137 }
138
139 #[pyo3(name = "is_active")]
149 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
150 !slf.controller_task.is_finished()
151 }
152
153 #[pyo3(name = "is_reconnecting")]
154 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
155 slf.is_reconnecting()
156 }
157
158 #[pyo3(name = "is_disconnecting")]
159 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
160 slf.is_disconnecting()
161 }
162
163 #[pyo3(name = "is_closed")]
164 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
165 slf.is_closed()
166 }
167
168 #[pyo3(name = "send")]
174 #[pyo3(signature = (data, keys=None))]
175 fn py_send<'py>(
176 slf: PyRef<'_, Self>,
177 data: Vec<u8>,
178 py: Python<'py>,
179 keys: Option<Vec<String>>,
180 ) -> PyResult<Bound<'py, PyAny>> {
181 let rate_limiter = slf.rate_limiter.clone();
182 let writer_tx = slf.writer_tx.clone();
183 let mode = slf.connection_mode.clone();
184
185 pyo3_async_runtimes::tokio::future_into_py(py, async move {
186 if !ConnectionMode::from_atomic(&mode).is_active() {
187 let msg = "Cannot send data: connection not active".to_string();
188 tracing::error!("{msg}");
189 return Err(to_pyruntime_err(std::io::Error::new(
190 std::io::ErrorKind::NotConnected,
191 msg,
192 )));
193 }
194 rate_limiter.await_keys_ready(keys).await;
195 tracing::trace!("Sending binary: {data:?}");
196
197 let msg = Message::Binary(data.into());
198 writer_tx
199 .send(WriterCommand::Send(msg))
200 .map_err(to_pyruntime_err)
201 })
202 }
203
204 #[pyo3(name = "send_text")]
218 #[pyo3(signature = (data, keys=None))]
219 fn py_send_text<'py>(
220 slf: PyRef<'_, Self>,
221 data: Vec<u8>,
222 py: Python<'py>,
223 keys: Option<Vec<String>>,
224 ) -> PyResult<Bound<'py, PyAny>> {
225 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
226 let data = Utf8Bytes::from(data_str);
227 let rate_limiter = slf.rate_limiter.clone();
228 let writer_tx = slf.writer_tx.clone();
229 let mode = slf.connection_mode.clone();
230
231 pyo3_async_runtimes::tokio::future_into_py(py, async move {
232 if !ConnectionMode::from_atomic(&mode).is_active() {
233 let err = std::io::Error::new(
234 std::io::ErrorKind::NotConnected,
235 "Cannot send text: connection not active",
236 );
237 return Err(to_pyruntime_err(err));
238 }
239 rate_limiter.await_keys_ready(keys).await;
240 tracing::trace!("Sending text: {data}");
241
242 let msg = Message::Text(data);
243 writer_tx
244 .send(WriterCommand::Send(msg))
245 .map_err(to_pyruntime_err)
246 })
247 }
248
249 #[pyo3(name = "send_pong")]
255 fn py_send_pong<'py>(
256 slf: PyRef<'_, Self>,
257 data: Vec<u8>,
258 py: Python<'py>,
259 ) -> PyResult<Bound<'py, PyAny>> {
260 let data_str = String::from_utf8(data.clone()).map_err(to_pyvalue_err)?;
261 let writer_tx = slf.writer_tx.clone();
262 let mode = slf.connection_mode.clone();
263
264 pyo3_async_runtimes::tokio::future_into_py(py, async move {
265 if !ConnectionMode::from_atomic(&mode).is_active() {
266 let err = std::io::Error::new(
267 std::io::ErrorKind::NotConnected,
268 "Cannot send pong: connection not active",
269 );
270 return Err(to_pyruntime_err(err));
271 }
272 tracing::trace!("Sending pong: {data_str}");
273
274 let msg = Message::Pong(data.into());
275 writer_tx
276 .send(WriterCommand::Send(msg))
277 .map_err(to_pyruntime_err)
278 })
279 }
280}
281
282#[cfg(test)]
286#[cfg(target_os = "linux")] mod tests {
288 use std::ffi::CString;
289
290 use futures_util::{SinkExt, StreamExt};
291 use nautilus_core::python::IntoPyObjectPoseiExt;
292 use pyo3::{prelude::*, prepare_freethreaded_python};
293 use tokio::{
294 net::TcpListener,
295 task::{self, JoinHandle},
296 time::{Duration, sleep},
297 };
298 use tokio_tungstenite::{
299 accept_hdr_async,
300 tungstenite::{
301 handshake::server::{self, Callback},
302 http::HeaderValue,
303 },
304 };
305 use tracing_test::traced_test;
306
307 use crate::websocket::{WebSocketClient, WebSocketConfig};
308
309 struct TestServer {
310 task: JoinHandle<()>,
311 port: u16,
312 }
313
314 #[derive(Debug, Clone)]
315 struct TestCallback {
316 key: String,
317 value: HeaderValue,
318 }
319
320 impl Callback for TestCallback {
321 fn on_request(
322 self,
323 request: &server::Request,
324 response: server::Response,
325 ) -> Result<server::Response, server::ErrorResponse> {
326 let _ = response;
327 let value = request.headers().get(&self.key);
328 assert!(value.is_some());
329
330 if let Some(value) = request.headers().get(&self.key) {
331 assert_eq!(value, self.value);
332 }
333
334 Ok(response)
335 }
336 }
337
338 impl TestServer {
339 async fn setup(key: String, value: String) -> Self {
340 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
341 let port = TcpListener::local_addr(&server).unwrap().port();
342
343 let test_call_back = TestCallback {
344 key,
345 value: HeaderValue::from_str(&value).unwrap(),
346 };
347
348 let task = task::spawn(async move {
350 loop {
352 let (conn, _) = server.accept().await.unwrap();
353 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
354 .await
355 .unwrap();
356
357 task::spawn(async move {
358 while let Some(Ok(msg)) = websocket.next().await {
359 match msg {
360 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
361 if txt == "close-now" =>
362 {
363 tracing::debug!("Forcibly closing from server side");
364 let _ = websocket.close(None).await;
366 break;
367 }
368 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
370 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
371 if websocket.send(msg).await.is_err() {
372 break;
373 }
374 }
375 tokio_tungstenite::tungstenite::protocol::Message::Close(
377 _frame,
378 ) => {
379 let _ = websocket.close(None).await;
380 break;
381 }
382 _ => {}
384 }
385 }
386 });
387 }
388 });
389
390 Self { task, port }
391 }
392 }
393
394 impl Drop for TestServer {
395 fn drop(&mut self) {
396 self.task.abort();
397 }
398 }
399
400 fn create_test_handler() -> (PyObject, PyObject) {
401 let code_raw = r"
402class Counter:
403 def __init__(self):
404 self.count = 0
405 self.check = False
406
407 def handler(self, bytes):
408 msg = bytes.decode()
409 if msg == 'ping':
410 self.count += 1
411 elif msg == 'heartbeat message':
412 self.check = True
413
414 def get_check(self):
415 return self.check
416
417 def get_count(self):
418 return self.count
419
420counter = Counter()
421";
422
423 let code = CString::new(code_raw).unwrap();
424 let filename = CString::new("test".to_string()).unwrap();
425 let module = CString::new("test".to_string()).unwrap();
426 Python::with_gil(|py| {
427 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
428
429 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
430 let handler = counter
431 .getattr(py, "handler")
432 .unwrap()
433 .into_py_any_unwrap(py);
434
435 (counter, handler)
436 })
437 }
438
439 #[tokio::test]
440 #[traced_test]
441 async fn basic_client_test() {
442 prepare_freethreaded_python();
443
444 const N: usize = 10;
445 let mut success_count = 0;
446 let header_key = "hello-custom-key".to_string();
447 let header_value = "hello-custom-value".to_string();
448
449 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
450 let (counter, handler) = create_test_handler();
451
452 let config = WebSocketConfig::py_new(
453 format!("ws://127.0.0.1:{}", server.port),
454 Python::with_gil(|py| handler.clone_ref(py)),
455 vec![(header_key, header_value)],
456 None,
457 None,
458 None,
459 None,
460 None,
461 None,
462 None,
463 None,
464 );
465 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
466 .await
467 .unwrap();
468
469 for _ in 0..N {
471 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
472 success_count += 1;
473 }
474
475 sleep(Duration::from_secs(1)).await;
477 let count_value: usize = Python::with_gil(|py| {
478 counter
479 .getattr(py, "get_count")
480 .unwrap()
481 .call0(py)
482 .unwrap()
483 .extract(py)
484 .unwrap()
485 });
486 assert_eq!(count_value, success_count);
487
488 client.send_close_message().await.unwrap();
490
491 sleep(Duration::from_secs(2)).await;
493 for _ in 0..N {
494 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
495 success_count += 1;
496 }
497
498 sleep(Duration::from_secs(1)).await;
500 let count_value: usize = Python::with_gil(|py| {
501 counter
502 .getattr(py, "get_count")
503 .unwrap()
504 .call0(py)
505 .unwrap()
506 .extract(py)
507 .unwrap()
508 });
509 assert_eq!(count_value, success_count);
510 assert_eq!(success_count, N + N);
511
512 client.disconnect().await;
514 assert!(client.is_disconnected());
515 }
516
517 #[tokio::test]
518 #[traced_test]
519 async fn message_ping_test() {
520 prepare_freethreaded_python();
521
522 let header_key = "hello-custom-key".to_string();
523 let header_value = "hello-custom-value".to_string();
524
525 let (checker, handler) = create_test_handler();
526
527 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
529 let config = WebSocketConfig::py_new(
530 format!("ws://127.0.0.1:{}", server.port),
531 Python::with_gil(|py| handler.clone_ref(py)),
532 vec![(header_key, header_value)],
533 Some(1),
534 Some("heartbeat message".to_string()),
535 None,
536 None,
537 None,
538 None,
539 None,
540 None,
541 );
542 let client = WebSocketClient::connect(config, None, None, None, Vec::new(), None)
543 .await
544 .unwrap();
545
546 sleep(Duration::from_secs(2)).await;
548 let check_value: bool = Python::with_gil(|py| {
549 checker
550 .getattr(py, "get_check")
551 .unwrap()
552 .call0(py)
553 .unwrap()
554 .extract(py)
555 .unwrap()
556 });
557 assert!(check_value);
558
559 client.disconnect().await;
561 assert!(client.is_disconnected());
562 }
563}