nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49    net::TcpStream,
50    sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53    MaybeTlsStream, WebSocketStream, connect_async,
54    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58    backoff::ExponentialBackoff,
59    error::SendError,
60    logging::{log_task_aborted, log_task_started, log_task_stopped},
61    mode::ConnectionMode,
62    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68/// Defines how WebSocket messages are consumed.
69#[derive(Debug, Clone)]
70pub enum Consumer {
71    /// Python-based message handler.
72    #[cfg(feature = "python")]
73    Python(Option<Arc<PyObject>>),
74    /// Rust-based message handler using a channel sender.
75    Rust(Sender<Message>),
76}
77
78impl Consumer {
79    /// Creates a Rust-based consumer with a channel for receiving messages.
80    ///
81    /// Returns a tuple containing the consumer and a receiver for messages.
82    #[must_use]
83    pub fn rust_consumer() -> (Self, Receiver<Message>) {
84        let (tx, rx) = mpsc::channel(100);
85        (Self::Rust(tx), rx)
86    }
87}
88
89#[derive(Debug, Clone)]
90#[cfg_attr(
91    feature = "python",
92    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
93)]
94pub struct WebSocketConfig {
95    /// The URL to connect to.
96    pub url: String,
97    /// The default headers.
98    pub headers: Vec<(String, String)>,
99    /// The Python function to handle incoming messages.
100    pub handler: Consumer,
101    /// The optional heartbeat interval (seconds).
102    pub heartbeat: Option<u64>,
103    /// The optional heartbeat message.
104    pub heartbeat_msg: Option<String>,
105    /// The handler for incoming pings.
106    #[cfg(feature = "python")]
107    pub ping_handler: Option<Arc<PyObject>>,
108    /// The timeout (milliseconds) for reconnection attempts.
109    pub reconnect_timeout_ms: Option<u64>,
110    /// The initial reconnection delay (milliseconds) for reconnects.
111    pub reconnect_delay_initial_ms: Option<u64>,
112    /// The maximum reconnect delay (milliseconds) for exponential backoff.
113    pub reconnect_delay_max_ms: Option<u64>,
114    /// The exponential backoff factor for reconnection delays.
115    pub reconnect_backoff_factor: Option<f64>,
116    /// The maximum jitter (milliseconds) added to reconnection delays.
117    pub reconnect_jitter_ms: Option<u64>,
118}
119
120/// Represents a command for the writer task.
121#[derive(Debug)]
122pub(crate) enum WriterCommand {
123    /// Update the writer reference with a new one after reconnection.
124    Update(MessageWriter),
125    /// Send message to the server.
126    Send(Message),
127}
128
129/// `WebSocketClient` connects to a websocket server to read and send messages.
130///
131/// The client is opinionated about how messages are read and written. It
132/// assumes that data can only have one reader but multiple writers.
133///
134/// The client splits the connection into read and write halves. It moves
135/// the read half into a tokio task which keeps receiving messages from the
136/// server and calls a handler - a Python function that takes the data
137/// as its parameter. It stores the write half in the struct wrapped
138/// with an Arc Mutex. This way the client struct can be used to write
139/// data to the server from multiple scopes/tasks.
140///
141/// The client also maintains a heartbeat if given a duration in seconds.
142/// It's preferable to set the duration slightly lower - heartbeat more
143/// frequently - than the required amount.
144struct WebSocketClientInner {
145    config: WebSocketConfig,
146    read_task: Option<tokio::task::JoinHandle<()>>,
147    write_task: tokio::task::JoinHandle<()>,
148    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
149    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
150    connection_mode: Arc<AtomicU8>,
151    reconnect_timeout: Duration,
152    backoff: ExponentialBackoff,
153}
154
155impl WebSocketClientInner {
156    /// Create an inner websocket client.
157    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
158        install_cryptographic_provider();
159
160        #[allow(unused_variables)]
161        let WebSocketConfig {
162            url,
163            handler,
164            heartbeat,
165            headers,
166            heartbeat_msg,
167            #[cfg(feature = "python")]
168            ping_handler,
169            reconnect_timeout_ms,
170            reconnect_delay_initial_ms,
171            reconnect_delay_max_ms,
172            reconnect_backoff_factor,
173            reconnect_jitter_ms,
174        } = &config;
175        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
176
177        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
178
179        let read_task = match &handler {
180            #[cfg(feature = "python")]
181            Consumer::Python(handler) => handler.as_ref().map(|handler| {
182                Self::spawn_python_callback_task(
183                    connection_mode.clone(),
184                    reader,
185                    handler.clone(),
186                    ping_handler.clone(),
187                )
188            }),
189            Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
190                connection_mode.clone(),
191                reader,
192                sender.clone(),
193            )),
194        };
195
196        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
197        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
198
199        // Optionally spawn a heartbeat task to periodically ping server
200        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
201            Self::spawn_heartbeat_task(
202                connection_mode.clone(),
203                *heartbeat_secs,
204                heartbeat_msg.clone(),
205                writer_tx.clone(),
206            )
207        });
208
209        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
210        let backoff = ExponentialBackoff::new(
211            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
212            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
213            reconnect_backoff_factor.unwrap_or(1.5),
214            reconnect_jitter_ms.unwrap_or(100),
215            true, // immediate-first
216        )
217        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
218
219        Ok(Self {
220            config,
221            read_task,
222            write_task,
223            writer_tx,
224            heartbeat_task,
225            connection_mode,
226            reconnect_timeout,
227            backoff,
228        })
229    }
230
231    /// Connects with the server creating a tokio-tungstenite websocket stream.
232    #[inline]
233    pub async fn connect_with_server(
234        url: &str,
235        headers: Vec<(String, String)>,
236    ) -> Result<(MessageWriter, MessageReader), Error> {
237        let mut request = url.into_client_request()?;
238        let req_headers = request.headers_mut();
239
240        let mut header_names: Vec<HeaderName> = Vec::new();
241        for (key, val) in headers {
242            let header_value = HeaderValue::from_str(&val)?;
243            let header_name: HeaderName = key.parse()?;
244            header_names.push(header_name.clone());
245            req_headers.insert(header_name, header_value);
246        }
247
248        connect_async(request).await.map(|resp| resp.0.split())
249    }
250
251    /// Reconnect with server.
252    ///
253    /// Make a new connection with server. Use the new read and write halves
254    /// to update self writer and read and heartbeat tasks.
255    pub async fn reconnect(&mut self) -> Result<(), Error> {
256        tracing::debug!("Reconnecting");
257
258        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
259            tracing::debug!("Reconnect aborted due to disconnect state");
260            return Ok(());
261        }
262
263        tokio::time::timeout(self.reconnect_timeout, async {
264            // Attempt to connect; abort early if a disconnect was requested
265            let (new_writer, reader) =
266                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
267
268            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
269                tracing::debug!("Reconnect aborted mid-flight (after connect)");
270                return Ok(());
271            }
272
273            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
274                tracing::error!("{e}");
275            }
276
277            // Delay before closing connection
278            tokio::time::sleep(Duration::from_millis(100)).await;
279
280            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
281                tracing::debug!("Reconnect aborted mid-flight (after delay)");
282                return Ok(());
283            }
284
285            if let Some(ref read_task) = self.read_task.take() {
286                if !read_task.is_finished() {
287                    read_task.abort();
288                    log_task_aborted("read");
289                }
290            }
291
292            // If a disconnect was requested during reconnect, do not proceed to reactivate
293            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
294                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
295                return Ok(());
296            }
297
298            // Mark as active only if not disconnecting
299            self.connection_mode
300                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
301
302            self.read_task = match &self.config.handler {
303                #[cfg(feature = "python")]
304                Consumer::Python(handler) => handler.as_ref().map(|handler| {
305                    Self::spawn_python_callback_task(
306                        self.connection_mode.clone(),
307                        reader,
308                        handler.clone(),
309                        self.config.ping_handler.clone(),
310                    )
311                }),
312                Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
313                    self.connection_mode.clone(),
314                    reader,
315                    sender.clone(),
316                )),
317            };
318
319            tracing::debug!("Reconnect succeeded");
320            Ok(())
321        })
322        .await
323        .map_err(|_| {
324            Error::Io(std::io::Error::new(
325                std::io::ErrorKind::TimedOut,
326                format!(
327                    "reconnection timed out after {}s",
328                    self.reconnect_timeout.as_secs_f64()
329                ),
330            ))
331        })?
332    }
333
334    /// Check if the client is still connected.
335    ///
336    /// The client is connected if the read task has not finished. It is expected
337    /// that in case of any failure client or server side. The read task will be
338    /// shutdown or will receive a `Close` frame which will finish it. There
339    /// might be some delay between the connection being closed and the client
340    /// detecting.
341    #[inline]
342    #[must_use]
343    pub fn is_alive(&self) -> bool {
344        match &self.read_task {
345            Some(read_task) => !read_task.is_finished(),
346            None => true, // Stream is being used directly
347        }
348    }
349
350    fn spawn_rust_streaming_task(
351        connection_state: Arc<AtomicU8>,
352        mut reader: MessageReader,
353        sender: Sender<Message>,
354    ) -> tokio::task::JoinHandle<()> {
355        tracing::debug!("Started streaming task 'read'");
356
357        let check_interval = Duration::from_millis(10);
358
359        tokio::task::spawn(async move {
360            loop {
361                if !ConnectionMode::from_atomic(&connection_state).is_active() {
362                    break;
363                }
364
365                match tokio::time::timeout(check_interval, reader.next()).await {
366                    Ok(Some(Ok(message))) => {
367                        if let Err(e) = sender.send(message).await {
368                            tracing::error!("Failed to send message: {e}");
369                        }
370                    }
371                    Ok(Some(Err(e))) => {
372                        tracing::error!("Received error message - terminating: {e}");
373                        break;
374                    }
375                    Ok(None) => {
376                        tracing::debug!("No message received - terminating");
377                        break;
378                    }
379                    Err(_) => {
380                        // Timeout - continue loop and check connection mode
381                        continue;
382                    }
383                }
384            }
385        })
386    }
387
388    #[cfg(feature = "python")]
389    fn spawn_python_callback_task(
390        connection_state: Arc<AtomicU8>,
391        mut reader: MessageReader,
392        handler: Arc<PyObject>,
393        ping_handler: Option<Arc<PyObject>>,
394    ) -> tokio::task::JoinHandle<()> {
395        log_task_started("read");
396
397        // Interval between checking the connection mode
398        let check_interval = Duration::from_millis(10);
399
400        tokio::task::spawn(async move {
401            loop {
402                if !ConnectionMode::from_atomic(&connection_state).is_active() {
403                    break;
404                }
405
406                match tokio::time::timeout(check_interval, reader.next()).await {
407                    Ok(Some(Ok(Message::Binary(data)))) => {
408                        tracing::trace!("Received message <binary> {} bytes", data.len());
409                        if let Err(e) =
410                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
411                        {
412                            tracing::error!("Error calling handler: {e}");
413                            break;
414                        }
415                        continue;
416                    }
417                    Ok(Some(Ok(Message::Text(data)))) => {
418                        tracing::trace!("Received message: {data}");
419                        if let Err(e) = Python::with_gil(|py| {
420                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
421                        }) {
422                            tracing::error!("Error calling handler: {e}");
423                            break;
424                        }
425                        continue;
426                    }
427                    Ok(Some(Ok(Message::Ping(ping)))) => {
428                        tracing::trace!("Received ping: {ping:?}");
429                        if let Some(ref handler) = ping_handler {
430                            if let Err(e) =
431                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
432                            {
433                                tracing::error!("Error calling handler: {e}");
434                                break;
435                            }
436                        }
437                        continue;
438                    }
439                    Ok(Some(Ok(Message::Pong(_)))) => {
440                        tracing::trace!("Received pong");
441                    }
442                    Ok(Some(Ok(Message::Close(_)))) => {
443                        tracing::debug!("Received close message - terminating");
444                        break;
445                    }
446                    Ok(Some(Ok(_))) => (),
447                    Ok(Some(Err(e))) => {
448                        tracing::error!("Received error message - terminating: {e}");
449                        break;
450                    }
451                    // Internally tungstenite considers the connection closed when polling
452                    // for the next message in the stream returns None.
453                    Ok(None) => {
454                        tracing::debug!("No message received - terminating");
455                        break;
456                    }
457                    Err(_) => {
458                        // Timeout - continue loop and check connection mode
459                        continue;
460                    }
461                }
462            }
463        })
464    }
465
466    fn spawn_write_task(
467        connection_state: Arc<AtomicU8>,
468        writer: MessageWriter,
469        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
470    ) -> tokio::task::JoinHandle<()> {
471        log_task_started("write");
472
473        // Interval between checking the connection mode
474        let check_interval = Duration::from_millis(10);
475
476        tokio::task::spawn(async move {
477            let mut active_writer = writer;
478
479            loop {
480                match ConnectionMode::from_atomic(&connection_state) {
481                    ConnectionMode::Disconnect => {
482                        // Attempt to close the writer gracefully before exiting,
483                        // we ignore any error as the writer may already be closed.
484                        _ = active_writer.close().await;
485                        break;
486                    }
487                    ConnectionMode::Closed => break,
488                    _ => {}
489                }
490
491                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
492                    Ok(Some(msg)) => {
493                        // Re-check connection mode after receiving a message
494                        let mode = ConnectionMode::from_atomic(&connection_state);
495                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
496                            break;
497                        }
498
499                        match msg {
500                            WriterCommand::Update(new_writer) => {
501                                tracing::debug!("Received new writer");
502
503                                // Delay before closing connection
504                                tokio::time::sleep(Duration::from_millis(100)).await;
505
506                                // Attempt to close the writer gracefully on update,
507                                // we ignore any error as the writer may already be closed.
508                                _ = active_writer.close().await;
509
510                                active_writer = new_writer;
511                                tracing::debug!("Updated writer");
512                            }
513                            _ if mode.is_reconnect() => {
514                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
515                                continue;
516                            }
517                            WriterCommand::Send(msg) => {
518                                if let Err(e) = active_writer.send(msg).await {
519                                    tracing::error!("Failed to send message: {e}");
520                                    // Mode is active so trigger reconnection
521                                    tracing::warn!("Writer triggering reconnect");
522                                    connection_state
523                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
524                                }
525                            }
526                        }
527                    }
528                    Ok(None) => {
529                        // Channel closed - writer task should terminate
530                        tracing::debug!("Writer channel closed, terminating writer task");
531                        break;
532                    }
533                    Err(_) => {
534                        // Timeout - just continue the loop
535                        continue;
536                    }
537                }
538            }
539
540            // Attempt to close the writer gracefully before exiting,
541            // we ignore any error as the writer may already be closed.
542            _ = active_writer.close().await;
543
544            log_task_stopped("write");
545        })
546    }
547
548    fn spawn_heartbeat_task(
549        connection_state: Arc<AtomicU8>,
550        heartbeat_secs: u64,
551        message: Option<String>,
552        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
553    ) -> tokio::task::JoinHandle<()> {
554        log_task_started("heartbeat");
555
556        tokio::task::spawn(async move {
557            let interval = Duration::from_secs(heartbeat_secs);
558
559            loop {
560                tokio::time::sleep(interval).await;
561
562                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
563                    ConnectionMode::Active => {
564                        let msg = match &message {
565                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
566                            None => WriterCommand::Send(Message::Ping(vec![].into())),
567                        };
568
569                        match writer_tx.send(msg) {
570                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
571                            Err(e) => {
572                                tracing::error!("Failed to send heartbeat to writer task: {e}");
573                            }
574                        }
575                    }
576                    ConnectionMode::Reconnect => continue,
577                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
578                }
579            }
580
581            log_task_stopped("heartbeat");
582        })
583    }
584}
585
586impl Drop for WebSocketClientInner {
587    fn drop(&mut self) {
588        if let Some(ref read_task) = self.read_task.take() {
589            if !read_task.is_finished() {
590                read_task.abort();
591                log_task_aborted("read");
592            }
593        }
594
595        if !self.write_task.is_finished() {
596            self.write_task.abort();
597            log_task_aborted("write");
598        }
599
600        if let Some(ref handle) = self.heartbeat_task.take() {
601            if !handle.is_finished() {
602                handle.abort();
603                log_task_aborted("heartbeat");
604            }
605        }
606    }
607}
608
609/// WebSocket client with automatic reconnection.
610///
611/// Handles connection state, Python callbacks, and rate limiting.
612/// See module docs for architecture details.
613#[cfg_attr(
614    feature = "python",
615    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
616)]
617pub struct WebSocketClient {
618    pub(crate) controller_task: tokio::task::JoinHandle<()>,
619    pub(crate) connection_mode: Arc<AtomicU8>,
620    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
621    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
622}
623
624impl Debug for WebSocketClient {
625    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        f.debug_struct(stringify!(WebSocketClient)).finish()
627    }
628}
629
630impl WebSocketClient {
631    /// Creates a websocket client that returns a stream for reading messages.
632    ///
633    /// # Errors
634    ///
635    /// Returns any error connecting to the server.
636    #[allow(clippy::too_many_arguments)]
637    pub async fn connect_stream(
638        config: WebSocketConfig,
639        keyed_quotas: Vec<(String, Quota)>,
640        default_quota: Option<Quota>,
641        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
642    ) -> Result<(MessageReader, Self), Error> {
643        install_cryptographic_provider();
644        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
645        let (writer, reader) = ws_stream.split();
646        let inner = WebSocketClientInner::connect_url(config).await?;
647
648        let connection_mode = inner.connection_mode.clone();
649
650        let writer_tx = inner.writer_tx.clone();
651        if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
652            tracing::error!("{e}");
653        }
654
655        let controller_task = Self::spawn_controller_task(
656            inner,
657            connection_mode.clone(),
658            post_reconnect,
659            #[cfg(feature = "python")]
660            None, // no post_reconnection
661            #[cfg(feature = "python")]
662            None, // no post_disconnection
663        );
664
665        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
666
667        Ok((
668            reader,
669            Self {
670                controller_task,
671                connection_mode,
672                writer_tx,
673                rate_limiter,
674            },
675        ))
676    }
677
678    /// Creates a websocket client.
679    ///
680    /// Creates an inner client and controller task to reconnect or disconnect
681    /// the client. Also assumes ownership of writer from inner client.
682    ///
683    /// # Errors
684    ///
685    /// Returns any websocket error.
686    pub async fn connect(
687        config: WebSocketConfig,
688        #[cfg(feature = "python")] post_connection: Option<PyObject>,
689        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
690        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
691        keyed_quotas: Vec<(String, Quota)>,
692        default_quota: Option<Quota>,
693    ) -> Result<Self, Error> {
694        tracing::debug!("Connecting");
695        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
696        let connection_mode = inner.connection_mode.clone();
697        let writer_tx = inner.writer_tx.clone();
698
699        let controller_task = Self::spawn_controller_task(
700            inner,
701            connection_mode.clone(),
702            None, // Rust handler
703            #[cfg(feature = "python")]
704            post_reconnection, // TODO: Deprecated
705            #[cfg(feature = "python")]
706            post_disconnection, // TODO: Deprecated
707        );
708
709        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
710
711        #[cfg(feature = "python")]
712        if let Some(handler) = post_connection {
713            Python::with_gil(|py| match handler.call0(py) {
714                Ok(_) => tracing::debug!("Called `post_connection` handler"),
715                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
716            });
717        }
718
719        Ok(Self {
720            controller_task,
721            connection_mode,
722            writer_tx,
723            rate_limiter,
724        })
725    }
726
727    /// Returns the current connection mode.
728    #[must_use]
729    pub fn connection_mode(&self) -> ConnectionMode {
730        ConnectionMode::from_atomic(&self.connection_mode)
731    }
732
733    /// Check if the client connection is active.
734    ///
735    /// Returns `true` if the client is connected and has not been signalled to disconnect.
736    /// The client will automatically retry connection based on its configuration.
737    #[inline]
738    #[must_use]
739    pub fn is_active(&self) -> bool {
740        self.connection_mode().is_active()
741    }
742
743    /// Check if the client is disconnected.
744    #[must_use]
745    pub fn is_disconnected(&self) -> bool {
746        self.controller_task.is_finished()
747    }
748
749    /// Check if the client is reconnecting.
750    ///
751    /// Returns `true` if the client lost connection and is attempting to reestablish it.
752    /// The client will automatically retry connection based on its configuration.
753    #[inline]
754    #[must_use]
755    pub fn is_reconnecting(&self) -> bool {
756        self.connection_mode().is_reconnect()
757    }
758
759    /// Check if the client is disconnecting.
760    ///
761    /// Returns `true` if the client is in disconnect mode.
762    #[inline]
763    #[must_use]
764    pub fn is_disconnecting(&self) -> bool {
765        self.connection_mode().is_disconnect()
766    }
767
768    /// Check if the client is closed.
769    ///
770    /// Returns `true` if the client has been explicitly disconnected or reached
771    /// maximum reconnection attempts. In this state, the client cannot be reused
772    /// and a new client must be created for further connections.
773    #[inline]
774    #[must_use]
775    pub fn is_closed(&self) -> bool {
776        self.connection_mode().is_closed()
777    }
778
779    /// Set disconnect mode to true.
780    ///
781    /// Controller task will periodically check the disconnect mode
782    /// and shutdown the client if it is alive
783    pub async fn disconnect(&self) {
784        tracing::debug!("Disconnecting");
785        self.connection_mode
786            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
787
788        match tokio::time::timeout(Duration::from_secs(5), async {
789            while !self.is_disconnected() {
790                tokio::time::sleep(Duration::from_millis(10)).await;
791            }
792
793            if !self.controller_task.is_finished() {
794                self.controller_task.abort();
795                log_task_aborted("controller");
796            }
797        })
798        .await
799        {
800            Ok(()) => {
801                tracing::debug!("Controller task finished");
802            }
803            Err(_) => {
804                tracing::error!("Timeout waiting for controller task to finish");
805            }
806        }
807    }
808
809    /// Sends the given text `data` to the server.
810    ///
811    /// # Errors
812    ///
813    /// Returns a websocket error if unable to send.
814    #[allow(unused_variables)]
815    pub async fn send_text(
816        &self,
817        data: String,
818        keys: Option<Vec<String>>,
819    ) -> std::result::Result<(), SendError> {
820        self.rate_limiter.await_keys_ready(keys).await;
821
822        if !self.is_active() {
823            return Err(SendError::Closed);
824        }
825
826        tracing::trace!("Sending text: {data:?}");
827
828        let msg = Message::Text(data.into());
829        self.writer_tx
830            .send(WriterCommand::Send(msg))
831            .map_err(|e| SendError::BrokenPipe(e.to_string()))
832    }
833
834    /// Sends the given bytes `data` to the server.
835    ///
836    /// # Errors
837    ///
838    /// Returns a websocket error if unable to send.
839    #[allow(unused_variables)]
840    pub async fn send_bytes(
841        &self,
842        data: Vec<u8>,
843        keys: Option<Vec<String>>,
844    ) -> std::result::Result<(), SendError> {
845        self.rate_limiter.await_keys_ready(keys).await;
846
847        if !self.is_active() {
848            return Err(SendError::Closed);
849        }
850
851        tracing::trace!("Sending bytes: {data:?}");
852
853        let msg = Message::Binary(data.into());
854        self.writer_tx
855            .send(WriterCommand::Send(msg))
856            .map_err(|e| SendError::BrokenPipe(e.to_string()))
857    }
858
859    /// Sends a close message to the server.
860    ///
861    /// # Errors
862    ///
863    /// Returns a websocket error if unable to send.
864    pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
865        if !self.is_active() {
866            return Err(SendError::Closed);
867        }
868
869        let msg = Message::Close(None);
870        self.writer_tx
871            .send(WriterCommand::Send(msg))
872            .map_err(|e| SendError::BrokenPipe(e.to_string()))
873    }
874
875    fn spawn_controller_task(
876        mut inner: WebSocketClientInner,
877        connection_mode: Arc<AtomicU8>,
878        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
879        #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, // TODO: Deprecated
880        #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, // TODO: Deprecated
881    ) -> tokio::task::JoinHandle<()> {
882        tokio::task::spawn(async move {
883            log_task_started("controller");
884
885            let check_interval = Duration::from_millis(10);
886
887            loop {
888                tokio::time::sleep(check_interval).await;
889                let mode = ConnectionMode::from_atomic(&connection_mode);
890
891                if mode.is_disconnect() {
892                    tracing::debug!("Disconnecting");
893
894                    let timeout = Duration::from_secs(5);
895                    if tokio::time::timeout(timeout, async {
896                        // Delay awaiting graceful shutdown
897                        tokio::time::sleep(Duration::from_millis(100)).await;
898
899                        if let Some(task) = &inner.read_task {
900                            if !task.is_finished() {
901                                task.abort();
902                                log_task_aborted("read");
903                            }
904                        }
905
906                        if let Some(task) = &inner.heartbeat_task {
907                            if !task.is_finished() {
908                                task.abort();
909                                log_task_aborted("heartbeat");
910                            }
911                        }
912                    })
913                    .await
914                    .is_err()
915                    {
916                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
917                    }
918
919                    tracing::debug!("Closed");
920
921                    #[cfg(feature = "python")]
922                    if let Some(ref handler) = py_post_disconnection {
923                        Python::with_gil(|py| match handler.call0(py) {
924                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
925                            Err(e) => {
926                                tracing::error!("Error calling `post_disconnection` handler: {e}");
927                            }
928                        });
929                    }
930                    break; // Controller finished
931                }
932
933                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
934                    match inner.reconnect().await {
935                        Ok(()) => {
936                            inner.backoff.reset();
937
938                            // Only invoke callbacks if not in disconnect state
939                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
940                                if let Some(ref callback) = post_reconnection {
941                                    callback();
942                                }
943
944                                // TODO: Python based websocket handlers deprecated (will be removed)
945                                #[cfg(feature = "python")]
946                                if let Some(ref callback) = py_post_reconnection {
947                                    Python::with_gil(|py| match callback.call0(py) {
948                                        Ok(_) => {
949                                            tracing::debug!("Called `post_reconnection` handler");
950                                        }
951                                        Err(e) => tracing::error!(
952                                            "Error calling `post_reconnection` handler: {e}"
953                                        ),
954                                    });
955                                }
956
957                                tracing::debug!("Reconnected successfully");
958                            } else {
959                                tracing::debug!(
960                                    "Skipping post_reconnection handlers due to disconnect state"
961                                );
962                            }
963                        }
964                        Err(e) => {
965                            let duration = inner.backoff.next_duration();
966                            tracing::warn!("Reconnect attempt failed: {e}");
967                            if !duration.is_zero() {
968                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
969                            }
970                            tokio::time::sleep(duration).await;
971                        }
972                    }
973                }
974            }
975            inner
976                .connection_mode
977                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
978
979            log_task_stopped("controller");
980        })
981    }
982}
983
984// Abort controller task on drop to clean up background tasks
985impl Drop for WebSocketClient {
986    fn drop(&mut self) {
987        if !self.controller_task.is_finished() {
988            self.controller_task.abort();
989            log_task_aborted("controller");
990        }
991    }
992}
993
994////////////////////////////////////////////////////////////////////////////////
995// Tests
996////////////////////////////////////////////////////////////////////////////////
997#[cfg(feature = "python")]
998#[cfg(test)]
999#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1000mod tests {
1001    use std::{num::NonZeroU32, sync::Arc};
1002
1003    use futures_util::{SinkExt, StreamExt};
1004    use tokio::{
1005        net::TcpListener,
1006        task::{self, JoinHandle},
1007    };
1008    use tokio_tungstenite::{
1009        accept_hdr_async,
1010        tungstenite::{
1011            handshake::server::{self, Callback},
1012            http::HeaderValue,
1013        },
1014    };
1015
1016    use crate::{
1017        ratelimiter::quota::Quota,
1018        websocket::{Consumer, WebSocketClient, WebSocketConfig},
1019    };
1020
1021    struct TestServer {
1022        task: JoinHandle<()>,
1023        port: u16,
1024    }
1025
1026    #[derive(Debug, Clone)]
1027    struct TestCallback {
1028        key: String,
1029        value: HeaderValue,
1030    }
1031
1032    impl Callback for TestCallback {
1033        fn on_request(
1034            self,
1035            request: &server::Request,
1036            response: server::Response,
1037        ) -> Result<server::Response, server::ErrorResponse> {
1038            let _ = response;
1039            let value = request.headers().get(&self.key);
1040            assert!(value.is_some());
1041
1042            if let Some(value) = request.headers().get(&self.key) {
1043                assert_eq!(value, self.value);
1044            }
1045
1046            Ok(response)
1047        }
1048    }
1049
1050    impl TestServer {
1051        async fn setup() -> Self {
1052            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1053            let port = TcpListener::local_addr(&server).unwrap().port();
1054
1055            let header_key = "test".to_string();
1056            let header_value = "test".to_string();
1057
1058            let test_call_back = TestCallback {
1059                key: header_key,
1060                value: HeaderValue::from_str(&header_value).unwrap(),
1061            };
1062
1063            let task = task::spawn(async move {
1064                // Keep accepting connections
1065                loop {
1066                    let (conn, _) = server.accept().await.unwrap();
1067                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1068                        .await
1069                        .unwrap();
1070
1071                    task::spawn(async move {
1072                        while let Some(Ok(msg)) = websocket.next().await {
1073                            match msg {
1074                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1075                                    if txt == "close-now" =>
1076                                {
1077                                    tracing::debug!("Forcibly closing from server side");
1078                                    // This sends a close frame, then stops reading
1079                                    let _ = websocket.close(None).await;
1080                                    break;
1081                                }
1082                                // Echo text/binary frames
1083                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1084                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1085                                    if websocket.send(msg).await.is_err() {
1086                                        break;
1087                                    }
1088                                }
1089                                // If the client closes, we also break
1090                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1091                                    _frame,
1092                                ) => {
1093                                    let _ = websocket.close(None).await;
1094                                    break;
1095                                }
1096                                // Ignore pings/pongs
1097                                _ => {}
1098                            }
1099                        }
1100                    });
1101                }
1102            });
1103
1104            Self { task, port }
1105        }
1106    }
1107
1108    impl Drop for TestServer {
1109        fn drop(&mut self) {
1110            self.task.abort();
1111        }
1112    }
1113
1114    async fn setup_test_client(port: u16) -> WebSocketClient {
1115        let config = WebSocketConfig {
1116            url: format!("ws://127.0.0.1:{port}"),
1117            headers: vec![("test".into(), "test".into())],
1118            handler: Consumer::Python(None),
1119            heartbeat: None,
1120            heartbeat_msg: None,
1121            ping_handler: None,
1122            reconnect_timeout_ms: None,
1123            reconnect_delay_initial_ms: None,
1124            reconnect_backoff_factor: None,
1125            reconnect_delay_max_ms: None,
1126            reconnect_jitter_ms: None,
1127        };
1128        WebSocketClient::connect(config, None, None, None, vec![], None)
1129            .await
1130            .expect("Failed to connect")
1131    }
1132
1133    #[tokio::test]
1134    async fn test_websocket_basic() {
1135        let server = TestServer::setup().await;
1136        let client = setup_test_client(server.port).await;
1137
1138        assert!(!client.is_disconnected());
1139
1140        client.disconnect().await;
1141        assert!(client.is_disconnected());
1142    }
1143
1144    #[tokio::test]
1145    async fn test_websocket_heartbeat() {
1146        let server = TestServer::setup().await;
1147        let client = setup_test_client(server.port).await;
1148
1149        // Wait ~3s => server should see multiple "ping"
1150        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1151
1152        // Cleanup
1153        client.disconnect().await;
1154        assert!(client.is_disconnected());
1155    }
1156
1157    #[tokio::test]
1158    async fn test_websocket_reconnect_exhausted() {
1159        let config = WebSocketConfig {
1160            url: "ws://127.0.0.1:9997".into(), // <-- No server
1161            headers: vec![],
1162            handler: Consumer::Python(None),
1163            heartbeat: None,
1164            heartbeat_msg: None,
1165            ping_handler: None,
1166            reconnect_timeout_ms: None,
1167            reconnect_delay_initial_ms: None,
1168            reconnect_backoff_factor: None,
1169            reconnect_delay_max_ms: None,
1170            reconnect_jitter_ms: None,
1171        };
1172        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1173        assert!(res.is_err(), "Should fail quickly with no server");
1174    }
1175
1176    #[tokio::test]
1177    async fn test_websocket_forced_close_reconnect() {
1178        let server = TestServer::setup().await;
1179        let client = setup_test_client(server.port).await;
1180
1181        // 1) Send normal message
1182        client.send_text("Hello".into(), None).await.unwrap();
1183
1184        // 2) Trigger forced close from server
1185        client.send_text("close-now".into(), None).await.unwrap();
1186
1187        // 3) Wait a bit => read loop sees close => reconnect
1188        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1189
1190        // Confirm not disconnected
1191        assert!(!client.is_disconnected());
1192
1193        // Cleanup
1194        client.disconnect().await;
1195        assert!(client.is_disconnected());
1196    }
1197
1198    #[tokio::test]
1199    async fn test_rate_limiter() {
1200        let server = TestServer::setup().await;
1201        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1202
1203        let config = WebSocketConfig {
1204            url: format!("ws://127.0.0.1:{}", server.port),
1205            headers: vec![("test".into(), "test".into())],
1206            handler: Consumer::Python(None),
1207            heartbeat: None,
1208            heartbeat_msg: None,
1209            ping_handler: None,
1210            reconnect_timeout_ms: None,
1211            reconnect_delay_initial_ms: None,
1212            reconnect_backoff_factor: None,
1213            reconnect_delay_max_ms: None,
1214            reconnect_jitter_ms: None,
1215        };
1216
1217        let client = WebSocketClient::connect(
1218            config,
1219            None,
1220            None,
1221            None,
1222            vec![("default".into(), quota)],
1223            None,
1224        )
1225        .await
1226        .unwrap();
1227
1228        // First 2 should succeed
1229        client.send_text("test1".into(), None).await.unwrap();
1230        client.send_text("test2".into(), None).await.unwrap();
1231
1232        // Third should error
1233        client.send_text("test3".into(), None).await.unwrap();
1234
1235        // Cleanup
1236        client.disconnect().await;
1237        assert!(client.is_disconnected());
1238    }
1239
1240    #[tokio::test]
1241    async fn test_concurrent_writers() {
1242        let server = TestServer::setup().await;
1243        let client = Arc::new(setup_test_client(server.port).await);
1244
1245        let mut handles = vec![];
1246        for i in 0..10 {
1247            let client = client.clone();
1248            handles.push(task::spawn(async move {
1249                client.send_text(format!("test{i}"), None).await.unwrap();
1250            }));
1251        }
1252
1253        for handle in handles {
1254            handle.await.unwrap();
1255        }
1256
1257        // Cleanup
1258        client.disconnect().await;
1259        assert!(client.is_disconnected());
1260    }
1261}
1262
1263#[cfg(test)]
1264mod rust_tests {
1265    use tokio::{
1266        net::TcpListener,
1267        task,
1268        time::{Duration, sleep},
1269    };
1270    use tokio_tungstenite::accept_async;
1271
1272    use super::*;
1273
1274    #[tokio::test]
1275    async fn test_reconnect_then_disconnect() {
1276        // Bind an ephemeral port
1277        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1278        let port = listener.local_addr().unwrap().port();
1279
1280        // Server task: accept one ws connection then close it
1281        let server = task::spawn(async move {
1282            let (stream, _) = listener.accept().await.unwrap();
1283            let ws = accept_async(stream).await.unwrap();
1284            drop(ws);
1285            // Keep alive briefly
1286            sleep(Duration::from_secs(1)).await;
1287        });
1288
1289        // Build a rust consumer for incoming messages (unused here)
1290        let (consumer, _rx) = Consumer::rust_consumer();
1291
1292        // Configure client with short reconnect backoff
1293        let config = WebSocketConfig {
1294            url: format!("ws://127.0.0.1:{port}"),
1295            headers: vec![],
1296            handler: consumer,
1297            heartbeat: None,
1298            heartbeat_msg: None,
1299            #[cfg(feature = "python")]
1300            ping_handler: None,
1301            reconnect_timeout_ms: Some(1_000),
1302            reconnect_delay_initial_ms: Some(50),
1303            reconnect_delay_max_ms: Some(100),
1304            reconnect_backoff_factor: Some(1.0),
1305            reconnect_jitter_ms: Some(0),
1306        };
1307
1308        // Connect the client
1309        let client = {
1310            #[cfg(feature = "python")]
1311            {
1312                WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1313                    .await
1314                    .unwrap()
1315            }
1316            #[cfg(not(feature = "python"))]
1317            {
1318                WebSocketClient::connect(config.clone(), vec![], None)
1319                    .await
1320                    .unwrap()
1321            }
1322        };
1323
1324        // Allow server to drop connection and client to detect
1325        sleep(Duration::from_millis(100)).await;
1326        // Now immediately disconnect the client
1327        client.disconnect().await;
1328        assert!(client.is_disconnected());
1329        server.abort();
1330    }
1331}