nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    path::Path,
34    sync::{
35        Arc,
36        atomic::{AtomicU8, Ordering},
37    },
38    time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_cryptography::providers::install_cryptographic_provider;
43#[cfg(feature = "python")]
44use pyo3::prelude::*;
45use tokio::{
46    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
47    net::TcpStream,
48};
49use tokio_tungstenite::{
50    MaybeTlsStream,
51    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
52};
53
54use crate::{
55    backoff::ExponentialBackoff,
56    error::SendError,
57    fix::process_fix_buffer,
58    logging::{log_task_aborted, log_task_started, log_task_stopped},
59    mode::ConnectionMode,
60    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
61};
62
63type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
64type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
65pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
66
67/// Configuration for TCP socket connection.
68#[derive(Debug, Clone)]
69#[cfg_attr(
70    feature = "python",
71    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
72)]
73pub struct SocketConfig {
74    /// The URL to connect to.
75    pub url: String,
76    /// The connection mode {Plain, TLS}.
77    pub mode: Mode,
78    /// The sequence of bytes which separates lines.
79    pub suffix: Vec<u8>,
80    #[cfg(feature = "python")]
81    /// The optional Python function to handle incoming messages.
82    pub py_handler: Option<Arc<PyObject>>,
83    /// The optional heartbeat with period and beat message.
84    pub heartbeat: Option<(u64, Vec<u8>)>,
85    /// The timeout (milliseconds) for reconnection attempts.
86    pub reconnect_timeout_ms: Option<u64>,
87    /// The initial reconnection delay (milliseconds) for reconnects.
88    pub reconnect_delay_initial_ms: Option<u64>,
89    /// The maximum reconnect delay (milliseconds) for exponential backoff.
90    pub reconnect_delay_max_ms: Option<u64>,
91    /// The exponential backoff factor for reconnection delays.
92    pub reconnect_backoff_factor: Option<f64>,
93    /// The maximum jitter (milliseconds) added to reconnection delays.
94    pub reconnect_jitter_ms: Option<u64>,
95    /// The path to the certificates directory.
96    pub certs_dir: Option<String>,
97}
98
99/// Represents a command for the writer task.
100#[derive(Debug)]
101pub enum WriterCommand {
102    /// Update the writer reference with a new one after reconnection.
103    Update(TcpWriter),
104    /// Send data to the server.
105    Send(Bytes),
106}
107
108/// Creates a `TcpStream` with the server.
109///
110/// The stream can be encrypted with TLS or Plain. The stream is split into
111/// read and write ends:
112/// - The read end is passed to the task that keeps receiving
113///   messages from the server and passing them to a handler.
114/// - The write end is passed to a task which receives messages over a channel
115///   to send to the server.
116///
117/// The heartbeat is optional and can be configured with an interval and data to
118/// send.
119///
120/// The client uses a suffix to separate messages on the byte stream. It is
121/// appended to all sent messages and heartbeats. It is also used to split
122/// the received byte stream.
123#[cfg_attr(
124    feature = "python",
125    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
126)]
127struct SocketClientInner {
128    config: SocketConfig,
129    connector: Option<Connector>,
130    read_task: Arc<tokio::task::JoinHandle<()>>,
131    write_task: tokio::task::JoinHandle<()>,
132    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
133    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
134    connection_mode: Arc<AtomicU8>,
135    reconnect_timeout: Duration,
136    backoff: ExponentialBackoff,
137    handler: Option<Arc<TcpMessageHandler>>,
138}
139
140impl SocketClientInner {
141    pub async fn connect_url(
142        config: SocketConfig,
143        handler: Option<Arc<TcpMessageHandler>>,
144    ) -> anyhow::Result<Self> {
145        install_cryptographic_provider();
146
147        let SocketConfig {
148            url,
149            mode,
150            heartbeat,
151            suffix,
152            #[cfg(feature = "python")]
153            py_handler,
154            reconnect_timeout_ms,
155            reconnect_delay_initial_ms,
156            reconnect_delay_max_ms,
157            reconnect_backoff_factor,
158            reconnect_jitter_ms,
159            certs_dir,
160        } = &config;
161        let connector = if let Some(dir) = certs_dir {
162            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
163            Some(Connector::Rustls(Arc::new(config)))
164        } else {
165            None
166        };
167
168        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
169        tracing::debug!("Connected");
170
171        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
172
173        let read_task = Arc::new(Self::spawn_read_task(
174            connection_mode.clone(),
175            reader,
176            handler.clone(),
177            #[cfg(feature = "python")]
178            py_handler.clone(),
179            suffix.clone(),
180        ));
181
182        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
183
184        let write_task =
185            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
186
187        // Optionally spawn a heartbeat task to periodically ping server
188        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
189            Self::spawn_heartbeat_task(
190                connection_mode.clone(),
191                heartbeat.clone(),
192                writer_tx.clone(),
193            )
194        });
195
196        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
197        let backoff = ExponentialBackoff::new(
198            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
199            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
200            reconnect_backoff_factor.unwrap_or(1.5),
201            reconnect_jitter_ms.unwrap_or(100),
202            true, // immediate-first
203        )?;
204
205        Ok(Self {
206            config,
207            connector,
208            read_task,
209            write_task,
210            writer_tx,
211            heartbeat_task,
212            connection_mode,
213            reconnect_timeout,
214            backoff,
215            handler,
216        })
217    }
218
219    pub async fn tls_connect_with_server(
220        url: &str,
221        mode: Mode,
222        connector: Option<Connector>,
223    ) -> Result<(TcpReader, TcpWriter), Error> {
224        tracing::debug!("Connecting to {url}");
225        let tcp_result = TcpStream::connect(url).await;
226
227        match tcp_result {
228            Ok(stream) => {
229                tracing::debug!("TCP connection established, proceeding with TLS");
230                let request = url.into_client_request()?;
231                tcp_tls(&request, mode, stream, connector)
232                    .await
233                    .map(tokio::io::split)
234            }
235            Err(e) => {
236                tracing::error!("TCP connection failed: {e:?}");
237                Err(Error::Io(e))
238            }
239        }
240    }
241
242    /// Reconnect with server.
243    ///
244    /// Makes a new connection with server, uses the new read and write halves
245    /// to update the reader and writer.
246    async fn reconnect(&mut self) -> Result<(), Error> {
247        tracing::debug!("Reconnecting");
248
249        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
250            tracing::debug!("Reconnect aborted due to disconnect state");
251            return Ok(());
252        }
253
254        tokio::time::timeout(self.reconnect_timeout, async {
255            let SocketConfig {
256                url,
257                mode,
258                heartbeat: _,
259                suffix,
260                #[cfg(feature = "python")]
261                py_handler,
262                reconnect_timeout_ms: _,
263                reconnect_delay_initial_ms: _,
264                reconnect_backoff_factor: _,
265                reconnect_delay_max_ms: _,
266                reconnect_jitter_ms: _,
267                certs_dir: _,
268            } = &self.config;
269            // Create a fresh connection
270            let connector = self.connector.clone();
271            // Attempt to connect; abort early if a disconnect was requested
272            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
273
274            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
275                tracing::debug!("Reconnect aborted mid-flight (after connect)");
276                return Ok(());
277            }
278            tracing::debug!("Connected");
279
280            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
281                tracing::error!("{e}");
282            }
283
284            // Delay before closing connection
285            tokio::time::sleep(Duration::from_millis(100)).await;
286
287            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
288                tracing::debug!("Reconnect aborted mid-flight (after delay)");
289                return Ok(());
290            }
291
292            if !self.read_task.is_finished() {
293                self.read_task.abort();
294                log_task_aborted("read");
295            }
296
297            // If a disconnect was requested during reconnect, do not proceed to reactivate
298            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
299                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
300                return Ok(());
301            }
302
303            // Mark as active only if not disconnecting
304            self.connection_mode
305                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
306
307            // Spawn new read task
308            self.read_task = Arc::new(Self::spawn_read_task(
309                self.connection_mode.clone(),
310                reader,
311                self.handler.clone(),
312                #[cfg(feature = "python")]
313                py_handler.clone(),
314                suffix.clone(),
315            ));
316
317            tracing::debug!("Reconnect succeeded");
318            Ok(())
319        })
320        .await
321        .map_err(|_| {
322            Error::Io(std::io::Error::new(
323                std::io::ErrorKind::TimedOut,
324                format!(
325                    "reconnection timed out after {}s",
326                    self.reconnect_timeout.as_secs_f64()
327                ),
328            ))
329        })?
330    }
331
332    /// Check if the client is still alive.
333    ///
334    /// The client is connected if the read task has not finished. It is expected
335    /// that in case of any failure client or server side. The read task will be
336    /// shutdown. There might be some delay between the connection being closed
337    /// and the client detecting it.
338    #[inline]
339    #[must_use]
340    pub fn is_alive(&self) -> bool {
341        !self.read_task.is_finished()
342    }
343
344    #[must_use]
345    fn spawn_read_task(
346        connection_state: Arc<AtomicU8>,
347        mut reader: TcpReader,
348        handler: Option<Arc<TcpMessageHandler>>,
349        #[cfg(feature = "python")] py_handler: Option<Arc<PyObject>>,
350        suffix: Vec<u8>,
351    ) -> tokio::task::JoinHandle<()> {
352        log_task_started("read");
353
354        // Interval between checking the connection mode
355        let check_interval = Duration::from_millis(10);
356
357        tokio::task::spawn(async move {
358            let mut buf = Vec::new();
359
360            loop {
361                if !ConnectionMode::from_atomic(&connection_state).is_active() {
362                    break;
363                }
364
365                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
366                    // Connection has been terminated or vector buffer is complete
367                    Ok(Ok(0)) => {
368                        tracing::debug!("Connection closed by server");
369                        break;
370                    }
371                    Ok(Err(e)) => {
372                        tracing::debug!("Connection ended: {e}");
373                        break;
374                    }
375                    // Received bytes of data
376                    Ok(Ok(bytes)) => {
377                        tracing::trace!("Received <binary> {bytes} bytes");
378
379                        if let Some(handler) = &handler {
380                            process_fix_buffer(&mut buf, handler);
381                        } else {
382                            while let Some((i, _)) = &buf
383                                .windows(suffix.len())
384                                .enumerate()
385                                .find(|(_, pair)| pair.eq(&suffix))
386                            {
387                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
388                                data.truncate(data.len() - suffix.len());
389
390                                if let Some(handler) = &handler {
391                                    handler(&data);
392                                }
393
394                                #[cfg(feature = "python")]
395                                if let Some(py_handler) = &py_handler {
396                                    if let Err(e) = Python::with_gil(|py| {
397                                        py_handler.call1(py, (data.as_slice(),))
398                                    }) {
399                                        tracing::error!("Call to handler failed: {e}");
400                                        break;
401                                    }
402                                }
403                            }
404                        }
405                    }
406                    Err(_) => {
407                        // Timeout - continue loop and check connection mode
408                        continue;
409                    }
410                }
411            }
412
413            log_task_stopped("read");
414        })
415    }
416
417    fn spawn_write_task(
418        connection_state: Arc<AtomicU8>,
419        writer: TcpWriter,
420        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
421        suffix: Vec<u8>,
422    ) -> tokio::task::JoinHandle<()> {
423        log_task_started("write");
424
425        // Interval between checking the connection mode
426        let check_interval = Duration::from_millis(10);
427
428        tokio::task::spawn(async move {
429            let mut active_writer = writer;
430
431            loop {
432                if matches!(
433                    ConnectionMode::from_atomic(&connection_state),
434                    ConnectionMode::Disconnect | ConnectionMode::Closed
435                ) {
436                    break;
437                }
438
439                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
440                    Ok(Some(msg)) => {
441                        // Re-check connection mode after receiving a message
442                        let mode = ConnectionMode::from_atomic(&connection_state);
443                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
444                            break;
445                        }
446
447                        match msg {
448                            WriterCommand::Update(new_writer) => {
449                                tracing::debug!("Received new writer");
450
451                                // Delay before closing connection
452                                tokio::time::sleep(Duration::from_millis(100)).await;
453
454                                // Attempt to shutdown the writer gracefully before updating,
455                                // we ignore any error as the writer may already be closed.
456                                _ = active_writer.shutdown().await;
457
458                                active_writer = new_writer;
459                                tracing::debug!("Updated writer");
460                            }
461                            _ if mode.is_reconnect() => {
462                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
463                                continue;
464                            }
465                            WriterCommand::Send(msg) => {
466                                if let Err(e) = active_writer.write_all(&msg).await {
467                                    tracing::error!("Failed to send message: {e}");
468                                    // Mode is active so trigger reconnection
469                                    tracing::warn!("Writer triggering reconnect");
470                                    connection_state
471                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
472                                    continue;
473                                }
474                                if let Err(e) = active_writer.write_all(&suffix).await {
475                                    tracing::error!("Failed to send message: {e}");
476                                }
477                            }
478                        }
479                    }
480                    Ok(None) => {
481                        // Channel closed - writer task should terminate
482                        tracing::debug!("Writer channel closed, terminating writer task");
483                        break;
484                    }
485                    Err(_) => {
486                        // Timeout - just continue the loop
487                        continue;
488                    }
489                }
490            }
491
492            // Attempt to shutdown the writer gracefully before exiting,
493            // we ignore any error as the writer may already be closed.
494            _ = active_writer.shutdown().await;
495
496            log_task_stopped("write");
497        })
498    }
499
500    fn spawn_heartbeat_task(
501        connection_state: Arc<AtomicU8>,
502        heartbeat: (u64, Vec<u8>),
503        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
504    ) -> tokio::task::JoinHandle<()> {
505        log_task_started("heartbeat");
506        let (interval_secs, message) = heartbeat;
507
508        tokio::task::spawn(async move {
509            let interval = Duration::from_secs(interval_secs);
510
511            loop {
512                tokio::time::sleep(interval).await;
513
514                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
515                    ConnectionMode::Active => {
516                        let msg = WriterCommand::Send(message.clone().into());
517
518                        match writer_tx.send(msg) {
519                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
520                            Err(e) => {
521                                tracing::error!("Failed to send heartbeat to writer task: {e}");
522                            }
523                        }
524                    }
525                    ConnectionMode::Reconnect => continue,
526                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
527                }
528            }
529
530            log_task_stopped("heartbeat");
531        })
532    }
533}
534
535impl Drop for SocketClientInner {
536    fn drop(&mut self) {
537        if !self.read_task.is_finished() {
538            self.read_task.abort();
539            log_task_aborted("read");
540        }
541
542        if !self.write_task.is_finished() {
543            self.write_task.abort();
544            log_task_aborted("write");
545        }
546
547        if let Some(ref handle) = self.heartbeat_task.take() {
548            if !handle.is_finished() {
549                handle.abort();
550                log_task_aborted("heartbeat");
551            }
552        }
553    }
554}
555
556#[cfg_attr(
557    feature = "python",
558    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
559)]
560pub struct SocketClient {
561    pub(crate) controller_task: tokio::task::JoinHandle<()>,
562    pub(crate) connection_mode: Arc<AtomicU8>,
563    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
564}
565
566impl Debug for SocketClient {
567    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568        f.debug_struct(stringify!(SocketClient)).finish()
569    }
570}
571
572impl SocketClient {
573    /// Connect to the server.
574    ///
575    /// # Errors
576    ///
577    /// Returns any error connecting to the server.
578    pub async fn connect(
579        config: SocketConfig,
580        handler: Option<Arc<TcpMessageHandler>>,
581        #[cfg(feature = "python")] post_connection: Option<PyObject>,
582        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
583        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
584    ) -> anyhow::Result<Self> {
585        let inner = SocketClientInner::connect_url(config, handler).await?;
586        let writer_tx = inner.writer_tx.clone();
587        let connection_mode = inner.connection_mode.clone();
588
589        let controller_task = Self::spawn_controller_task(
590            inner,
591            connection_mode.clone(),
592            #[cfg(feature = "python")]
593            post_reconnection,
594            #[cfg(feature = "python")]
595            post_disconnection,
596        );
597
598        #[cfg(feature = "python")]
599        if let Some(handler) = post_connection {
600            Python::with_gil(|py| match handler.call0(py) {
601                Ok(_) => tracing::debug!("Called `post_connection` handler"),
602                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
603            });
604        }
605
606        Ok(Self {
607            controller_task,
608            connection_mode,
609            writer_tx,
610        })
611    }
612
613    /// Returns the current connection mode.
614    #[must_use]
615    pub fn connection_mode(&self) -> ConnectionMode {
616        ConnectionMode::from_atomic(&self.connection_mode)
617    }
618
619    /// Check if the client connection is active.
620    ///
621    /// Returns `true` if the client is connected and has not been signalled to disconnect.
622    /// The client will automatically retry connection based on its configuration.
623    #[inline]
624    #[must_use]
625    pub fn is_active(&self) -> bool {
626        self.connection_mode().is_active()
627    }
628
629    /// Check if the client is reconnecting.
630    ///
631    /// Returns `true` if the client lost connection and is attempting to reestablish it.
632    /// The client will automatically retry connection based on its configuration.
633    #[inline]
634    #[must_use]
635    pub fn is_reconnecting(&self) -> bool {
636        self.connection_mode().is_reconnect()
637    }
638
639    /// Check if the client is disconnecting.
640    ///
641    /// Returns `true` if the client is in disconnect mode.
642    #[inline]
643    #[must_use]
644    pub fn is_disconnecting(&self) -> bool {
645        self.connection_mode().is_disconnect()
646    }
647
648    /// Check if the client is closed.
649    ///
650    /// Returns `true` if the client has been explicitly disconnected or reached
651    /// maximum reconnection attempts. In this state, the client cannot be reused
652    /// and a new client must be created for further connections.
653    #[inline]
654    #[must_use]
655    pub fn is_closed(&self) -> bool {
656        self.connection_mode().is_closed()
657    }
658
659    /// Close the client.
660    ///
661    /// Controller task will periodically check the disconnect mode
662    /// and shutdown the client if it is not alive.
663    pub async fn close(&self) {
664        self.connection_mode
665            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
666
667        match tokio::time::timeout(Duration::from_secs(5), async {
668            while !self.is_closed() {
669                tokio::time::sleep(Duration::from_millis(10)).await;
670            }
671
672            if !self.controller_task.is_finished() {
673                self.controller_task.abort();
674                log_task_aborted("controller");
675            }
676        })
677        .await
678        {
679            Ok(()) => {
680                log_task_stopped("controller");
681            }
682            Err(_) => {
683                tracing::error!("Timeout waiting for controller task to finish");
684            }
685        }
686    }
687
688    /// Sends a message of the given `data`.
689    ///
690    /// # Errors
691    ///
692    /// Returns an error if sending fails.
693    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
694        if self.is_closed() {
695            return Err(SendError::Closed);
696        }
697
698        let timeout = Duration::from_secs(2);
699        let check_interval = Duration::from_millis(1);
700
701        if !self.is_active() {
702            tracing::debug!("Waiting for client to become ACTIVE before sending...");
703
704            let inner = tokio::time::timeout(timeout, async {
705                loop {
706                    if self.is_active() {
707                        return Ok(());
708                    }
709                    if matches!(
710                        self.connection_mode(),
711                        ConnectionMode::Disconnect | ConnectionMode::Closed
712                    ) {
713                        return Err(());
714                    }
715                    tokio::time::sleep(check_interval).await;
716                }
717            })
718            .await
719            .map_err(|_| SendError::Timeout)?;
720            inner.map_err(|()| SendError::Closed)?;
721        }
722
723        let msg = WriterCommand::Send(data.into());
724        self.writer_tx
725            .send(msg)
726            .map_err(|e| SendError::BrokenPipe(e.to_string()))
727    }
728
729    fn spawn_controller_task(
730        mut inner: SocketClientInner,
731        connection_mode: Arc<AtomicU8>,
732        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
733        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
734    ) -> tokio::task::JoinHandle<()> {
735        tokio::task::spawn(async move {
736            log_task_started("controller");
737
738            let check_interval = Duration::from_millis(10);
739
740            loop {
741                tokio::time::sleep(check_interval).await;
742                let mode = ConnectionMode::from_atomic(&connection_mode);
743
744                if mode.is_disconnect() {
745                    tracing::debug!("Disconnecting");
746
747                    let timeout = Duration::from_secs(5);
748                    if tokio::time::timeout(timeout, async {
749                        // Delay awaiting graceful shutdown
750                        tokio::time::sleep(Duration::from_millis(100)).await;
751
752                        if !inner.read_task.is_finished() {
753                            inner.read_task.abort();
754                            log_task_aborted("read");
755                        }
756
757                        if let Some(task) = &inner.heartbeat_task {
758                            if !task.is_finished() {
759                                task.abort();
760                                log_task_aborted("heartbeat");
761                            }
762                        }
763                    })
764                    .await
765                    .is_err()
766                    {
767                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
768                    }
769
770                    tracing::debug!("Closed");
771
772                    #[cfg(feature = "python")]
773                    if let Some(ref handler) = post_disconnection {
774                        Python::with_gil(|py| match handler.call0(py) {
775                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
776                            Err(e) => {
777                                tracing::error!("Error calling `post_disconnection` handler: {e}");
778                            }
779                        });
780                    }
781                    break; // Controller finished
782                }
783
784                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
785                    match inner.reconnect().await {
786                        Ok(()) => {
787                            tracing::debug!("Reconnected successfully");
788                            inner.backoff.reset();
789                            // Only invoke Python reconnect handler if still active
790                            #[cfg(feature = "python")]
791                            {
792                                if ConnectionMode::from_atomic(&connection_mode).is_active() {
793                                    if let Some(ref handler) = post_reconnection {
794                                        Python::with_gil(|py| match handler.call0(py) {
795                                            Ok(_) => tracing::debug!(
796                                                "Called `post_reconnection` handler"
797                                            ),
798                                            Err(e) => tracing::error!(
799                                                "Error calling `post_reconnection` handler: {e}"
800                                            ),
801                                        });
802                                    }
803                                } else {
804                                    tracing::debug!(
805                                        "Skipping post_reconnection handlers due to disconnect state"
806                                    );
807                                }
808                            }
809                        }
810                        Err(e) => {
811                            let duration = inner.backoff.next_duration();
812                            tracing::warn!("Reconnect attempt failed: {e}");
813                            if !duration.is_zero() {
814                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
815                            }
816                            tokio::time::sleep(duration).await;
817                        }
818                    }
819                }
820            }
821            inner
822                .connection_mode
823                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
824
825            log_task_stopped("controller");
826        })
827    }
828}
829
830// Abort controller task on drop to clean up background tasks
831impl Drop for SocketClient {
832    fn drop(&mut self) {
833        if !self.controller_task.is_finished() {
834            self.controller_task.abort();
835            log_task_aborted("controller");
836        }
837    }
838}
839
840////////////////////////////////////////////////////////////////////////////////
841// Tests
842////////////////////////////////////////////////////////////////////////////////
843#[cfg(test)]
844#[cfg(feature = "python")]
845#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
846mod tests {
847    use std::ffi::CString;
848
849    use nautilus_common::testing::wait_until_async;
850    use nautilus_core::python::IntoPyObjectPoseiExt;
851    use pyo3::prepare_freethreaded_python;
852    use tokio::{
853        io::{AsyncReadExt, AsyncWriteExt},
854        net::{TcpListener, TcpStream},
855        sync::Mutex,
856        task,
857        time::{Duration, sleep},
858    };
859
860    use super::*;
861
862    fn create_handler() -> PyObject {
863        let code_raw = r"
864class Counter:
865    def __init__(self):
866        self.count = 0
867        self.check = False
868
869    def handler(self, bytes):
870        msg = bytes.decode()
871        if msg == 'ping':
872            self.count += 1
873        elif msg == 'heartbeat message':
874            self.check = True
875
876    def get_check(self):
877        return self.check
878
879    def get_count(self):
880        return self.count
881
882counter = Counter()
883";
884        let code = CString::new(code_raw).unwrap();
885        let filename = CString::new("test".to_string()).unwrap();
886        let module = CString::new("test".to_string()).unwrap();
887        Python::with_gil(|py| {
888            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
889            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
890
891            counter
892                .getattr(py, "handler")
893                .unwrap()
894                .into_py_any_unwrap(py)
895        })
896    }
897
898    async fn bind_test_server() -> (u16, TcpListener) {
899        let listener = TcpListener::bind("127.0.0.1:0")
900            .await
901            .expect("Failed to bind ephemeral port");
902        let port = listener.local_addr().unwrap().port();
903        (port, listener)
904    }
905
906    async fn run_echo_server(mut socket: TcpStream) {
907        let mut buf = Vec::new();
908        loop {
909            match socket.read_buf(&mut buf).await {
910                Ok(0) => {
911                    break;
912                }
913                Ok(_n) => {
914                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
915                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
916                        // Remove trailing \r\n
917                        line.truncate(line.len() - 2);
918
919                        if line == b"close" {
920                            let _ = socket.shutdown().await;
921                            return;
922                        }
923
924                        let mut echo_data = line;
925                        echo_data.extend_from_slice(b"\r\n");
926                        if socket.write_all(&echo_data).await.is_err() {
927                            break;
928                        }
929                    }
930                }
931                Err(e) => {
932                    eprintln!("Server read error: {e}");
933                    break;
934                }
935            }
936        }
937    }
938
939    #[tokio::test]
940    async fn test_basic_send_receive() {
941        prepare_freethreaded_python();
942
943        let (port, listener) = bind_test_server().await;
944        let server_task = task::spawn(async move {
945            let (socket, _) = listener.accept().await.unwrap();
946            run_echo_server(socket).await;
947        });
948
949        let config = SocketConfig {
950            url: format!("127.0.0.1:{port}"),
951            mode: Mode::Plain,
952            suffix: b"\r\n".to_vec(),
953            py_handler: Some(Arc::new(create_handler())),
954            heartbeat: None,
955            reconnect_timeout_ms: None,
956            reconnect_delay_initial_ms: None,
957            reconnect_backoff_factor: None,
958            reconnect_delay_max_ms: None,
959            reconnect_jitter_ms: None,
960            certs_dir: None,
961        };
962
963        let client = SocketClient::connect(config, None, None, None, None)
964            .await
965            .expect("Client connect failed unexpectedly");
966
967        client.send_bytes(b"Hello".into()).await.unwrap();
968        client.send_bytes(b"World".into()).await.unwrap();
969
970        // Wait a bit for the server to echo them back
971        sleep(Duration::from_millis(100)).await;
972
973        client.send_bytes(b"close".into()).await.unwrap();
974        server_task.await.unwrap();
975        assert!(!client.is_closed());
976    }
977
978    #[tokio::test]
979    async fn test_reconnect_fail_exhausted() {
980        prepare_freethreaded_python();
981
982        let (port, listener) = bind_test_server().await;
983        drop(listener); // We drop it immediately -> no server is listening
984
985        let config = SocketConfig {
986            url: format!("127.0.0.1:{port}"),
987            mode: Mode::Plain,
988            suffix: b"\r\n".to_vec(),
989            py_handler: Some(Arc::new(create_handler())),
990            heartbeat: None,
991            reconnect_timeout_ms: None,
992            reconnect_delay_initial_ms: None,
993            reconnect_backoff_factor: None,
994            reconnect_delay_max_ms: None,
995            reconnect_jitter_ms: None,
996            certs_dir: None,
997        };
998
999        let client_res = SocketClient::connect(config, None, None, None, None).await;
1000        assert!(
1001            client_res.is_err(),
1002            "Should fail quickly with no server listening"
1003        );
1004    }
1005
1006    #[tokio::test]
1007    async fn test_user_disconnect() {
1008        prepare_freethreaded_python();
1009
1010        let (port, listener) = bind_test_server().await;
1011        let server_task = task::spawn(async move {
1012            let (socket, _) = listener.accept().await.unwrap();
1013            let mut buf = [0u8; 1024];
1014            let _ = socket.try_read(&mut buf);
1015
1016            loop {
1017                sleep(Duration::from_secs(1)).await;
1018            }
1019        });
1020
1021        let config = SocketConfig {
1022            url: format!("127.0.0.1:{port}"),
1023            mode: Mode::Plain,
1024            suffix: b"\r\n".to_vec(),
1025            py_handler: Some(Arc::new(create_handler())),
1026            heartbeat: None,
1027            reconnect_timeout_ms: None,
1028            reconnect_delay_initial_ms: None,
1029            reconnect_backoff_factor: None,
1030            reconnect_delay_max_ms: None,
1031            reconnect_jitter_ms: None,
1032            certs_dir: None,
1033        };
1034
1035        let client = SocketClient::connect(config, None, None, None, None)
1036            .await
1037            .unwrap();
1038
1039        client.close().await;
1040        assert!(client.is_closed());
1041        server_task.abort();
1042    }
1043
1044    #[tokio::test]
1045    async fn test_heartbeat() {
1046        prepare_freethreaded_python();
1047
1048        let (port, listener) = bind_test_server().await;
1049        let received = Arc::new(Mutex::new(Vec::new()));
1050        let received2 = received.clone();
1051
1052        let server_task = task::spawn(async move {
1053            let (socket, _) = listener.accept().await.unwrap();
1054
1055            let mut buf = Vec::new();
1056            loop {
1057                match socket.try_read_buf(&mut buf) {
1058                    Ok(0) => break,
1059                    Ok(_) => {
1060                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1061                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1062                            line.truncate(line.len() - 2);
1063                            received2.lock().await.push(line);
1064                        }
1065                    }
1066                    Err(_) => {
1067                        tokio::time::sleep(Duration::from_millis(10)).await;
1068                    }
1069                }
1070            }
1071        });
1072
1073        // Heartbeat every 1 second
1074        let heartbeat = Some((1, b"ping".to_vec()));
1075
1076        let config = SocketConfig {
1077            url: format!("127.0.0.1:{port}"),
1078            mode: Mode::Plain,
1079            suffix: b"\r\n".to_vec(),
1080            py_handler: Some(Arc::new(create_handler())),
1081            heartbeat,
1082            reconnect_timeout_ms: None,
1083            reconnect_delay_initial_ms: None,
1084            reconnect_backoff_factor: None,
1085            reconnect_delay_max_ms: None,
1086            reconnect_jitter_ms: None,
1087            certs_dir: None,
1088        };
1089
1090        let client = SocketClient::connect(config, None, None, None, None)
1091            .await
1092            .unwrap();
1093
1094        // Wait ~3 seconds to collect some heartbeats
1095        sleep(Duration::from_secs(3)).await;
1096
1097        {
1098            let lock = received.lock().await;
1099            let pings = lock
1100                .iter()
1101                .filter(|line| line == &&b"ping".to_vec())
1102                .count();
1103            assert!(
1104                pings >= 2,
1105                "Expected at least 2 heartbeat pings; got {pings}"
1106            );
1107        }
1108
1109        client.close().await;
1110        server_task.abort();
1111    }
1112
1113    #[tokio::test]
1114    async fn test_python_handler_error() {
1115        prepare_freethreaded_python();
1116
1117        let (port, listener) = bind_test_server().await;
1118        let server_task = task::spawn(async move {
1119            let (socket, _) = listener.accept().await.unwrap();
1120            run_echo_server(socket).await;
1121        });
1122
1123        let code_raw = r#"
1124def handler(bytes_data):
1125    txt = bytes_data.decode()
1126    if "ERR" in txt:
1127        raise ValueError("Simulated error in handler")
1128    return
1129"#;
1130        let code = CString::new(code_raw).unwrap();
1131        let filename = CString::new("test".to_string()).unwrap();
1132        let module = CString::new("test".to_string()).unwrap();
1133
1134        let py_handler = Some(Python::with_gil(|py| {
1135            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1136            let func = pymod.getattr("handler").unwrap();
1137            Arc::new(func.into_py_any_unwrap(py))
1138        }));
1139
1140        let config = SocketConfig {
1141            url: format!("127.0.0.1:{port}"),
1142            mode: Mode::Plain,
1143            suffix: b"\r\n".to_vec(),
1144            py_handler,
1145            heartbeat: None,
1146            reconnect_timeout_ms: None,
1147            reconnect_delay_initial_ms: None,
1148            reconnect_backoff_factor: None,
1149            reconnect_delay_max_ms: None,
1150            reconnect_jitter_ms: None,
1151            certs_dir: None,
1152        };
1153
1154        let client = SocketClient::connect(config, None, None, None, None)
1155            .await
1156            .expect("Client connect failed unexpectedly");
1157
1158        client.send_bytes(b"hello".into()).await.unwrap();
1159        sleep(Duration::from_millis(100)).await;
1160
1161        client.send_bytes(b"ERR".into()).await.unwrap();
1162        sleep(Duration::from_secs(1)).await;
1163
1164        assert!(client.is_active());
1165
1166        client.close().await;
1167
1168        assert!(client.is_closed());
1169        server_task.abort();
1170    }
1171
1172    #[tokio::test]
1173    async fn test_reconnect_success() {
1174        prepare_freethreaded_python();
1175
1176        let (port, listener) = bind_test_server().await;
1177
1178        // Spawn a server task that:
1179        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1180        // 2. Waits a bit and then accepts a new connection and runs the echo server
1181        let server_task = task::spawn(async move {
1182            // Accept first connection
1183            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1184
1185            // Wait briefly and then force-close the connection
1186            sleep(Duration::from_millis(500)).await;
1187            let _ = socket.shutdown().await;
1188
1189            // Wait for the client's reconnect attempt
1190            sleep(Duration::from_millis(500)).await;
1191
1192            // Run the echo server on the new connection
1193            let (socket, _) = listener.accept().await.expect("Second accept failed");
1194            run_echo_server(socket).await;
1195        });
1196
1197        let config = SocketConfig {
1198            url: format!("127.0.0.1:{port}"),
1199            mode: Mode::Plain,
1200            suffix: b"\r\n".to_vec(),
1201            py_handler: Some(Arc::new(create_handler())),
1202            heartbeat: None,
1203            reconnect_timeout_ms: Some(5_000),
1204            reconnect_delay_initial_ms: Some(500),
1205            reconnect_delay_max_ms: Some(5_000),
1206            reconnect_backoff_factor: Some(2.0),
1207            reconnect_jitter_ms: Some(50),
1208            certs_dir: None,
1209        };
1210
1211        let client = SocketClient::connect(config, None, None, None, None)
1212            .await
1213            .expect("Client connect failed unexpectedly");
1214
1215        // Initially, the client should be active
1216        assert!(client.is_active(), "Client should start as active");
1217
1218        // Wait until the client loses connection (i.e. not active),
1219        // then wait until it reconnects (active again).
1220        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1221
1222        client
1223            .send_bytes(b"TestReconnect".into())
1224            .await
1225            .expect("Send failed");
1226
1227        client.close().await;
1228        server_task.abort();
1229    }
1230}
1231
1232#[cfg(test)]
1233mod rust_tests {
1234    use tokio::{
1235        net::TcpListener,
1236        task,
1237        time::{Duration, sleep},
1238    };
1239
1240    use super::*;
1241
1242    #[tokio::test]
1243    async fn test_reconnect_then_close() {
1244        // Bind an ephemeral port
1245        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1246        let port = listener.local_addr().unwrap().port();
1247
1248        // Server task: accept one connection and then drop it
1249        let server = task::spawn(async move {
1250            if let Ok((mut sock, _)) = listener.accept().await {
1251                let _ = sock.shutdown();
1252            }
1253            // Keep listener alive briefly to avoid premature exit
1254            sleep(Duration::from_secs(1)).await;
1255        });
1256
1257        // Configure client with a short reconnect backoff
1258        let config = SocketConfig {
1259            url: format!("127.0.0.1:{port}"),
1260            mode: Mode::Plain,
1261            suffix: b"\r\n".to_vec(),
1262            #[cfg(feature = "python")]
1263            py_handler: None,
1264            heartbeat: None,
1265            reconnect_timeout_ms: Some(1_000),
1266            reconnect_delay_initial_ms: Some(50),
1267            reconnect_delay_max_ms: Some(100),
1268            reconnect_backoff_factor: Some(1.0),
1269            reconnect_jitter_ms: Some(0),
1270            certs_dir: None,
1271        };
1272
1273        // Connect client (handler=None)
1274        let client = {
1275            #[cfg(feature = "python")]
1276            {
1277                SocketClient::connect(config.clone(), None, None, None, None)
1278                    .await
1279                    .unwrap()
1280            }
1281            #[cfg(not(feature = "python"))]
1282            {
1283                SocketClient::connect(config.clone(), None).await.unwrap()
1284            }
1285        };
1286
1287        // Allow server to drop connection and client to notice
1288        sleep(Duration::from_millis(100)).await;
1289
1290        // Now close the client
1291        client.close().await;
1292        assert!(client.is_closed());
1293        server.abort();
1294    }
1295}