1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49 net::TcpStream,
50 sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53 MaybeTlsStream, WebSocketStream, connect_async,
54 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58 backoff::ExponentialBackoff,
59 error::SendError,
60 logging::{log_task_aborted, log_task_started, log_task_stopped},
61 mode::ConnectionMode,
62 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68#[derive(Debug, Clone)]
69pub enum Consumer {
70 #[cfg(feature = "python")]
71 Python(Option<Arc<PyObject>>),
72 Rust(Sender<Message>),
73}
74
75impl Consumer {
76 #[must_use]
77 pub fn rust_consumer() -> (Self, Receiver<Message>) {
78 let (tx, rx) = mpsc::channel(100);
79 (Self::Rust(tx), rx)
80 }
81}
82
83#[derive(Debug, Clone)]
84#[cfg_attr(
85 feature = "python",
86 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
87)]
88pub struct WebSocketConfig {
89 pub url: String,
91 pub headers: Vec<(String, String)>,
93 pub handler: Consumer,
95 pub heartbeat: Option<u64>,
97 pub heartbeat_msg: Option<String>,
99 #[cfg(feature = "python")]
101 pub ping_handler: Option<Arc<PyObject>>,
102 pub reconnect_timeout_ms: Option<u64>,
104 pub reconnect_delay_initial_ms: Option<u64>,
106 pub reconnect_delay_max_ms: Option<u64>,
108 pub reconnect_backoff_factor: Option<f64>,
110 pub reconnect_jitter_ms: Option<u64>,
112}
113
114#[derive(Debug)]
116pub(crate) enum WriterCommand {
117 Update(MessageWriter),
119 Send(Message),
121}
122
123struct WebSocketClientInner {
139 config: WebSocketConfig,
140 read_task: Option<tokio::task::JoinHandle<()>>,
141 write_task: tokio::task::JoinHandle<()>,
142 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
143 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
144 connection_mode: Arc<AtomicU8>,
145 reconnect_timeout: Duration,
146 backoff: ExponentialBackoff,
147}
148
149impl WebSocketClientInner {
150 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
152 install_cryptographic_provider();
153
154 #[allow(unused_variables)]
155 let WebSocketConfig {
156 url,
157 handler,
158 heartbeat,
159 headers,
160 heartbeat_msg,
161 #[cfg(feature = "python")]
162 ping_handler,
163 reconnect_timeout_ms,
164 reconnect_delay_initial_ms,
165 reconnect_delay_max_ms,
166 reconnect_backoff_factor,
167 reconnect_jitter_ms,
168 } = &config;
169 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
170
171 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
172
173 let read_task = match &handler {
174 #[cfg(feature = "python")]
175 Consumer::Python(handler) => handler.as_ref().map(|handler| {
176 Self::spawn_python_callback_task(
177 connection_mode.clone(),
178 reader,
179 handler.clone(),
180 ping_handler.clone(),
181 )
182 }),
183 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
184 connection_mode.clone(),
185 reader,
186 sender.clone(),
187 )),
188 };
189
190 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
191 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
192
193 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
195 Self::spawn_heartbeat_task(
196 connection_mode.clone(),
197 *heartbeat_secs,
198 heartbeat_msg.clone(),
199 writer_tx.clone(),
200 )
201 });
202
203 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
204 let backoff = ExponentialBackoff::new(
205 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
206 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
207 reconnect_backoff_factor.unwrap_or(1.5),
208 reconnect_jitter_ms.unwrap_or(100),
209 true, )
211 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
212
213 Ok(Self {
214 config,
215 read_task,
216 write_task,
217 writer_tx,
218 heartbeat_task,
219 connection_mode,
220 reconnect_timeout,
221 backoff,
222 })
223 }
224
225 #[inline]
227 pub async fn connect_with_server(
228 url: &str,
229 headers: Vec<(String, String)>,
230 ) -> Result<(MessageWriter, MessageReader), Error> {
231 let mut request = url.into_client_request()?;
232 let req_headers = request.headers_mut();
233
234 let mut header_names: Vec<HeaderName> = Vec::new();
235 for (key, val) in headers {
236 let header_value = HeaderValue::from_str(&val)?;
237 let header_name: HeaderName = key.parse()?;
238 header_names.push(header_name.clone());
239 req_headers.insert(header_name, header_value);
240 }
241
242 connect_async(request).await.map(|resp| resp.0.split())
243 }
244
245 pub async fn reconnect(&mut self) -> Result<(), Error> {
250 tracing::debug!("Reconnecting");
251
252 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
253 tracing::debug!("Reconnect aborted due to disconnect state");
254 return Ok(());
255 }
256
257 tokio::time::timeout(self.reconnect_timeout, async {
258 let (new_writer, reader) =
260 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
261
262 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
263 tracing::debug!("Reconnect aborted mid-flight (after connect)");
264 return Ok(());
265 }
266
267 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
268 tracing::error!("{e}");
269 }
270
271 tokio::time::sleep(Duration::from_millis(100)).await;
273
274 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
275 tracing::debug!("Reconnect aborted mid-flight (after delay)");
276 return Ok(());
277 }
278
279 if let Some(ref read_task) = self.read_task.take() {
280 if !read_task.is_finished() {
281 read_task.abort();
282 log_task_aborted("read");
283 }
284 }
285
286 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
288 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
289 return Ok(());
290 }
291
292 self.connection_mode
294 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
295
296 self.read_task = match &self.config.handler {
297 #[cfg(feature = "python")]
298 Consumer::Python(handler) => handler.as_ref().map(|handler| {
299 Self::spawn_python_callback_task(
300 self.connection_mode.clone(),
301 reader,
302 handler.clone(),
303 self.config.ping_handler.clone(),
304 )
305 }),
306 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
307 self.connection_mode.clone(),
308 reader,
309 sender.clone(),
310 )),
311 };
312
313 tracing::debug!("Reconnect succeeded");
314 Ok(())
315 })
316 .await
317 .map_err(|_| {
318 Error::Io(std::io::Error::new(
319 std::io::ErrorKind::TimedOut,
320 format!(
321 "reconnection timed out after {}s",
322 self.reconnect_timeout.as_secs_f64()
323 ),
324 ))
325 })?
326 }
327
328 #[inline]
336 #[must_use]
337 pub fn is_alive(&self) -> bool {
338 match &self.read_task {
339 Some(read_task) => !read_task.is_finished(),
340 None => true, }
342 }
343
344 fn spawn_rust_streaming_task(
345 connection_state: Arc<AtomicU8>,
346 mut reader: MessageReader,
347 sender: Sender<Message>,
348 ) -> tokio::task::JoinHandle<()> {
349 tracing::debug!("Started streaming task 'read'");
350
351 let check_interval = Duration::from_millis(10);
352
353 tokio::task::spawn(async move {
354 loop {
355 if !ConnectionMode::from_atomic(&connection_state).is_active() {
356 break;
357 }
358
359 match tokio::time::timeout(check_interval, reader.next()).await {
360 Ok(Some(Ok(message))) => {
361 if let Err(e) = sender.send(message).await {
362 tracing::error!("Failed to send message: {e}");
363 }
364 }
365 Ok(Some(Err(e))) => {
366 tracing::error!("Received error message - terminating: {e}");
367 break;
368 }
369 Ok(None) => {
370 tracing::debug!("No message received - terminating");
371 break;
372 }
373 Err(_) => {
374 continue;
376 }
377 }
378 }
379 })
380 }
381
382 #[cfg(feature = "python")]
383 fn spawn_python_callback_task(
384 connection_state: Arc<AtomicU8>,
385 mut reader: MessageReader,
386 handler: Arc<PyObject>,
387 ping_handler: Option<Arc<PyObject>>,
388 ) -> tokio::task::JoinHandle<()> {
389 log_task_started("read");
390
391 let check_interval = Duration::from_millis(10);
393
394 tokio::task::spawn(async move {
395 loop {
396 if !ConnectionMode::from_atomic(&connection_state).is_active() {
397 break;
398 }
399
400 match tokio::time::timeout(check_interval, reader.next()).await {
401 Ok(Some(Ok(Message::Binary(data)))) => {
402 tracing::trace!("Received message <binary> {} bytes", data.len());
403 if let Err(e) =
404 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
405 {
406 tracing::error!("Error calling handler: {e}");
407 break;
408 }
409 continue;
410 }
411 Ok(Some(Ok(Message::Text(data)))) => {
412 tracing::trace!("Received message: {data}");
413 if let Err(e) = Python::with_gil(|py| {
414 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
415 }) {
416 tracing::error!("Error calling handler: {e}");
417 break;
418 }
419 continue;
420 }
421 Ok(Some(Ok(Message::Ping(ping)))) => {
422 tracing::trace!("Received ping: {ping:?}");
423 if let Some(ref handler) = ping_handler {
424 if let Err(e) =
425 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
426 {
427 tracing::error!("Error calling handler: {e}");
428 break;
429 }
430 }
431 continue;
432 }
433 Ok(Some(Ok(Message::Pong(_)))) => {
434 tracing::trace!("Received pong");
435 }
436 Ok(Some(Ok(Message::Close(_)))) => {
437 tracing::debug!("Received close message - terminating");
438 break;
439 }
440 Ok(Some(Ok(_))) => (),
441 Ok(Some(Err(e))) => {
442 tracing::error!("Received error message - terminating: {e}");
443 break;
444 }
445 Ok(None) => {
448 tracing::debug!("No message received - terminating");
449 break;
450 }
451 Err(_) => {
452 continue;
454 }
455 }
456 }
457 })
458 }
459
460 fn spawn_write_task(
461 connection_state: Arc<AtomicU8>,
462 writer: MessageWriter,
463 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
464 ) -> tokio::task::JoinHandle<()> {
465 log_task_started("write");
466
467 let check_interval = Duration::from_millis(10);
469
470 tokio::task::spawn(async move {
471 let mut active_writer = writer;
472
473 loop {
474 match ConnectionMode::from_atomic(&connection_state) {
475 ConnectionMode::Disconnect => {
476 _ = active_writer.close().await;
479 break;
480 }
481 ConnectionMode::Closed => break,
482 _ => {}
483 }
484
485 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
486 Ok(Some(msg)) => {
487 let mode = ConnectionMode::from_atomic(&connection_state);
489 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
490 break;
491 }
492
493 match msg {
494 WriterCommand::Update(new_writer) => {
495 tracing::debug!("Received new writer");
496
497 tokio::time::sleep(Duration::from_millis(100)).await;
499
500 _ = active_writer.close().await;
503
504 active_writer = new_writer;
505 tracing::debug!("Updated writer");
506 }
507 _ if mode.is_reconnect() => {
508 tracing::warn!("Skipping message while reconnecting, {msg:?}");
509 continue;
510 }
511 WriterCommand::Send(msg) => {
512 if let Err(e) = active_writer.send(msg).await {
513 tracing::error!("Failed to send message: {e}");
514 tracing::warn!("Writer triggering reconnect");
516 connection_state
517 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
518 }
519 }
520 }
521 }
522 Ok(None) => {
523 tracing::debug!("Writer channel closed, terminating writer task");
525 break;
526 }
527 Err(_) => {
528 continue;
530 }
531 }
532 }
533
534 _ = active_writer.close().await;
537
538 log_task_stopped("write");
539 })
540 }
541
542 fn spawn_heartbeat_task(
543 connection_state: Arc<AtomicU8>,
544 heartbeat_secs: u64,
545 message: Option<String>,
546 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
547 ) -> tokio::task::JoinHandle<()> {
548 log_task_started("heartbeat");
549
550 tokio::task::spawn(async move {
551 let interval = Duration::from_secs(heartbeat_secs);
552
553 loop {
554 tokio::time::sleep(interval).await;
555
556 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
557 ConnectionMode::Active => {
558 let msg = match &message {
559 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
560 None => WriterCommand::Send(Message::Ping(vec![].into())),
561 };
562
563 match writer_tx.send(msg) {
564 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
565 Err(e) => {
566 tracing::error!("Failed to send heartbeat to writer task: {e}");
567 }
568 }
569 }
570 ConnectionMode::Reconnect => continue,
571 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
572 }
573 }
574
575 log_task_stopped("heartbeat");
576 })
577 }
578}
579
580impl Drop for WebSocketClientInner {
581 fn drop(&mut self) {
582 if let Some(ref read_task) = self.read_task.take() {
583 if !read_task.is_finished() {
584 read_task.abort();
585 log_task_aborted("read");
586 }
587 }
588
589 if !self.write_task.is_finished() {
590 self.write_task.abort();
591 log_task_aborted("write");
592 }
593
594 if let Some(ref handle) = self.heartbeat_task.take() {
595 if !handle.is_finished() {
596 handle.abort();
597 log_task_aborted("heartbeat");
598 }
599 }
600 }
601}
602
603#[cfg_attr(
608 feature = "python",
609 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
610)]
611pub struct WebSocketClient {
612 pub(crate) controller_task: tokio::task::JoinHandle<()>,
613 pub(crate) connection_mode: Arc<AtomicU8>,
614 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
615 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
616}
617
618impl Debug for WebSocketClient {
619 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.debug_struct(stringify!(WebSocketClient)).finish()
621 }
622}
623
624impl WebSocketClient {
625 #[allow(clippy::too_many_arguments)]
631 pub async fn connect_stream(
632 config: WebSocketConfig,
633 keyed_quotas: Vec<(String, Quota)>,
634 default_quota: Option<Quota>,
635 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
636 ) -> Result<(MessageReader, Self), Error> {
637 install_cryptographic_provider();
638 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
639 let (writer, reader) = ws_stream.split();
640 let inner = WebSocketClientInner::connect_url(config).await?;
641
642 let connection_mode = inner.connection_mode.clone();
643
644 let writer_tx = inner.writer_tx.clone();
645 if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
646 tracing::error!("{e}");
647 }
648
649 let controller_task = Self::spawn_controller_task(
650 inner,
651 connection_mode.clone(),
652 post_reconnect,
653 #[cfg(feature = "python")]
654 None, #[cfg(feature = "python")]
656 None, );
658
659 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
660
661 Ok((
662 reader,
663 Self {
664 controller_task,
665 connection_mode,
666 writer_tx,
667 rate_limiter,
668 },
669 ))
670 }
671
672 pub async fn connect(
681 config: WebSocketConfig,
682 #[cfg(feature = "python")] post_connection: Option<PyObject>,
683 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
684 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
685 keyed_quotas: Vec<(String, Quota)>,
686 default_quota: Option<Quota>,
687 ) -> Result<Self, Error> {
688 tracing::debug!("Connecting");
689 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
690 let connection_mode = inner.connection_mode.clone();
691 let writer_tx = inner.writer_tx.clone();
692
693 let controller_task = Self::spawn_controller_task(
694 inner,
695 connection_mode.clone(),
696 None, #[cfg(feature = "python")]
698 post_reconnection, #[cfg(feature = "python")]
700 post_disconnection, );
702
703 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
704
705 #[cfg(feature = "python")]
706 if let Some(handler) = post_connection {
707 Python::with_gil(|py| match handler.call0(py) {
708 Ok(_) => tracing::debug!("Called `post_connection` handler"),
709 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
710 });
711 }
712
713 Ok(Self {
714 controller_task,
715 connection_mode,
716 writer_tx,
717 rate_limiter,
718 })
719 }
720
721 #[must_use]
723 pub fn connection_mode(&self) -> ConnectionMode {
724 ConnectionMode::from_atomic(&self.connection_mode)
725 }
726
727 #[inline]
732 #[must_use]
733 pub fn is_active(&self) -> bool {
734 self.connection_mode().is_active()
735 }
736
737 #[must_use]
739 pub fn is_disconnected(&self) -> bool {
740 self.controller_task.is_finished()
741 }
742
743 #[inline]
748 #[must_use]
749 pub fn is_reconnecting(&self) -> bool {
750 self.connection_mode().is_reconnect()
751 }
752
753 #[inline]
757 #[must_use]
758 pub fn is_disconnecting(&self) -> bool {
759 self.connection_mode().is_disconnect()
760 }
761
762 #[inline]
768 #[must_use]
769 pub fn is_closed(&self) -> bool {
770 self.connection_mode().is_closed()
771 }
772
773 pub async fn disconnect(&self) {
778 tracing::debug!("Disconnecting");
779 self.connection_mode
780 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
781
782 match tokio::time::timeout(Duration::from_secs(5), async {
783 while !self.is_disconnected() {
784 tokio::time::sleep(Duration::from_millis(10)).await;
785 }
786
787 if !self.controller_task.is_finished() {
788 self.controller_task.abort();
789 log_task_aborted("controller");
790 }
791 })
792 .await
793 {
794 Ok(()) => {
795 tracing::debug!("Controller task finished");
796 }
797 Err(_) => {
798 tracing::error!("Timeout waiting for controller task to finish");
799 }
800 }
801 }
802
803 #[allow(unused_variables)]
809 pub async fn send_text(
810 &self,
811 data: String,
812 keys: Option<Vec<String>>,
813 ) -> std::result::Result<(), SendError> {
814 self.rate_limiter.await_keys_ready(keys).await;
815
816 if !self.is_active() {
817 return Err(SendError::Closed);
818 }
819
820 tracing::trace!("Sending text: {data:?}");
821
822 let msg = Message::Text(data.into());
823 self.writer_tx
824 .send(WriterCommand::Send(msg))
825 .map_err(|e| SendError::BrokenPipe(e.to_string()))
826 }
827
828 #[allow(unused_variables)]
834 pub async fn send_bytes(
835 &self,
836 data: Vec<u8>,
837 keys: Option<Vec<String>>,
838 ) -> std::result::Result<(), SendError> {
839 self.rate_limiter.await_keys_ready(keys).await;
840
841 if !self.is_active() {
842 return Err(SendError::Closed);
843 }
844
845 tracing::trace!("Sending bytes: {data:?}");
846
847 let msg = Message::Binary(data.into());
848 self.writer_tx
849 .send(WriterCommand::Send(msg))
850 .map_err(|e| SendError::BrokenPipe(e.to_string()))
851 }
852
853 pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
859 if !self.is_active() {
860 return Err(SendError::Closed);
861 }
862
863 let msg = Message::Close(None);
864 self.writer_tx
865 .send(WriterCommand::Send(msg))
866 .map_err(|e| SendError::BrokenPipe(e.to_string()))
867 }
868
869 fn spawn_controller_task(
870 mut inner: WebSocketClientInner,
871 connection_mode: Arc<AtomicU8>,
872 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
873 #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, ) -> tokio::task::JoinHandle<()> {
876 tokio::task::spawn(async move {
877 log_task_started("controller");
878
879 let check_interval = Duration::from_millis(10);
880
881 loop {
882 tokio::time::sleep(check_interval).await;
883 let mode = ConnectionMode::from_atomic(&connection_mode);
884
885 if mode.is_disconnect() {
886 tracing::debug!("Disconnecting");
887
888 let timeout = Duration::from_secs(5);
889 if tokio::time::timeout(timeout, async {
890 tokio::time::sleep(Duration::from_millis(100)).await;
892
893 if let Some(task) = &inner.read_task {
894 if !task.is_finished() {
895 task.abort();
896 log_task_aborted("read");
897 }
898 }
899
900 if let Some(task) = &inner.heartbeat_task {
901 if !task.is_finished() {
902 task.abort();
903 log_task_aborted("heartbeat");
904 }
905 }
906 })
907 .await
908 .is_err()
909 {
910 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
911 }
912
913 tracing::debug!("Closed");
914
915 #[cfg(feature = "python")]
916 if let Some(ref handler) = py_post_disconnection {
917 Python::with_gil(|py| match handler.call0(py) {
918 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
919 Err(e) => {
920 tracing::error!("Error calling `post_disconnection` handler: {e}");
921 }
922 });
923 }
924 break; }
926
927 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
928 match inner.reconnect().await {
929 Ok(()) => {
930 inner.backoff.reset();
931
932 if ConnectionMode::from_atomic(&connection_mode).is_active() {
934 if let Some(ref callback) = post_reconnection {
935 callback();
936 }
937
938 #[cfg(feature = "python")]
940 if let Some(ref callback) = py_post_reconnection {
941 Python::with_gil(|py| match callback.call0(py) {
942 Ok(_) => {
943 tracing::debug!("Called `post_reconnection` handler");
944 }
945 Err(e) => tracing::error!(
946 "Error calling `post_reconnection` handler: {e}"
947 ),
948 });
949 }
950
951 tracing::debug!("Reconnected successfully");
952 } else {
953 tracing::debug!(
954 "Skipping post_reconnection handlers due to disconnect state"
955 );
956 }
957 }
958 Err(e) => {
959 let duration = inner.backoff.next_duration();
960 tracing::warn!("Reconnect attempt failed: {e}");
961 if !duration.is_zero() {
962 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
963 }
964 tokio::time::sleep(duration).await;
965 }
966 }
967 }
968 }
969 inner
970 .connection_mode
971 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
972
973 log_task_stopped("controller");
974 })
975 }
976}
977
978impl Drop for WebSocketClient {
980 fn drop(&mut self) {
981 if !self.controller_task.is_finished() {
982 self.controller_task.abort();
983 log_task_aborted("controller");
984 }
985 }
986}
987
988#[cfg(feature = "python")]
992#[cfg(test)]
993#[cfg(target_os = "linux")] mod tests {
995 use std::{num::NonZeroU32, sync::Arc};
996
997 use futures_util::{SinkExt, StreamExt};
998 use tokio::{
999 net::TcpListener,
1000 task::{self, JoinHandle},
1001 };
1002 use tokio_tungstenite::{
1003 accept_hdr_async,
1004 tungstenite::{
1005 handshake::server::{self, Callback},
1006 http::HeaderValue,
1007 },
1008 };
1009
1010 use crate::{
1011 ratelimiter::quota::Quota,
1012 websocket::{Consumer, WebSocketClient, WebSocketConfig},
1013 };
1014
1015 struct TestServer {
1016 task: JoinHandle<()>,
1017 port: u16,
1018 }
1019
1020 #[derive(Debug, Clone)]
1021 struct TestCallback {
1022 key: String,
1023 value: HeaderValue,
1024 }
1025
1026 impl Callback for TestCallback {
1027 fn on_request(
1028 self,
1029 request: &server::Request,
1030 response: server::Response,
1031 ) -> Result<server::Response, server::ErrorResponse> {
1032 let _ = response;
1033 let value = request.headers().get(&self.key);
1034 assert!(value.is_some());
1035
1036 if let Some(value) = request.headers().get(&self.key) {
1037 assert_eq!(value, self.value);
1038 }
1039
1040 Ok(response)
1041 }
1042 }
1043
1044 impl TestServer {
1045 async fn setup() -> Self {
1046 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1047 let port = TcpListener::local_addr(&server).unwrap().port();
1048
1049 let header_key = "test".to_string();
1050 let header_value = "test".to_string();
1051
1052 let test_call_back = TestCallback {
1053 key: header_key,
1054 value: HeaderValue::from_str(&header_value).unwrap(),
1055 };
1056
1057 let task = task::spawn(async move {
1058 loop {
1060 let (conn, _) = server.accept().await.unwrap();
1061 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1062 .await
1063 .unwrap();
1064
1065 task::spawn(async move {
1066 while let Some(Ok(msg)) = websocket.next().await {
1067 match msg {
1068 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1069 if txt == "close-now" =>
1070 {
1071 tracing::debug!("Forcibly closing from server side");
1072 let _ = websocket.close(None).await;
1074 break;
1075 }
1076 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1078 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1079 if websocket.send(msg).await.is_err() {
1080 break;
1081 }
1082 }
1083 tokio_tungstenite::tungstenite::protocol::Message::Close(
1085 _frame,
1086 ) => {
1087 let _ = websocket.close(None).await;
1088 break;
1089 }
1090 _ => {}
1092 }
1093 }
1094 });
1095 }
1096 });
1097
1098 Self { task, port }
1099 }
1100 }
1101
1102 impl Drop for TestServer {
1103 fn drop(&mut self) {
1104 self.task.abort();
1105 }
1106 }
1107
1108 async fn setup_test_client(port: u16) -> WebSocketClient {
1109 let config = WebSocketConfig {
1110 url: format!("ws://127.0.0.1:{port}"),
1111 headers: vec![("test".into(), "test".into())],
1112 handler: Consumer::Python(None),
1113 heartbeat: None,
1114 heartbeat_msg: None,
1115 ping_handler: None,
1116 reconnect_timeout_ms: None,
1117 reconnect_delay_initial_ms: None,
1118 reconnect_backoff_factor: None,
1119 reconnect_delay_max_ms: None,
1120 reconnect_jitter_ms: None,
1121 };
1122 WebSocketClient::connect(config, None, None, None, vec![], None)
1123 .await
1124 .expect("Failed to connect")
1125 }
1126
1127 #[tokio::test]
1128 async fn test_websocket_basic() {
1129 let server = TestServer::setup().await;
1130 let client = setup_test_client(server.port).await;
1131
1132 assert!(!client.is_disconnected());
1133
1134 client.disconnect().await;
1135 assert!(client.is_disconnected());
1136 }
1137
1138 #[tokio::test]
1139 async fn test_websocket_heartbeat() {
1140 let server = TestServer::setup().await;
1141 let client = setup_test_client(server.port).await;
1142
1143 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1145
1146 client.disconnect().await;
1148 assert!(client.is_disconnected());
1149 }
1150
1151 #[tokio::test]
1152 async fn test_websocket_reconnect_exhausted() {
1153 let config = WebSocketConfig {
1154 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1156 handler: Consumer::Python(None),
1157 heartbeat: None,
1158 heartbeat_msg: None,
1159 ping_handler: None,
1160 reconnect_timeout_ms: None,
1161 reconnect_delay_initial_ms: None,
1162 reconnect_backoff_factor: None,
1163 reconnect_delay_max_ms: None,
1164 reconnect_jitter_ms: None,
1165 };
1166 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1167 assert!(res.is_err(), "Should fail quickly with no server");
1168 }
1169
1170 #[tokio::test]
1171 async fn test_websocket_forced_close_reconnect() {
1172 let server = TestServer::setup().await;
1173 let client = setup_test_client(server.port).await;
1174
1175 client.send_text("Hello".into(), None).await.unwrap();
1177
1178 client.send_text("close-now".into(), None).await.unwrap();
1180
1181 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1183
1184 assert!(!client.is_disconnected());
1186
1187 client.disconnect().await;
1189 assert!(client.is_disconnected());
1190 }
1191
1192 #[tokio::test]
1193 async fn test_rate_limiter() {
1194 let server = TestServer::setup().await;
1195 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1196
1197 let config = WebSocketConfig {
1198 url: format!("ws://127.0.0.1:{}", server.port),
1199 headers: vec![("test".into(), "test".into())],
1200 handler: Consumer::Python(None),
1201 heartbeat: None,
1202 heartbeat_msg: None,
1203 ping_handler: None,
1204 reconnect_timeout_ms: None,
1205 reconnect_delay_initial_ms: None,
1206 reconnect_backoff_factor: None,
1207 reconnect_delay_max_ms: None,
1208 reconnect_jitter_ms: None,
1209 };
1210
1211 let client = WebSocketClient::connect(
1212 config,
1213 None,
1214 None,
1215 None,
1216 vec![("default".into(), quota)],
1217 None,
1218 )
1219 .await
1220 .unwrap();
1221
1222 client.send_text("test1".into(), None).await.unwrap();
1224 client.send_text("test2".into(), None).await.unwrap();
1225
1226 client.send_text("test3".into(), None).await.unwrap();
1228
1229 client.disconnect().await;
1231 assert!(client.is_disconnected());
1232 }
1233
1234 #[tokio::test]
1235 async fn test_concurrent_writers() {
1236 let server = TestServer::setup().await;
1237 let client = Arc::new(setup_test_client(server.port).await);
1238
1239 let mut handles = vec![];
1240 for i in 0..10 {
1241 let client = client.clone();
1242 handles.push(task::spawn(async move {
1243 client.send_text(format!("test{i}"), None).await.unwrap();
1244 }));
1245 }
1246
1247 for handle in handles {
1248 handle.await.unwrap();
1249 }
1250
1251 client.disconnect().await;
1253 assert!(client.is_disconnected());
1254 }
1255}
1256
1257#[cfg(test)]
1258mod rust_tests {
1259 use tokio::{
1260 net::TcpListener,
1261 task,
1262 time::{Duration, sleep},
1263 };
1264 use tokio_tungstenite::accept_async;
1265
1266 use super::*;
1267
1268 #[tokio::test]
1269 async fn test_reconnect_then_disconnect() {
1270 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1272 let port = listener.local_addr().unwrap().port();
1273
1274 let server = task::spawn(async move {
1276 let (stream, _) = listener.accept().await.unwrap();
1277 let ws = accept_async(stream).await.unwrap();
1278 drop(ws);
1279 sleep(Duration::from_secs(1)).await;
1281 });
1282
1283 let (consumer, _rx) = Consumer::rust_consumer();
1285
1286 let config = WebSocketConfig {
1288 url: format!("ws://127.0.0.1:{port}"),
1289 headers: vec![],
1290 handler: consumer,
1291 heartbeat: None,
1292 heartbeat_msg: None,
1293 #[cfg(feature = "python")]
1294 ping_handler: None,
1295 reconnect_timeout_ms: Some(1_000),
1296 reconnect_delay_initial_ms: Some(50),
1297 reconnect_delay_max_ms: Some(100),
1298 reconnect_backoff_factor: Some(1.0),
1299 reconnect_jitter_ms: Some(0),
1300 };
1301
1302 let client = {
1304 #[cfg(feature = "python")]
1305 {
1306 WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1307 .await
1308 .unwrap()
1309 }
1310 #[cfg(not(feature = "python"))]
1311 {
1312 WebSocketClient::connect(config.clone(), vec![], None)
1313 .await
1314 .unwrap()
1315 }
1316 };
1317
1318 sleep(Duration::from_millis(100)).await;
1320 client.disconnect().await;
1322 assert!(client.is_disconnected());
1323 server.abort();
1324 }
1325}