nautilus_network/
websocket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance WebSocket client implementation with automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    sync::{
34        Arc,
35        atomic::{AtomicU8, Ordering},
36    },
37    time::Duration,
38};
39
40use futures_util::{
41    SinkExt, StreamExt,
42    stream::{SplitSink, SplitStream},
43};
44use http::HeaderName;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46#[cfg(feature = "python")]
47use pyo3::{prelude::*, types::PyBytes};
48use tokio::{
49    net::TcpStream,
50    sync::mpsc::{self, Receiver, Sender},
51};
52use tokio_tungstenite::{
53    MaybeTlsStream, WebSocketStream, connect_async,
54    tungstenite::{Error, Message, client::IntoClientRequest, http::HeaderValue},
55};
56
57use crate::{
58    backoff::ExponentialBackoff,
59    error::SendError,
60    logging::{log_task_aborted, log_task_started, log_task_stopped},
61    mode::ConnectionMode,
62    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
63};
64
65type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
66pub type MessageReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
67
68#[derive(Debug, Clone)]
69pub enum Consumer {
70    #[cfg(feature = "python")]
71    Python(Option<Arc<PyObject>>),
72    Rust(Sender<Message>),
73}
74
75impl Consumer {
76    #[must_use]
77    pub fn rust_consumer() -> (Self, Receiver<Message>) {
78        let (tx, rx) = mpsc::channel(100);
79        (Self::Rust(tx), rx)
80    }
81}
82
83#[derive(Debug, Clone)]
84#[cfg_attr(
85    feature = "python",
86    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
87)]
88pub struct WebSocketConfig {
89    /// The URL to connect to.
90    pub url: String,
91    /// The default headers.
92    pub headers: Vec<(String, String)>,
93    /// The Python function to handle incoming messages.
94    pub handler: Consumer,
95    /// The optional heartbeat interval (seconds).
96    pub heartbeat: Option<u64>,
97    /// The optional heartbeat message.
98    pub heartbeat_msg: Option<String>,
99    /// The handler for incoming pings.
100    #[cfg(feature = "python")]
101    pub ping_handler: Option<Arc<PyObject>>,
102    /// The timeout (milliseconds) for reconnection attempts.
103    pub reconnect_timeout_ms: Option<u64>,
104    /// The initial reconnection delay (milliseconds) for reconnects.
105    pub reconnect_delay_initial_ms: Option<u64>,
106    /// The maximum reconnect delay (milliseconds) for exponential backoff.
107    pub reconnect_delay_max_ms: Option<u64>,
108    /// The exponential backoff factor for reconnection delays.
109    pub reconnect_backoff_factor: Option<f64>,
110    /// The maximum jitter (milliseconds) added to reconnection delays.
111    pub reconnect_jitter_ms: Option<u64>,
112}
113
114/// Represents a command for the writer task.
115#[derive(Debug)]
116pub(crate) enum WriterCommand {
117    /// Update the writer reference with a new one after reconnection.
118    Update(MessageWriter),
119    /// Send message to the server.
120    Send(Message),
121}
122
123/// `WebSocketClient` connects to a websocket server to read and send messages.
124///
125/// The client is opinionated about how messages are read and written. It
126/// assumes that data can only have one reader but multiple writers.
127///
128/// The client splits the connection into read and write halves. It moves
129/// the read half into a tokio task which keeps receiving messages from the
130/// server and calls a handler - a Python function that takes the data
131/// as its parameter. It stores the write half in the struct wrapped
132/// with an Arc Mutex. This way the client struct can be used to write
133/// data to the server from multiple scopes/tasks.
134///
135/// The client also maintains a heartbeat if given a duration in seconds.
136/// It's preferable to set the duration slightly lower - heartbeat more
137/// frequently - than the required amount.
138struct WebSocketClientInner {
139    config: WebSocketConfig,
140    read_task: Option<tokio::task::JoinHandle<()>>,
141    write_task: tokio::task::JoinHandle<()>,
142    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
143    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
144    connection_mode: Arc<AtomicU8>,
145    reconnect_timeout: Duration,
146    backoff: ExponentialBackoff,
147}
148
149impl WebSocketClientInner {
150    /// Create an inner websocket client.
151    pub async fn connect_url(config: WebSocketConfig) -> Result<Self, Error> {
152        install_cryptographic_provider();
153
154        #[allow(unused_variables)]
155        let WebSocketConfig {
156            url,
157            handler,
158            heartbeat,
159            headers,
160            heartbeat_msg,
161            #[cfg(feature = "python")]
162            ping_handler,
163            reconnect_timeout_ms,
164            reconnect_delay_initial_ms,
165            reconnect_delay_max_ms,
166            reconnect_backoff_factor,
167            reconnect_jitter_ms,
168        } = &config;
169        let (writer, reader) = Self::connect_with_server(url, headers.clone()).await?;
170
171        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
172
173        let read_task = match &handler {
174            #[cfg(feature = "python")]
175            Consumer::Python(handler) => handler.as_ref().map(|handler| {
176                Self::spawn_python_callback_task(
177                    connection_mode.clone(),
178                    reader,
179                    handler.clone(),
180                    ping_handler.clone(),
181                )
182            }),
183            Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
184                connection_mode.clone(),
185                reader,
186                sender.clone(),
187            )),
188        };
189
190        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
191        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
192
193        // Optionally spawn a heartbeat task to periodically ping server
194        let heartbeat_task = heartbeat.as_ref().map(|heartbeat_secs| {
195            Self::spawn_heartbeat_task(
196                connection_mode.clone(),
197                *heartbeat_secs,
198                heartbeat_msg.clone(),
199                writer_tx.clone(),
200            )
201        });
202
203        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
204        let backoff = ExponentialBackoff::new(
205            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
206            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
207            reconnect_backoff_factor.unwrap_or(1.5),
208            reconnect_jitter_ms.unwrap_or(100),
209            true, // immediate-first
210        )
211        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
212
213        Ok(Self {
214            config,
215            read_task,
216            write_task,
217            writer_tx,
218            heartbeat_task,
219            connection_mode,
220            reconnect_timeout,
221            backoff,
222        })
223    }
224
225    /// Connects with the server creating a tokio-tungstenite websocket stream.
226    #[inline]
227    pub async fn connect_with_server(
228        url: &str,
229        headers: Vec<(String, String)>,
230    ) -> Result<(MessageWriter, MessageReader), Error> {
231        let mut request = url.into_client_request()?;
232        let req_headers = request.headers_mut();
233
234        let mut header_names: Vec<HeaderName> = Vec::new();
235        for (key, val) in headers {
236            let header_value = HeaderValue::from_str(&val)?;
237            let header_name: HeaderName = key.parse()?;
238            header_names.push(header_name.clone());
239            req_headers.insert(header_name, header_value);
240        }
241
242        connect_async(request).await.map(|resp| resp.0.split())
243    }
244
245    /// Reconnect with server.
246    ///
247    /// Make a new connection with server. Use the new read and write halves
248    /// to update self writer and read and heartbeat tasks.
249    pub async fn reconnect(&mut self) -> Result<(), Error> {
250        tracing::debug!("Reconnecting");
251
252        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
253            tracing::debug!("Reconnect aborted due to disconnect state");
254            return Ok(());
255        }
256
257        tokio::time::timeout(self.reconnect_timeout, async {
258            // Attempt to connect; abort early if a disconnect was requested
259            let (new_writer, reader) =
260                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
261
262            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
263                tracing::debug!("Reconnect aborted mid-flight (after connect)");
264                return Ok(());
265            }
266
267            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
268                tracing::error!("{e}");
269            }
270
271            // Delay before closing connection
272            tokio::time::sleep(Duration::from_millis(100)).await;
273
274            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
275                tracing::debug!("Reconnect aborted mid-flight (after delay)");
276                return Ok(());
277            }
278
279            if let Some(ref read_task) = self.read_task.take() {
280                if !read_task.is_finished() {
281                    read_task.abort();
282                    log_task_aborted("read");
283                }
284            }
285
286            // If a disconnect was requested during reconnect, do not proceed to reactivate
287            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
288                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
289                return Ok(());
290            }
291
292            // Mark as active only if not disconnecting
293            self.connection_mode
294                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
295
296            self.read_task = match &self.config.handler {
297                #[cfg(feature = "python")]
298                Consumer::Python(handler) => handler.as_ref().map(|handler| {
299                    Self::spawn_python_callback_task(
300                        self.connection_mode.clone(),
301                        reader,
302                        handler.clone(),
303                        self.config.ping_handler.clone(),
304                    )
305                }),
306                Consumer::Rust(sender) => Some(Self::spawn_rust_streaming_task(
307                    self.connection_mode.clone(),
308                    reader,
309                    sender.clone(),
310                )),
311            };
312
313            tracing::debug!("Reconnect succeeded");
314            Ok(())
315        })
316        .await
317        .map_err(|_| {
318            Error::Io(std::io::Error::new(
319                std::io::ErrorKind::TimedOut,
320                format!(
321                    "reconnection timed out after {}s",
322                    self.reconnect_timeout.as_secs_f64()
323                ),
324            ))
325        })?
326    }
327
328    /// Check if the client is still connected.
329    ///
330    /// The client is connected if the read task has not finished. It is expected
331    /// that in case of any failure client or server side. The read task will be
332    /// shutdown or will receive a `Close` frame which will finish it. There
333    /// might be some delay between the connection being closed and the client
334    /// detecting.
335    #[inline]
336    #[must_use]
337    pub fn is_alive(&self) -> bool {
338        match &self.read_task {
339            Some(read_task) => !read_task.is_finished(),
340            None => true, // Stream is being used directly
341        }
342    }
343
344    fn spawn_rust_streaming_task(
345        connection_state: Arc<AtomicU8>,
346        mut reader: MessageReader,
347        sender: Sender<Message>,
348    ) -> tokio::task::JoinHandle<()> {
349        tracing::debug!("Started streaming task 'read'");
350
351        let check_interval = Duration::from_millis(10);
352
353        tokio::task::spawn(async move {
354            loop {
355                if !ConnectionMode::from_atomic(&connection_state).is_active() {
356                    break;
357                }
358
359                match tokio::time::timeout(check_interval, reader.next()).await {
360                    Ok(Some(Ok(message))) => {
361                        if let Err(e) = sender.send(message).await {
362                            tracing::error!("Failed to send message: {e}");
363                        }
364                    }
365                    Ok(Some(Err(e))) => {
366                        tracing::error!("Received error message - terminating: {e}");
367                        break;
368                    }
369                    Ok(None) => {
370                        tracing::debug!("No message received - terminating");
371                        break;
372                    }
373                    Err(_) => {
374                        // Timeout - continue loop and check connection mode
375                        continue;
376                    }
377                }
378            }
379        })
380    }
381
382    #[cfg(feature = "python")]
383    fn spawn_python_callback_task(
384        connection_state: Arc<AtomicU8>,
385        mut reader: MessageReader,
386        handler: Arc<PyObject>,
387        ping_handler: Option<Arc<PyObject>>,
388    ) -> tokio::task::JoinHandle<()> {
389        log_task_started("read");
390
391        // Interval between checking the connection mode
392        let check_interval = Duration::from_millis(10);
393
394        tokio::task::spawn(async move {
395            loop {
396                if !ConnectionMode::from_atomic(&connection_state).is_active() {
397                    break;
398                }
399
400                match tokio::time::timeout(check_interval, reader.next()).await {
401                    Ok(Some(Ok(Message::Binary(data)))) => {
402                        tracing::trace!("Received message <binary> {} bytes", data.len());
403                        if let Err(e) =
404                            Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &data),)))
405                        {
406                            tracing::error!("Error calling handler: {e}");
407                            break;
408                        }
409                        continue;
410                    }
411                    Ok(Some(Ok(Message::Text(data)))) => {
412                        tracing::trace!("Received message: {data}");
413                        if let Err(e) = Python::with_gil(|py| {
414                            handler.call1(py, (PyBytes::new(py, data.as_bytes()),))
415                        }) {
416                            tracing::error!("Error calling handler: {e}");
417                            break;
418                        }
419                        continue;
420                    }
421                    Ok(Some(Ok(Message::Ping(ping)))) => {
422                        tracing::trace!("Received ping: {ping:?}");
423                        if let Some(ref handler) = ping_handler {
424                            if let Err(e) =
425                                Python::with_gil(|py| handler.call1(py, (PyBytes::new(py, &ping),)))
426                            {
427                                tracing::error!("Error calling handler: {e}");
428                                break;
429                            }
430                        }
431                        continue;
432                    }
433                    Ok(Some(Ok(Message::Pong(_)))) => {
434                        tracing::trace!("Received pong");
435                    }
436                    Ok(Some(Ok(Message::Close(_)))) => {
437                        tracing::debug!("Received close message - terminating");
438                        break;
439                    }
440                    Ok(Some(Ok(_))) => (),
441                    Ok(Some(Err(e))) => {
442                        tracing::error!("Received error message - terminating: {e}");
443                        break;
444                    }
445                    // Internally tungstenite considers the connection closed when polling
446                    // for the next message in the stream returns None.
447                    Ok(None) => {
448                        tracing::debug!("No message received - terminating");
449                        break;
450                    }
451                    Err(_) => {
452                        // Timeout - continue loop and check connection mode
453                        continue;
454                    }
455                }
456            }
457        })
458    }
459
460    fn spawn_write_task(
461        connection_state: Arc<AtomicU8>,
462        writer: MessageWriter,
463        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
464    ) -> tokio::task::JoinHandle<()> {
465        log_task_started("write");
466
467        // Interval between checking the connection mode
468        let check_interval = Duration::from_millis(10);
469
470        tokio::task::spawn(async move {
471            let mut active_writer = writer;
472
473            loop {
474                match ConnectionMode::from_atomic(&connection_state) {
475                    ConnectionMode::Disconnect => {
476                        // Attempt to close the writer gracefully before exiting,
477                        // we ignore any error as the writer may already be closed.
478                        _ = active_writer.close().await;
479                        break;
480                    }
481                    ConnectionMode::Closed => break,
482                    _ => {}
483                }
484
485                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
486                    Ok(Some(msg)) => {
487                        // Re-check connection mode after receiving a message
488                        let mode = ConnectionMode::from_atomic(&connection_state);
489                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
490                            break;
491                        }
492
493                        match msg {
494                            WriterCommand::Update(new_writer) => {
495                                tracing::debug!("Received new writer");
496
497                                // Delay before closing connection
498                                tokio::time::sleep(Duration::from_millis(100)).await;
499
500                                // Attempt to close the writer gracefully on update,
501                                // we ignore any error as the writer may already be closed.
502                                _ = active_writer.close().await;
503
504                                active_writer = new_writer;
505                                tracing::debug!("Updated writer");
506                            }
507                            _ if mode.is_reconnect() => {
508                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
509                                continue;
510                            }
511                            WriterCommand::Send(msg) => {
512                                if let Err(e) = active_writer.send(msg).await {
513                                    tracing::error!("Failed to send message: {e}");
514                                    // Mode is active so trigger reconnection
515                                    tracing::warn!("Writer triggering reconnect");
516                                    connection_state
517                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
518                                }
519                            }
520                        }
521                    }
522                    Ok(None) => {
523                        // Channel closed - writer task should terminate
524                        tracing::debug!("Writer channel closed, terminating writer task");
525                        break;
526                    }
527                    Err(_) => {
528                        // Timeout - just continue the loop
529                        continue;
530                    }
531                }
532            }
533
534            // Attempt to close the writer gracefully before exiting,
535            // we ignore any error as the writer may already be closed.
536            _ = active_writer.close().await;
537
538            log_task_stopped("write");
539        })
540    }
541
542    fn spawn_heartbeat_task(
543        connection_state: Arc<AtomicU8>,
544        heartbeat_secs: u64,
545        message: Option<String>,
546        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
547    ) -> tokio::task::JoinHandle<()> {
548        log_task_started("heartbeat");
549
550        tokio::task::spawn(async move {
551            let interval = Duration::from_secs(heartbeat_secs);
552
553            loop {
554                tokio::time::sleep(interval).await;
555
556                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
557                    ConnectionMode::Active => {
558                        let msg = match &message {
559                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
560                            None => WriterCommand::Send(Message::Ping(vec![].into())),
561                        };
562
563                        match writer_tx.send(msg) {
564                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
565                            Err(e) => {
566                                tracing::error!("Failed to send heartbeat to writer task: {e}");
567                            }
568                        }
569                    }
570                    ConnectionMode::Reconnect => continue,
571                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
572                }
573            }
574
575            log_task_stopped("heartbeat");
576        })
577    }
578}
579
580impl Drop for WebSocketClientInner {
581    fn drop(&mut self) {
582        if let Some(ref read_task) = self.read_task.take() {
583            if !read_task.is_finished() {
584                read_task.abort();
585                log_task_aborted("read");
586            }
587        }
588
589        if !self.write_task.is_finished() {
590            self.write_task.abort();
591            log_task_aborted("write");
592        }
593
594        if let Some(ref handle) = self.heartbeat_task.take() {
595            if !handle.is_finished() {
596                handle.abort();
597                log_task_aborted("heartbeat");
598            }
599        }
600    }
601}
602
603/// WebSocket client with automatic reconnection.
604///
605/// Handles connection state, Python callbacks, and rate limiting.
606/// See module docs for architecture details.
607#[cfg_attr(
608    feature = "python",
609    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
610)]
611pub struct WebSocketClient {
612    pub(crate) controller_task: tokio::task::JoinHandle<()>,
613    pub(crate) connection_mode: Arc<AtomicU8>,
614    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
615    pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
616}
617
618impl Debug for WebSocketClient {
619    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.debug_struct(stringify!(WebSocketClient)).finish()
621    }
622}
623
624impl WebSocketClient {
625    /// Creates a websocket client that returns a stream for reading messages.
626    ///
627    /// # Errors
628    ///
629    /// Returns any error connecting to the server.
630    #[allow(clippy::too_many_arguments)]
631    pub async fn connect_stream(
632        config: WebSocketConfig,
633        keyed_quotas: Vec<(String, Quota)>,
634        default_quota: Option<Quota>,
635        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
636    ) -> Result<(MessageReader, Self), Error> {
637        install_cryptographic_provider();
638        let (ws_stream, _) = connect_async(config.url.clone().into_client_request()?).await?;
639        let (writer, reader) = ws_stream.split();
640        let inner = WebSocketClientInner::connect_url(config).await?;
641
642        let connection_mode = inner.connection_mode.clone();
643
644        let writer_tx = inner.writer_tx.clone();
645        if let Err(e) = writer_tx.send(WriterCommand::Update(writer)) {
646            tracing::error!("{e}");
647        }
648
649        let controller_task = Self::spawn_controller_task(
650            inner,
651            connection_mode.clone(),
652            post_reconnect,
653            #[cfg(feature = "python")]
654            None, // no post_reconnection
655            #[cfg(feature = "python")]
656            None, // no post_disconnection
657        );
658
659        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
660
661        Ok((
662            reader,
663            Self {
664                controller_task,
665                connection_mode,
666                writer_tx,
667                rate_limiter,
668            },
669        ))
670    }
671
672    /// Creates a websocket client.
673    ///
674    /// Creates an inner client and controller task to reconnect or disconnect
675    /// the client. Also assumes ownership of writer from inner client.
676    ///
677    /// # Errors
678    ///
679    /// Returns any websocket error.
680    pub async fn connect(
681        config: WebSocketConfig,
682        #[cfg(feature = "python")] post_connection: Option<PyObject>,
683        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
684        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
685        keyed_quotas: Vec<(String, Quota)>,
686        default_quota: Option<Quota>,
687    ) -> Result<Self, Error> {
688        tracing::debug!("Connecting");
689        let inner = WebSocketClientInner::connect_url(config.clone()).await?;
690        let connection_mode = inner.connection_mode.clone();
691        let writer_tx = inner.writer_tx.clone();
692
693        let controller_task = Self::spawn_controller_task(
694            inner,
695            connection_mode.clone(),
696            None, // Rust handler
697            #[cfg(feature = "python")]
698            post_reconnection, // TODO: Deprecated
699            #[cfg(feature = "python")]
700            post_disconnection, // TODO: Deprecated
701        );
702
703        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
704
705        #[cfg(feature = "python")]
706        if let Some(handler) = post_connection {
707            Python::with_gil(|py| match handler.call0(py) {
708                Ok(_) => tracing::debug!("Called `post_connection` handler"),
709                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
710            });
711        }
712
713        Ok(Self {
714            controller_task,
715            connection_mode,
716            writer_tx,
717            rate_limiter,
718        })
719    }
720
721    /// Returns the current connection mode.
722    #[must_use]
723    pub fn connection_mode(&self) -> ConnectionMode {
724        ConnectionMode::from_atomic(&self.connection_mode)
725    }
726
727    /// Check if the client connection is active.
728    ///
729    /// Returns `true` if the client is connected and has not been signalled to disconnect.
730    /// The client will automatically retry connection based on its configuration.
731    #[inline]
732    #[must_use]
733    pub fn is_active(&self) -> bool {
734        self.connection_mode().is_active()
735    }
736
737    /// Check if the client is disconnected.
738    #[must_use]
739    pub fn is_disconnected(&self) -> bool {
740        self.controller_task.is_finished()
741    }
742
743    /// Check if the client is reconnecting.
744    ///
745    /// Returns `true` if the client lost connection and is attempting to reestablish it.
746    /// The client will automatically retry connection based on its configuration.
747    #[inline]
748    #[must_use]
749    pub fn is_reconnecting(&self) -> bool {
750        self.connection_mode().is_reconnect()
751    }
752
753    /// Check if the client is disconnecting.
754    ///
755    /// Returns `true` if the client is in disconnect mode.
756    #[inline]
757    #[must_use]
758    pub fn is_disconnecting(&self) -> bool {
759        self.connection_mode().is_disconnect()
760    }
761
762    /// Check if the client is closed.
763    ///
764    /// Returns `true` if the client has been explicitly disconnected or reached
765    /// maximum reconnection attempts. In this state, the client cannot be reused
766    /// and a new client must be created for further connections.
767    #[inline]
768    #[must_use]
769    pub fn is_closed(&self) -> bool {
770        self.connection_mode().is_closed()
771    }
772
773    /// Set disconnect mode to true.
774    ///
775    /// Controller task will periodically check the disconnect mode
776    /// and shutdown the client if it is alive
777    pub async fn disconnect(&self) {
778        tracing::debug!("Disconnecting");
779        self.connection_mode
780            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
781
782        match tokio::time::timeout(Duration::from_secs(5), async {
783            while !self.is_disconnected() {
784                tokio::time::sleep(Duration::from_millis(10)).await;
785            }
786
787            if !self.controller_task.is_finished() {
788                self.controller_task.abort();
789                log_task_aborted("controller");
790            }
791        })
792        .await
793        {
794            Ok(()) => {
795                tracing::debug!("Controller task finished");
796            }
797            Err(_) => {
798                tracing::error!("Timeout waiting for controller task to finish");
799            }
800        }
801    }
802
803    /// Sends the given text `data` to the server.
804    ///
805    /// # Errors
806    ///
807    /// Returns a websocket error if unable to send.
808    #[allow(unused_variables)]
809    pub async fn send_text(
810        &self,
811        data: String,
812        keys: Option<Vec<String>>,
813    ) -> std::result::Result<(), SendError> {
814        self.rate_limiter.await_keys_ready(keys).await;
815
816        if !self.is_active() {
817            return Err(SendError::Closed);
818        }
819
820        tracing::trace!("Sending text: {data:?}");
821
822        let msg = Message::Text(data.into());
823        self.writer_tx
824            .send(WriterCommand::Send(msg))
825            .map_err(|e| SendError::BrokenPipe(e.to_string()))
826    }
827
828    /// Sends the given bytes `data` to the server.
829    ///
830    /// # Errors
831    ///
832    /// Returns a websocket error if unable to send.
833    #[allow(unused_variables)]
834    pub async fn send_bytes(
835        &self,
836        data: Vec<u8>,
837        keys: Option<Vec<String>>,
838    ) -> std::result::Result<(), SendError> {
839        self.rate_limiter.await_keys_ready(keys).await;
840
841        if !self.is_active() {
842            return Err(SendError::Closed);
843        }
844
845        tracing::trace!("Sending bytes: {data:?}");
846
847        let msg = Message::Binary(data.into());
848        self.writer_tx
849            .send(WriterCommand::Send(msg))
850            .map_err(|e| SendError::BrokenPipe(e.to_string()))
851    }
852
853    /// Sends a close message to the server.
854    ///
855    /// # Errors
856    ///
857    /// Returns a websocket error if unable to send.
858    pub async fn send_close_message(&self) -> std::result::Result<(), SendError> {
859        if !self.is_active() {
860            return Err(SendError::Closed);
861        }
862
863        let msg = Message::Close(None);
864        self.writer_tx
865            .send(WriterCommand::Send(msg))
866            .map_err(|e| SendError::BrokenPipe(e.to_string()))
867    }
868
869    fn spawn_controller_task(
870        mut inner: WebSocketClientInner,
871        connection_mode: Arc<AtomicU8>,
872        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
873        #[cfg(feature = "python")] py_post_reconnection: Option<PyObject>, // TODO: Deprecated
874        #[cfg(feature = "python")] py_post_disconnection: Option<PyObject>, // TODO: Deprecated
875    ) -> tokio::task::JoinHandle<()> {
876        tokio::task::spawn(async move {
877            log_task_started("controller");
878
879            let check_interval = Duration::from_millis(10);
880
881            loop {
882                tokio::time::sleep(check_interval).await;
883                let mode = ConnectionMode::from_atomic(&connection_mode);
884
885                if mode.is_disconnect() {
886                    tracing::debug!("Disconnecting");
887
888                    let timeout = Duration::from_secs(5);
889                    if tokio::time::timeout(timeout, async {
890                        // Delay awaiting graceful shutdown
891                        tokio::time::sleep(Duration::from_millis(100)).await;
892
893                        if let Some(task) = &inner.read_task {
894                            if !task.is_finished() {
895                                task.abort();
896                                log_task_aborted("read");
897                            }
898                        }
899
900                        if let Some(task) = &inner.heartbeat_task {
901                            if !task.is_finished() {
902                                task.abort();
903                                log_task_aborted("heartbeat");
904                            }
905                        }
906                    })
907                    .await
908                    .is_err()
909                    {
910                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
911                    }
912
913                    tracing::debug!("Closed");
914
915                    #[cfg(feature = "python")]
916                    if let Some(ref handler) = py_post_disconnection {
917                        Python::with_gil(|py| match handler.call0(py) {
918                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
919                            Err(e) => {
920                                tracing::error!("Error calling `post_disconnection` handler: {e}");
921                            }
922                        });
923                    }
924                    break; // Controller finished
925                }
926
927                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
928                    match inner.reconnect().await {
929                        Ok(()) => {
930                            inner.backoff.reset();
931
932                            // Only invoke callbacks if not in disconnect state
933                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
934                                if let Some(ref callback) = post_reconnection {
935                                    callback();
936                                }
937
938                                // TODO: Python based websocket handlers deprecated (will be removed)
939                                #[cfg(feature = "python")]
940                                if let Some(ref callback) = py_post_reconnection {
941                                    Python::with_gil(|py| match callback.call0(py) {
942                                        Ok(_) => {
943                                            tracing::debug!("Called `post_reconnection` handler");
944                                        }
945                                        Err(e) => tracing::error!(
946                                            "Error calling `post_reconnection` handler: {e}"
947                                        ),
948                                    });
949                                }
950
951                                tracing::debug!("Reconnected successfully");
952                            } else {
953                                tracing::debug!(
954                                    "Skipping post_reconnection handlers due to disconnect state"
955                                );
956                            }
957                        }
958                        Err(e) => {
959                            let duration = inner.backoff.next_duration();
960                            tracing::warn!("Reconnect attempt failed: {e}");
961                            if !duration.is_zero() {
962                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
963                            }
964                            tokio::time::sleep(duration).await;
965                        }
966                    }
967                }
968            }
969            inner
970                .connection_mode
971                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
972
973            log_task_stopped("controller");
974        })
975    }
976}
977
978// Abort controller task on drop to clean up background tasks
979impl Drop for WebSocketClient {
980    fn drop(&mut self) {
981        if !self.controller_task.is_finished() {
982            self.controller_task.abort();
983            log_task_aborted("controller");
984        }
985    }
986}
987
988////////////////////////////////////////////////////////////////////////////////
989// Tests
990////////////////////////////////////////////////////////////////////////////////
991#[cfg(feature = "python")]
992#[cfg(test)]
993#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
994mod tests {
995    use std::{num::NonZeroU32, sync::Arc};
996
997    use futures_util::{SinkExt, StreamExt};
998    use tokio::{
999        net::TcpListener,
1000        task::{self, JoinHandle},
1001    };
1002    use tokio_tungstenite::{
1003        accept_hdr_async,
1004        tungstenite::{
1005            handshake::server::{self, Callback},
1006            http::HeaderValue,
1007        },
1008    };
1009
1010    use crate::{
1011        ratelimiter::quota::Quota,
1012        websocket::{Consumer, WebSocketClient, WebSocketConfig},
1013    };
1014
1015    struct TestServer {
1016        task: JoinHandle<()>,
1017        port: u16,
1018    }
1019
1020    #[derive(Debug, Clone)]
1021    struct TestCallback {
1022        key: String,
1023        value: HeaderValue,
1024    }
1025
1026    impl Callback for TestCallback {
1027        fn on_request(
1028            self,
1029            request: &server::Request,
1030            response: server::Response,
1031        ) -> Result<server::Response, server::ErrorResponse> {
1032            let _ = response;
1033            let value = request.headers().get(&self.key);
1034            assert!(value.is_some());
1035
1036            if let Some(value) = request.headers().get(&self.key) {
1037                assert_eq!(value, self.value);
1038            }
1039
1040            Ok(response)
1041        }
1042    }
1043
1044    impl TestServer {
1045        async fn setup() -> Self {
1046            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1047            let port = TcpListener::local_addr(&server).unwrap().port();
1048
1049            let header_key = "test".to_string();
1050            let header_value = "test".to_string();
1051
1052            let test_call_back = TestCallback {
1053                key: header_key,
1054                value: HeaderValue::from_str(&header_value).unwrap(),
1055            };
1056
1057            let task = task::spawn(async move {
1058                // Keep accepting connections
1059                loop {
1060                    let (conn, _) = server.accept().await.unwrap();
1061                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1062                        .await
1063                        .unwrap();
1064
1065                    task::spawn(async move {
1066                        while let Some(Ok(msg)) = websocket.next().await {
1067                            match msg {
1068                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1069                                    if txt == "close-now" =>
1070                                {
1071                                    tracing::debug!("Forcibly closing from server side");
1072                                    // This sends a close frame, then stops reading
1073                                    let _ = websocket.close(None).await;
1074                                    break;
1075                                }
1076                                // Echo text/binary frames
1077                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1078                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1079                                    if websocket.send(msg).await.is_err() {
1080                                        break;
1081                                    }
1082                                }
1083                                // If the client closes, we also break
1084                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1085                                    _frame,
1086                                ) => {
1087                                    let _ = websocket.close(None).await;
1088                                    break;
1089                                }
1090                                // Ignore pings/pongs
1091                                _ => {}
1092                            }
1093                        }
1094                    });
1095                }
1096            });
1097
1098            Self { task, port }
1099        }
1100    }
1101
1102    impl Drop for TestServer {
1103        fn drop(&mut self) {
1104            self.task.abort();
1105        }
1106    }
1107
1108    async fn setup_test_client(port: u16) -> WebSocketClient {
1109        let config = WebSocketConfig {
1110            url: format!("ws://127.0.0.1:{port}"),
1111            headers: vec![("test".into(), "test".into())],
1112            handler: Consumer::Python(None),
1113            heartbeat: None,
1114            heartbeat_msg: None,
1115            ping_handler: None,
1116            reconnect_timeout_ms: None,
1117            reconnect_delay_initial_ms: None,
1118            reconnect_backoff_factor: None,
1119            reconnect_delay_max_ms: None,
1120            reconnect_jitter_ms: None,
1121        };
1122        WebSocketClient::connect(config, None, None, None, vec![], None)
1123            .await
1124            .expect("Failed to connect")
1125    }
1126
1127    #[tokio::test]
1128    async fn test_websocket_basic() {
1129        let server = TestServer::setup().await;
1130        let client = setup_test_client(server.port).await;
1131
1132        assert!(!client.is_disconnected());
1133
1134        client.disconnect().await;
1135        assert!(client.is_disconnected());
1136    }
1137
1138    #[tokio::test]
1139    async fn test_websocket_heartbeat() {
1140        let server = TestServer::setup().await;
1141        let client = setup_test_client(server.port).await;
1142
1143        // Wait ~3s => server should see multiple "ping"
1144        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1145
1146        // Cleanup
1147        client.disconnect().await;
1148        assert!(client.is_disconnected());
1149    }
1150
1151    #[tokio::test]
1152    async fn test_websocket_reconnect_exhausted() {
1153        let config = WebSocketConfig {
1154            url: "ws://127.0.0.1:9997".into(), // <-- No server
1155            headers: vec![],
1156            handler: Consumer::Python(None),
1157            heartbeat: None,
1158            heartbeat_msg: None,
1159            ping_handler: None,
1160            reconnect_timeout_ms: None,
1161            reconnect_delay_initial_ms: None,
1162            reconnect_backoff_factor: None,
1163            reconnect_delay_max_ms: None,
1164            reconnect_jitter_ms: None,
1165        };
1166        let res = WebSocketClient::connect(config, None, None, None, vec![], None).await;
1167        assert!(res.is_err(), "Should fail quickly with no server");
1168    }
1169
1170    #[tokio::test]
1171    async fn test_websocket_forced_close_reconnect() {
1172        let server = TestServer::setup().await;
1173        let client = setup_test_client(server.port).await;
1174
1175        // 1) Send normal message
1176        client.send_text("Hello".into(), None).await.unwrap();
1177
1178        // 2) Trigger forced close from server
1179        client.send_text("close-now".into(), None).await.unwrap();
1180
1181        // 3) Wait a bit => read loop sees close => reconnect
1182        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1183
1184        // Confirm not disconnected
1185        assert!(!client.is_disconnected());
1186
1187        // Cleanup
1188        client.disconnect().await;
1189        assert!(client.is_disconnected());
1190    }
1191
1192    #[tokio::test]
1193    async fn test_rate_limiter() {
1194        let server = TestServer::setup().await;
1195        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
1196
1197        let config = WebSocketConfig {
1198            url: format!("ws://127.0.0.1:{}", server.port),
1199            headers: vec![("test".into(), "test".into())],
1200            handler: Consumer::Python(None),
1201            heartbeat: None,
1202            heartbeat_msg: None,
1203            ping_handler: None,
1204            reconnect_timeout_ms: None,
1205            reconnect_delay_initial_ms: None,
1206            reconnect_backoff_factor: None,
1207            reconnect_delay_max_ms: None,
1208            reconnect_jitter_ms: None,
1209        };
1210
1211        let client = WebSocketClient::connect(
1212            config,
1213            None,
1214            None,
1215            None,
1216            vec![("default".into(), quota)],
1217            None,
1218        )
1219        .await
1220        .unwrap();
1221
1222        // First 2 should succeed
1223        client.send_text("test1".into(), None).await.unwrap();
1224        client.send_text("test2".into(), None).await.unwrap();
1225
1226        // Third should error
1227        client.send_text("test3".into(), None).await.unwrap();
1228
1229        // Cleanup
1230        client.disconnect().await;
1231        assert!(client.is_disconnected());
1232    }
1233
1234    #[tokio::test]
1235    async fn test_concurrent_writers() {
1236        let server = TestServer::setup().await;
1237        let client = Arc::new(setup_test_client(server.port).await);
1238
1239        let mut handles = vec![];
1240        for i in 0..10 {
1241            let client = client.clone();
1242            handles.push(task::spawn(async move {
1243                client.send_text(format!("test{i}"), None).await.unwrap();
1244            }));
1245        }
1246
1247        for handle in handles {
1248            handle.await.unwrap();
1249        }
1250
1251        // Cleanup
1252        client.disconnect().await;
1253        assert!(client.is_disconnected());
1254    }
1255}
1256
1257#[cfg(test)]
1258mod rust_tests {
1259    use tokio::{
1260        net::TcpListener,
1261        task,
1262        time::{Duration, sleep},
1263    };
1264    use tokio_tungstenite::accept_async;
1265
1266    use super::*;
1267
1268    #[tokio::test]
1269    async fn test_reconnect_then_disconnect() {
1270        // Bind an ephemeral port
1271        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1272        let port = listener.local_addr().unwrap().port();
1273
1274        // Server task: accept one ws connection then close it
1275        let server = task::spawn(async move {
1276            let (stream, _) = listener.accept().await.unwrap();
1277            let ws = accept_async(stream).await.unwrap();
1278            drop(ws);
1279            // Keep alive briefly
1280            sleep(Duration::from_secs(1)).await;
1281        });
1282
1283        // Build a rust consumer for incoming messages (unused here)
1284        let (consumer, _rx) = Consumer::rust_consumer();
1285
1286        // Configure client with short reconnect backoff
1287        let config = WebSocketConfig {
1288            url: format!("ws://127.0.0.1:{port}"),
1289            headers: vec![],
1290            handler: consumer,
1291            heartbeat: None,
1292            heartbeat_msg: None,
1293            #[cfg(feature = "python")]
1294            ping_handler: None,
1295            reconnect_timeout_ms: Some(1_000),
1296            reconnect_delay_initial_ms: Some(50),
1297            reconnect_delay_max_ms: Some(100),
1298            reconnect_backoff_factor: Some(1.0),
1299            reconnect_jitter_ms: Some(0),
1300        };
1301
1302        // Connect the client
1303        let client = {
1304            #[cfg(feature = "python")]
1305            {
1306                WebSocketClient::connect(config.clone(), None, None, None, vec![], None)
1307                    .await
1308                    .unwrap()
1309            }
1310            #[cfg(not(feature = "python"))]
1311            {
1312                WebSocketClient::connect(config.clone(), vec![], None)
1313                    .await
1314                    .unwrap()
1315            }
1316        };
1317
1318        // Allow server to drop connection and client to detect
1319        sleep(Duration::from_millis(100)).await;
1320        // Now immediately disconnect the client
1321        client.disconnect().await;
1322        assert!(client.is_disconnected());
1323        server.abort();
1324    }
1325}