1use std::{
32 fmt::Debug,
33 sync::{
34 Arc,
35 atomic::{AtomicU8, Ordering},
36 },
37 time::Duration,
38};
39
40use futures_util::{
41 SinkExt, StreamExt,
42 stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49 net::TcpStream,
50 sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53 MaybeTlsStream, WebSocketStream, connect_async,
54 tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58 backoff::ExponentialBackoff,
59 error::SendError,
60 logging::{log_task_aborted, log_task_started, log_task_stopped},
61 mode::ConnectionMode,
62 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68#[derive(Debug, Clone)]
70pub enum Consumer {
71 #[cfg(feature = "python")]
73 Python(Option<Arc<PyObject>>),
74 Rust(Sender<Message>),
76}
77
78impl Consumer {
79 #[must_use]
83 pub fn rust_consumer() -> (Self, Receiver<Message>) {
84 let (tx, rx) = mpsc::channel(100);
85 (Self::Rust(tx), rx)
86 }
87}
88
89#[derive(Debug, Clone)]
90#[cfg_attr(
91 feature = "python",
92 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
93)]
94pub struct WebSocketConfig {
95 pub url: String,
97 pub headers: Vec<(String, String)>,
99 pub handler: Consumer,
101 pub heartbeat: Option<u64>,
103 pub heartbeat_msg: Option<String>,
105 #[cfg(feature = "python")]
107 pub ping_handler: Option<Arc<PyObject>>,
108 pub reconnect_timeout_ms: Option<u64>,
110 pub reconnect_delay_initial_ms: Option<u64>,
112 pub reconnect_delay_max_ms: Option<u64>,
114 pub reconnect_backoff_factor: Option<f64>,
116 pub reconnect_jitter_ms: Option<u64>,
118}
119
120#[derive(Debug)]
122pub(crate) enum WriterCommand {
123 Update(MessageWriter),
125 Send(Message),
127}
128
129struct WebSocketClientInner {
145 config: WebSocketConfig,
146 read_task: Option<tokio::task::JoinHandle<()>>,
147 write_task: tokio::task::JoinHandle<()>,
148 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
149 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
150 connection_mode: Arc<AtomicU8>,
151 reconnect_timeout: Duration,
152 backoff: ExponentialBackoff,
153}
154
155impl WebSocketClientInner {
156 pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
158 install_cryptographic_provider();
159
160 #[allow(unused_variables)]
161 let WebSocketConfig {
162 url,
163 handler,
164 heartbeat,
165 headers,
166 heartbeat_msg,
167 #[cfg(feature = "python")]
168 ping_handler,
169 reconnect_timeout_ms,
170 reconnect_delay_initial_ms,
171 reconnect_delay_max_ms,
172 reconnect_backoff_factor,
173 reconnect_jitter_ms,
174 } = &config;
175 let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
176
177 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
178
179 let read_task = match &handler {
180 #[cfg(feature = "python")]
181 Consumer::Python(handler) => handler.as_ref().map(|handler| {
182 Self::spawn_python_callback_task(
183 connection_mode.clone(),
184 reader,
185 handler.clone(),
186 ping_handler.clone(),
187 )
188 }),
189 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
190 connection_mode.clone(),
191 reader,
192 sender.clone(),
193 )),
194 };
195
196 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
197 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
198
199 let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
201 Self::spawn_heartbeat_task(
202 connection_mode.clone(),
203 *heartbeat_secs,
204 heartbeat_msg.clone(),
205 writer_tx.clone(),
206 )
207 });
208
209 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
210 let backoff = ExponentialBackoff::new(
211 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
212 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
213 reconnect_backoff_factor.unwrap_or(1.5),
214 reconnect_jitter_ms.unwrap_or(100),
215 true, )
217 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
218
219 Ok(Self {
220 config,
221 read_task,
222 write_task,
223 writer_tx,
224 heartbeat_task,
225 connection_mode,
226 reconnect_timeout,
227 backoff,
228 })
229 }
230
231 #[inline]
233 pub async fn connect_with_server(
234 url: &str,
235 headers: Vec<(String, String)>,
236 ) -> Result<(MessageWriter, MessageReader), Error> {
237 let mut request = url.into_client_request()?;
238 let req_headers = request.headers_mut();
239
240 let mut header_names: Vec<HeaderName> = Vec::new();
241 for (key, val) in headers {
242 let header_value = HeaderValue::from_str(&val)?;
243 let header_name: HeaderName = key.parse()?;
244 header_names.push(header_name.clone());
245 req_headers.insert(header_name, header_value);
246 }
247
248 connect_async(request).await.map(|resp| resp.0.split())
249 }
250
251 pub async fn reconnect(&mut self) -> Result<(), Error> {
256 tracing::debug!("Reconnecting");
257
258 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
259 tracing::debug!("Reconnect aborted due to disconnect state");
260 return Ok(());
261 }
262
263 tokio::time::timeout(self.reconnect_timeout, async {
264 let (new_writer, reader) =
266 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
267
268 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
269 tracing::debug!("Reconnect aborted mid-flight (after connect)");
270 return Ok(());
271 }
272
273 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
274 tracing::error!("{e}");
275 }
276
277 tokio::time::sleep(Duration::from_millis(100)).await;
279
280 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
281 tracing::debug!("Reconnect aborted mid-flight (after delay)");
282 return Ok(());
283 }
284
285 if let Some(ref read_task) = self.read_task.take() {
286 if !read_task.is_finished() {
287 read_task.abort();
288 log_task_aborted("read");
289 }
290 }
291
292 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
294 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
295 return Ok(());
296 }
297
298 self.connection_mode
300 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
301
302 self.read_task = match &self.config.handler {
303 #[cfg(feature = "python")]
304 Consumer::Python(handler) => handler.as_ref().map(|handler| {
305 Self::spawn_python_callback_task(
306 self.connection_mode.clone(),
307 reader,
308 handler.clone(),
309 self.config.ping_handler.clone(),
310 )
311 }),
312 Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
313 self.connection_mode.clone(),
314 reader,
315 sender.clone(),
316 )),
317 };
318
319 tracing::debug!("Reconnect succeeded");
320 Ok(())
321 })
322 .await
323 .map_err(|_| {
324 Error::Io(std::io::Error::new(
325 std::io::ErrorKind::TimedOut,
326 format!(
327 "reconnection timed out after {}s",
328 self.reconnect_timeout.as_secs_f64()
329 ),
330 ))
331 })?
332 }
333
334 #[inline]
342 #[must_use]
343 pub fn is_alive(&self) -> bool {
344 match &self.read_task {
345 Some(read_task) => !read_task.is_finished(),
346 None => true, }
348 }
349
350 fn spawn_rust_streaming_task(
351 connection_state: Arc<AtomicU8>,
352 mut reader: MessageReader,
353 sender: Sender<Message>,
354 ) -> tokio::task::JoinHandle<()> {
355 tracing::debug!("Started streaming task 'read'");
356
357 let check_interval = Duration::from_millis(10);
358
359 tokio::task::spawn(async move {
360 loop {
361 if !ConnectionMode::from_atomic(&connection_state).is_active() {
362 break;
363 }
364
365 match tokio::time::timeout(check_interval, reader.next()).await {
366 Ok(Some(Ok(message))) => {
367 if let Err(e) = sender.send(message).await {
368 tracing::error!("Failed to send message: {e}");
369 }
370 }
371 Ok(Some(Err(e))) => {
372 tracing::error!("Received error message - terminating: {e}");
373 break;
374 }
375 Ok(None) => {
376 tracing::debug!("No message received - terminating");
377 break;
378 }
379 Err(_) => {
380 continue;
382 }
383 }
384 }
385 })
386 }
387
388 #[cfg(feature = "python")]
389 fn spawn_python_callback_task(
390 connection_state: Arc<AtomicU8>,
391 mut reader: MessageReader,
392 handler: Arc<PyObject>,
393 ping_handler: Option<Arc<PyObject>>,
394 ) -> tokio::task::JoinHandle<()> {
395 log_task_started("read");
396
397 let check_interval = Duration::from_millis(10);
399
400 tokio::task::spawn(async move {
401 loop {
402 if !ConnectionMode::from_atomic(&connection_state).is_active() {
403 break;
404 }
405
406 match tokio::time::timeout(check_interval, reader.next()).await {
407 Ok(Some(Ok(Message::Binary(data)))) => {
408 tracing::trace!("Received message <binary> {} bytes", data.len());
409 if let Err(e) =
410 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
411 {
412 tracing::error!("Error calling handler: {e}");
413 break;
414 }
415 continue;
416 }
417 Ok(Some(Ok(Message::Text(data)))) => {
418 tracing::trace!("Received message: {data}");
419 if let Err(e) = Python::with_gil(|py| {
420 handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
421 }) {
422 tracing::error!("Error calling handler: {e}");
423 break;
424 }
425 continue;
426 }
427 Ok(Some(Ok(Message::Ping(ping)))) => {
428 tracing::trace!("Received ping: {ping:?}");
429 if let Some(ref handler) = ping_handler {
430 if let Err(e) =
431 Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
432 {
433 tracing::error!("Error calling handler: {e}");
434 break;
435 }
436 }
437 continue;
438 }
439 Ok(Some(Ok(Message::Pong(_)))) => {
440 tracing::trace!("Received pong");
441 }
442 Ok(Some(Ok(Message::Close(_)))) => {
443 tracing::debug!("Received close message - terminating");
444 break;
445 }
446 Ok(Some(Ok(_))) => (),
447 Ok(Some(Err(e))) => {
448 tracing::error!("Received error message - terminating: {e}");
449 break;
450 }
451 Ok(None) => {
454 tracing::debug!("No message received - terminating");
455 break;
456 }
457 Err(_) => {
458 continue;
460 }
461 }
462 }
463 })
464 }
465
466 fn spawn_write_task(
467 connection_state: Arc<AtomicU8>,
468 writer: MessageWriter,
469 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
470 ) -> tokio::task::JoinHandle<()> {
471 log_task_started("write");
472
473 let check_interval = Duration::from_millis(10);
475
476 tokio::task::spawn(async move {
477 let mut active_writer = writer;
478
479 loop {
480 match ConnectionMode::from_atomic(&connection_state) {
481 ConnectionMode::Disconnect => {
482 _ = active_writer.close().await;
485 break;
486 }
487 ConnectionMode::Closed => break,
488 _ => {}
489 }
490
491 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
492 Ok(Some(msg)) => {
493 let mode = ConnectionMode::from_atomic(&connection_state);
495 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
496 break;
497 }
498
499 match msg {
500 WriterCommand::Update(new_writer) => {
501 tracing::debug!("Received new writer");
502
503 tokio::time::sleep(Duration::from_millis(100)).await;
505
506 _ = active_writer.close().await;
509
510 active_writer = new_writer;
511 tracing::debug!("Updated writer");
512 }
513 _ if mode.is_reconnect() => {
514 tracing::warn!("Skipping message while reconnecting, {msg:?}");
515 continue;
516 }
517 WriterCommand::Send(msg) => {
518 if let Err(e) = active_writer.send(msg).await {
519 tracing::error!("Failed to send message: {e}");
520 tracing::warn!("Writer triggering reconnect");
522 connection_state
523 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
524 }
525 }
526 }
527 }
528 Ok(None) => {
529 tracing::debug!("Writer channel closed, terminating writer task");
531 break;
532 }
533 Err(_) => {
534 continue;
536 }
537 }
538 }
539
540 _ = active_writer.close().await;
543
544 log_task_stopped("write");
545 })
546 }
547
548 fn spawn_heartbeat_task(
549 connection_state: Arc<AtomicU8>,
550 heartbeat_secs: u64,
551 message: Option<String>,
552 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
553 ) -> tokio::task::JoinHandle<()> {
554 log_task_started("heartbeat");
555
556 tokio::task::spawn(async move {
557 let interval = Duration::from_secs(heartbeat_secs);
558
559 loop {
560 tokio::time::sleep(interval).await;
561
562 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
563 ConnectionMode::Active => {
564 let msg = match &message {
565 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
566 None => WriterCommand::Send(Message::Ping(vec![].into())),
567 };
568
569 match writer_tx.send(msg) {
570 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
571 Err(e) => {
572 tracing::error!("Failed to send heartbeat to writer task: {e}");
573 }
574 }
575 }
576 ConnectionMode::Reconnect => continue,
577 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
578 }
579 }
580
581 log_task_stopped("heartbeat");
582 })
583 }
584}
585
586impl Drop for WebSocketClientInner {
587 fn drop(&mut self) {
588 if let Some(ref read_task) = self.read_task.take() {
589 if !read_task.is_finished() {
590 read_task.abort();
591 log_task_aborted("read");
592 }
593 }
594
595 if !self.write_task.is_finished() {
596 self.write_task.abort();
597 log_task_aborted("write");
598 }
599
600 if let Some(ref handle) = self.heartbeat_task.take() {
601 if !handle.is_finished() {
602 handle.abort();
603 log_task_aborted("heartbeat");
604 }
605 }
606 }
607}
608
609#[cfg_attr(
614 feature = "python",
615 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
616)]
617pub struct WebSocketClient {
618 pub(crate) controller_task: tokio::task::JoinHandle<()>,
619 pub(crate) connection_mode: Arc<AtomicU8>,
620 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
621 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
622}
623
624impl Debug for WebSocketClient {
625 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 f.debug_struct(stringify!(WebSocketClient)).finish()
627 }
628}
629
630impl WebSocketClient {
631 #[allow(clippy::too_many_arguments)]
637 pub async fn connect_stream(
638 config: WebSocketConfig,
639 keyed_quotas: Vec<(String, Quota)>,
640 default_quota: Option<Quota>,
641 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
642 ) -> Result<(MessageReader, Self), Error> {
643 install_cryptographic_provider();
644 let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
645 let (writer, reader) = ws_stream.split();
646 let inner = WebSocketClientInner::connect_url(config).await?;
647
648 let connection_mode = inner.connection_mode.clone();
649
650 let writer_tx = inner.writer_tx.clone();
651 if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
652 tracing::error!("{e}");
653 }
654
655 let controller_task = Self::spawn_controller_task(
656 inner,
657 connection_mode.clone(),
658 post_reconnect,
659 #[cfg(feature = "python")]
660 None, #[cfg(feature = "python")]
662 None, );
664
665 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
666
667 Ok((
668 reader,
669 Self {
670 controller_task,
671 connection_mode,
672 writer_tx,
673 rate_limiter,
674 },
675 ))
676 }
677
678 pub async fn connect(
687 config: WebSocketConfig,
688 #[cfg(feature = "python")] post_connection: Option<PyObject>,
689 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
690 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
691 keyed_quotas: Vec<(String, Quota)>,
692 default_quota: Option<Quota>,
693 ) -> Result<Self, Error> {
694 tracing::debug!("Connecting");
695 let inner = WebSocketClientInner::connect_url(config.clone()).await?;
696 let connection_mode = inner.connection_mode.clone();
697 let writer_tx = inner.writer_tx.clone();
698
699 let controller_task = Self::spawn_controller_task(
700 inner,
701 connection_mode.clone(),
702 None, #[cfg(feature = "python")]
704 post_reconnection, #[cfg(feature = "python")]
706 post_disconnection, );
708
709 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
710
711 #[cfg(feature = "python")]
712 if let Some(handler) = post_connection {
713 Python::with_gil(|py| match handler.call0(py) {
714 Ok(_) => tracing::debug!("Called `post_connection` handler"),
715 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
716 });
717 }
718
719 Ok(Self {
720 controller_task,
721 connection_mode,
722 writer_tx,
723 rate_limiter,
724 })
725 }
726
727 #[must_use]
729 pub fn connection_mode(&self) -> ConnectionMode {
730 ConnectionMode::from_atomic(&self.connection_mode)
731 }
732
733 #[inline]
738 #[must_use]
739 pub fn is_active(&self) -> bool {
740 self.connection_mode().is_active()
741 }
742
743 #[must_use]
745 pub fn is_disconnected(&self) -> bool {
746 self.controller_task.is_finished()
747 }
748
749 #[inline]
754 #[must_use]
755 pub fn is_reconnecting(&self) -> bool {
756 self.connection_mode().is_reconnect()
757 }
758
759 #[inline]
763 #[must_use]
764 pub fn is_disconnecting(&self) -> bool {
765 self.connection_mode().is_disconnect()
766 }
767
768 #[inline]
774 #[must_use]
775 pub fn is_closed(&self) -> bool {
776 self.connection_mode().is_closed()
777 }
778
779 pub async fn disconnect(&self) {
784 tracing::debug!("Disconnecting");
785 self.connection_mode
786 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
787
788 match tokio::time::timeout(Duration::from_secs(5), async {
789 while !self.is_disconnected() {
790 tokio::time::sleep(Duration::from_millis(10)).await;
791 }
792
793 if !self.controller_task.is_finished() {
794 self.controller_task.abort();
795 log_task_aborted("controller");
796 }
797 })
798 .await
799 {
800 Ok(()) => {
801 tracing::debug!("Controller task finished");
802 }
803 Err(_) => {
804 tracing::error!("Timeout waiting for controller task to finish");
805 }
806 }
807 }
808
809 #[allow(unused_variables)]
815 pub async fn send_text(
816 &self,
817 data: String,
818 keys: Option<Vec<String>>,
819 ) -> std::result::Result<(), SendError> {
820 self.rate_limiter.await_keys_ready(keys).await;
821
822 if !self.is_active() {
823 return Err(SendError::Closed);
824 }
825
826 tracing::trace!("Sending text: {data:?}");
827
828 let msg = Message::Text(data.into());
829 self.writer_tx
830 .send(WriterCommand::Send(msg))
831 .map_err(|e| SendError::BrokenPipe(e.to_string()))
832 }
833
834 #[allow(unused_variables)]
840 pub async fn send_bytes(
841 &self,
842 data: Vec<u8>,
843 keys: Option<Vec<String>>,
844 ) -> std::result::Result<(), SendError> {
845 self.rate_limiter.await_keys_ready(keys).await;
846
847 if !self.is_active() {
848 return Err(SendError::Closed);
849 }
850
851 tracing::trace!("Sending bytes: {data:?}");
852
853 let msg = Message::Binary(data.into());
854 self.writer_tx
855 .send(WriterCommand::Send(msg))
856 .map_err(|e| SendError::BrokenPipe(e.to_string()))
857 }
858
859 pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
865 if !self.is_active() {
866 return Err(SendError::Closed);
867 }
868
869 let msg = Message::Close(None);
870 self.writer_tx
871 .send(WriterCommand::Send(msg))
872 .map_err(|e| SendError::BrokenPipe(e.to_string()))
873 }
874
875 fn spawn_controller_task(
876 mut inner: WebSocketClientInner,
877 connection_mode: Arc<AtomicU8>,
878 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
879 #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, ) -> tokio::task::JoinHandle<()> {
882 tokio::task::spawn(async move {
883 log_task_started("controller");
884
885 let check_interval = Duration::from_millis(10);
886
887 loop {
888 tokio::time::sleep(check_interval).await;
889 let mode = ConnectionMode::from_atomic(&connection_mode);
890
891 if mode.is_disconnect() {
892 tracing::debug!("Disconnecting");
893
894 let timeout = Duration::from_secs(5);
895 if tokio::time::timeout(timeout, async {
896 tokio::time::sleep(Duration::from_millis(100)).await;
898
899 if let Some(task) = &inner.read_task {
900 if !task.is_finished() {
901 task.abort();
902 log_task_aborted("read");
903 }
904 }
905
906 if let Some(task) = &inner.heartbeat_task {
907 if !task.is_finished() {
908 task.abort();
909 log_task_aborted("heartbeat");
910 }
911 }
912 })
913 .await
914 .is_err()
915 {
916 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
917 }
918
919 tracing::debug!("Closed");
920
921 #[cfg(feature = "python")]
922 if let Some(ref handler) = py_post_disconnection {
923 Python::with_gil(|py| match handler.call0(py) {
924 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
925 Err(e) => {
926 tracing::error!("Error calling `post_disconnection` handler: {e}");
927 }
928 });
929 }
930 break; }
932
933 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
934 match inner.reconnect().await {
935 Ok(()) => {
936 inner.backoff.reset();
937
938 if ConnectionMode::from_atomic(&connection_mode).is_active() {
940 if let Some(ref callback) = post_reconnection {
941 callback();
942 }
943
944 #[cfg(feature = "python")]
946 if let Some(ref callback) = py_post_reconnection {
947 Python::with_gil(|py| match callback.call0(py) {
948 Ok(_) => {
949 tracing::debug!("Called `post_reconnection` handler");
950 }
951 Err(e) => tracing::error!(
952 "Error calling `post_reconnection` handler: {e}"
953 ),
954 });
955 }
956
957 tracing::debug!("Reconnected successfully");
958 } else {
959 tracing::debug!(
960 "Skipping post_reconnection handlers due to disconnect state"
961 );
962 }
963 }
964 Err(e) => {
965 let duration = inner.backoff.next_duration();
966 tracing::warn!("Reconnect attempt failed: {e}");
967 if !duration.is_zero() {
968 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
969 }
970 tokio::time::sleep(duration).await;
971 }
972 }
973 }
974 }
975 inner
976 .connection_mode
977 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
978
979 log_task_stopped("controller");
980 })
981 }
982}
983
984impl Drop for WebSocketClient {
986 fn drop(&mut self) {
987 if !self.controller_task.is_finished() {
988 self.controller_task.abort();
989 log_task_aborted("controller");
990 }
991 }
992}
993
994#[cfg(feature = "python")]
998#[cfg(test)]
999#[cfg(target_os = "linux")] mod tests {
1001 use std::{num::NonZeroU32, sync::Arc};
1002
1003 use futures_util::{SinkExt, StreamExt};
1004 use tokio::{
1005 net::TcpListener,
1006 task::{self, JoinHandle},
1007 };
1008 use tokio_tungstenite::{
1009 accept_hdr_async,
1010 tungstenite::{
1011 handshake::server::{self, Callback},
1012 http::HeaderValue,
1013 },
1014 };
1015
1016 use crate::{
1017 ratelimiter::quota::Quota,
1018 websocket::{Consumer, WebSocketClient, WebSocketConfig},
1019 };
1020
1021 struct TestServer {
1022 task: JoinHandle<()>,
1023 port: u16,
1024 }
1025
1026 #[derive(Debug, Clone)]
1027 struct TestCallback {
1028 key: String,
1029 value: HeaderValue,
1030 }
1031
1032 impl Callback for TestCallback {
1033 fn on_request(
1034 self,
1035 request: &server::Request,
1036 response: server::Response,
1037 ) -> Result<server::Response, server::ErrorResponse> {
1038 let _ = response;
1039 let value = request.headers().get(&self.key);
1040 assert!(value.is_some());
1041
1042 if let Some(value) = request.headers().get(&self.key) {
1043 assert_eq!(value, self.value);
1044 }
1045
1046 Ok(response)
1047 }
1048 }
1049
1050 impl TestServer {
1051 async fn setup() -> Self {
1052 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1053 let port = TcpListener::local_addr(&server).unwrap().port();
1054
1055 let header_key = "test".to_string();
1056 let header_value = "test".to_string();
1057
1058 let test_call_back = TestCallback {
1059 key: header_key,
1060 value: HeaderValue::from_str(&header_value).unwrap(),
1061 };
1062
1063 let task = task::spawn(async move {
1064 loop {
1066 let (conn, _) = server.accept().await.unwrap();
1067 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1068 .await
1069 .unwrap();
1070
1071 task::spawn(async move {
1072 while let Some(Ok(msg)) = websocket.next().await {
1073 match msg {
1074 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1075 if txt == "close-now" =>
1076 {
1077 tracing::debug!("Forcibly closing from server side");
1078 let _ = websocket.close(None).await;
1080 break;
1081 }
1082 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1084 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1085 if websocket.send(msg).await.is_err() {
1086 break;
1087 }
1088 }
1089 tokio_tungstenite::tungstenite::protocol::Message::Close(
1091 _frame,
1092 ) => {
1093 let _ = websocket.close(None).await;
1094 break;
1095 }
1096 _ => {}
1098 }
1099 }
1100 });
1101 }
1102 });
1103
1104 Self { task, port }
1105 }
1106 }
1107
1108 impl Drop for TestServer {
1109 fn drop(&mut self) {
1110 self.task.abort();
1111 }
1112 }
1113
1114 async fn setup_test_client(port: u16) -> WebSocketClient {
1115 let config = WebSocketConfig {
1116 url: format!("ws://127.0.0.1:{port}"),
1117 headers: vec![("test".into(), "test".into())],
1118 handler: Consumer::Python(None),
1119 heartbeat: None,
1120 heartbeat_msg: None,
1121 ping_handler: None,
1122 reconnect_timeout_ms: None,
1123 reconnect_delay_initial_ms: None,
1124 reconnect_backoff_factor: None,
1125 reconnect_delay_max_ms: None,
1126 reconnect_jitter_ms: None,
1127 };
1128 WebSocketClient::connect(config, None, None, None, vec![], None)
1129 .await
1130 .expect("Failed to connect")
1131 }
1132
1133 #[tokio::test]
1134 async fn test_websocket_basic() {
1135 let server = TestServer::setup().await;
1136 let client = setup_test_client(server.port).await;
1137
1138 assert!(!client.is_disconnected());
1139
1140 client.disconnect().await;
1141 assert!(client.is_disconnected());
1142 }
1143
1144 #[tokio::test]
1145 async fn test_websocket_heartbeat() {
1146 let server = TestServer::setup().await;
1147 let client = setup_test_client(server.port).await;
1148
1149 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1151
1152 client.disconnect().await;
1154 assert!(client.is_disconnected());
1155 }
1156
1157 #[tokio::test]
1158 async fn test_websocket_reconnect_exhausted() {
1159 let config = WebSocketConfig {
1160 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1162 handler: Consumer::Python(None),
1163 heartbeat: None,
1164 heartbeat_msg: None,
1165 ping_handler: None,
1166 reconnect_timeout_ms: None,
1167 reconnect_delay_initial_ms: None,
1168 reconnect_backoff_factor: None,
1169 reconnect_delay_max_ms: None,
1170 reconnect_jitter_ms: None,
1171 };
1172 let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1173 assert!(res.is_err(), "Should fail quickly with no server");
1174 }
1175
1176 #[tokio::test]
1177 async fn test_websocket_forced_close_reconnect() {
1178 let server = TestServer::setup().await;
1179 let client = setup_test_client(server.port).await;
1180
1181 client.send_text("Hello".into(), None).await.unwrap();
1183
1184 client.send_text("close-now".into(), None).await.unwrap();
1186
1187 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1189
1190 assert!(!client.is_disconnected());
1192
1193 client.disconnect().await;
1195 assert!(client.is_disconnected());
1196 }
1197
1198 #[tokio::test]
1199 async fn test_rate_limiter() {
1200 let server = TestServer::setup().await;
1201 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1202
1203 let config = WebSocketConfig {
1204 url: format!("ws://127.0.0.1:{}", server.port),
1205 headers: vec![("test".into(), "test".into())],
1206 handler: Consumer::Python(None),
1207 heartbeat: None,
1208 heartbeat_msg: None,
1209 ping_handler: None,
1210 reconnect_timeout_ms: None,
1211 reconnect_delay_initial_ms: None,
1212 reconnect_backoff_factor: None,
1213 reconnect_delay_max_ms: None,
1214 reconnect_jitter_ms: None,
1215 };
1216
1217 let client = WebSocketClient::connect(
1218 config,
1219 None,
1220 None,
1221 None,
1222 vec![("default".into(), quota)],
1223 None,
1224 )
1225 .await
1226 .unwrap();
1227
1228 client.send_text("test1".into(), None).await.unwrap();
1230 client.send_text("test2".into(), None).await.unwrap();
1231
1232 client.send_text("test3".into(), None).await.unwrap();
1234
1235 client.disconnect().await;
1237 assert!(client.is_disconnected());
1238 }
1239
1240 #[tokio::test]
1241 async fn test_concurrent_writers() {
1242 let server = TestServer::setup().await;
1243 let client = Arc::new(setup_test_client(server.port).await);
1244
1245 let mut handles = vec![];
1246 for i in 0..10 {
1247 let client = client.clone();
1248 handles.push(task::spawn(async move {
1249 client.send_text(format!("test{i}"), None).await.unwrap();
1250 }));
1251 }
1252
1253 for handle in handles {
1254 handle.await.unwrap();
1255 }
1256
1257 client.disconnect().await;
1259 assert!(client.is_disconnected());
1260 }
1261}
1262
1263#[cfg(test)]
1264mod rust_tests {
1265 use tokio::{
1266 net::TcpListener,
1267 task,
1268 time::{Duration, sleep},
1269 };
1270 use tokio_tungstenite::accept_async;
1271
1272 use super::*;
1273
1274 #[tokio::test]
1275 async fn test_reconnect_then_disconnect() {
1276 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1278 let port = listener.local_addr().unwrap().port();
1279
1280 let server = task::spawn(async move {
1282 let (stream, _) = listener.accept().await.unwrap();
1283 let ws = accept_async(stream).await.unwrap();
1284 drop(ws);
1285 sleep(Duration::from_secs(1)).await;
1287 });
1288
1289 let (consumer, _rx) = Consumer::rust_consumer();
1291
1292 let config = WebSocketConfig {
1294 url: format!("ws://127.0.0.1:{port}"),
1295 headers: vec![],
1296 handler: consumer,
1297 heartbeat: None,
1298 heartbeat_msg: None,
1299 #[cfg(feature = "python")]
1300 ping_handler: None,
1301 reconnect_timeout_ms: Some(1_000),
1302 reconnect_delay_initial_ms: Some(50),
1303 reconnect_delay_max_ms: Some(100),
1304 reconnect_backoff_factor: Some(1.0),
1305 reconnect_jitter_ms: Some(0),
1306 };
1307
1308 let client = {
1310 #[cfg(feature = "python")]
1311 {
1312 WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1313 .await
1314 .unwrap()
1315 }
1316 #[cfg(not(feature = "python"))]
1317 {
1318 WebSocketClient::connect(config.clone(), vec![], None)
1319 .await
1320 .unwrap()
1321 }
1322 };
1323
1324 sleep(Duration::from_millis(100)).await;
1326 client.disconnect().await;
1328 assert!(client.is_disconnected());
1329 server.abort();
1330 }
1331}