nautilus_network/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED)
21//! - Synchronized reconnection with backoff
22//! - Split read/write architecture
23//! - Python callback integration
24//!
25//! **Design**:
26//! - Single reader, multiple writer model
27//! - Read half runs in dedicated task
28//! - Write half runs in dedicated task connected with channel
29//! - Controller task manages lifecycle
30
31use std::{
32    fmt::Debug,
33    path::Path,
34    sync::{
35        Arc,
36        atomic::{AtomicU8, Ordering},
37    },
38    time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_cryptography::providers::install_cryptographic_provider;
43#[cfg(feature = "python")]
44use pyo3::prelude::*;
45use tokio::{
46    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
47    net::TcpStream,
48};
49use tokio_tungstenite::{
50    MaybeTlsStream,
51    tungstenite::{Error, client::IntoClientRequest, stream::Mode},
52};
53
54use crate::{
55    backoff::ExponentialBackoff,
56    error::SendError,
57    fix::process_fix_buffer,
58    logging::{log_task_aborted, log_task_started, log_task_stopped},
59    mode::ConnectionMode,
60    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
61};
62
63type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
64type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
65pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
66
67/// Configuration for TCP socket connection.
68#[derive(Debug, Clone)]
69#[cfg_attr(
70    feature = "python",
71    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
72)]
73pub struct SocketConfig {
74    /// The URL to connect to.
75    pub url: String,
76    /// The connection mode {Plain, TLS}.
77    pub mode: Mode,
78    /// The sequence of bytes which separates lines.
79    pub suffix: Vec<u8>,
80    #[cfg(feature = "python")]
81    /// The optional Python function to handle incoming messages.
82    pub py_handler: Option<Arc<PyObject>>,
83    /// The optional heartbeat with period and beat message.
84    pub heartbeat: Option<(u64, Vec<u8>)>,
85    /// The timeout (milliseconds) for reconnection attempts.
86    pub reconnect_timeout_ms: Option<u64>,
87    /// The initial reconnection delay (milliseconds) for reconnects.
88    pub reconnect_delay_initial_ms: Option<u64>,
89    /// The maximum reconnect delay (milliseconds) for exponential backoff.
90    pub reconnect_delay_max_ms: Option<u64>,
91    /// The exponential backoff factor for reconnection delays.
92    pub reconnect_backoff_factor: Option<f64>,
93    /// The maximum jitter (milliseconds) added to reconnection delays.
94    pub reconnect_jitter_ms: Option<u64>,
95    /// The path to the certificates directory.
96    pub certs_dir: Option<String>,
97}
98
99/// Represents a command for the writer task.
100#[derive(Debug)]
101pub enum WriterCommand {
102    /// Update the writer reference with a new one after reconnection.
103    Update(TcpWriter),
104    /// Send data to the server.
105    Send(Bytes),
106}
107
108/// Creates a `TcpStream` with the server.
109///
110/// The stream can be encrypted with TLS or Plain. The stream is split into
111/// read and write ends:
112/// - The read end is passed to the task that keeps receiving
113///   messages from the server and passing them to a handler.
114/// - The write end is passed to a task which receives messages over a channel
115///   to send to the server.
116///
117/// The heartbeat is optional and can be configured with an interval and data to
118/// send.
119///
120/// The client uses a suffix to separate messages on the byte stream. It is
121/// appended to all sent messages and heartbeats. It is also used to split
122/// the received byte stream.
123#[cfg_attr(
124    feature = "python",
125    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
126)]
127struct SocketClientInner {
128    config: SocketConfig,
129    connector: Option<Connector>,
130    read_task: Arc<tokio::task::JoinHandle<()>>,
131    write_task: tokio::task::JoinHandle<()>,
132    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
133    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
134    connection_mode: Arc<AtomicU8>,
135    reconnect_timeout: Duration,
136    backoff: ExponentialBackoff,
137    handler: Option<Arc<TcpMessageHandler>>,
138}
139
140impl SocketClientInner {
141    /// Connect to a URL with the specified configuration.
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if connection fails or configuration is invalid.
146    pub async fn connect_url(
147        config: SocketConfig,
148        handler: Option<Arc<TcpMessageHandler>>,
149    ) -> anyhow::Result<Self> {
150        install_cryptographic_provider();
151
152        let SocketConfig {
153            url,
154            mode,
155            heartbeat,
156            suffix,
157            #[cfg(feature = "python")]
158            py_handler,
159            reconnect_timeout_ms,
160            reconnect_delay_initial_ms,
161            reconnect_delay_max_ms,
162            reconnect_backoff_factor,
163            reconnect_jitter_ms,
164            certs_dir,
165        } = &config;
166        let connector = if let Some(dir) = certs_dir {
167            let config = create_tls_config_from_certs_dir(Path::new(dir))?;
168            Some(Connector::Rustls(Arc::new(config)))
169        } else {
170            None
171        };
172
173        let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
174        tracing::debug!("Connected");
175
176        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
177
178        let read_task = Arc::new(Self::spawn_read_task(
179            connection_mode.clone(),
180            reader,
181            handler.clone(),
182            #[cfg(feature = "python")]
183            py_handler.clone(),
184            suffix.clone(),
185        ));
186
187        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
188
189        let write_task =
190            Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
191
192        // Optionally spawn a heartbeat task to periodically ping server
193        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
194            Self::spawn_heartbeat_task(
195                connection_mode.clone(),
196                heartbeat.clone(),
197                writer_tx.clone(),
198            )
199        });
200
201        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
202        let backoff = ExponentialBackoff::new(
203            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
204            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
205            reconnect_backoff_factor.unwrap_or(1.5),
206            reconnect_jitter_ms.unwrap_or(100),
207            true, // immediate-first
208        )?;
209
210        Ok(Self {
211            config,
212            connector,
213            read_task,
214            write_task,
215            writer_tx,
216            heartbeat_task,
217            connection_mode,
218            reconnect_timeout,
219            backoff,
220            handler,
221        })
222    }
223
224    /// Establish a TLS or plain TCP connection with the server.
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the connection cannot be established.
229    pub async fn tls_connect_with_server(
230        url: &str,
231        mode: Mode,
232        connector: Option<Connector>,
233    ) -> Result<(TcpReader, TcpWriter), Error> {
234        tracing::debug!("Connecting to {url}");
235        let tcp_result = TcpStream::connect(url).await;
236
237        match tcp_result {
238            Ok(stream) => {
239                tracing::debug!("TCP connection established, proceeding with TLS");
240                let request = url.into_client_request()?;
241                tcp_tls(&request, mode, stream, connector)
242                    .await
243                    .map(tokio::io::split)
244            }
245            Err(e) => {
246                tracing::error!("TCP connection failed: {e:?}");
247                Err(Error::Io(e))
248            }
249        }
250    }
251
252    /// Reconnect with server.
253    ///
254    /// Makes a new connection with server, uses the new read and write halves
255    /// to update the reader and writer.
256    async fn reconnect(&mut self) -> Result<(), Error> {
257        tracing::debug!("Reconnecting");
258
259        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
260            tracing::debug!("Reconnect aborted due to disconnect state");
261            return Ok(());
262        }
263
264        tokio::time::timeout(self.reconnect_timeout, async {
265            let SocketConfig {
266                url,
267                mode,
268                heartbeat: _,
269                suffix,
270                #[cfg(feature = "python")]
271                py_handler,
272                reconnect_timeout_ms: _,
273                reconnect_delay_initial_ms: _,
274                reconnect_backoff_factor: _,
275                reconnect_delay_max_ms: _,
276                reconnect_jitter_ms: _,
277                certs_dir: _,
278            } = &self.config;
279            // Create a fresh connection
280            let connector = self.connector.clone();
281            // Attempt to connect; abort early if a disconnect was requested
282            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
283
284            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
285                tracing::debug!("Reconnect aborted mid-flight (after connect)");
286                return Ok(());
287            }
288            tracing::debug!("Connected");
289
290            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
291                tracing::error!("{e}");
292            }
293
294            // Delay before closing connection
295            tokio::time::sleep(Duration::from_millis(100)).await;
296
297            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
298                tracing::debug!("Reconnect aborted mid-flight (after delay)");
299                return Ok(());
300            }
301
302            if !self.read_task.is_finished() {
303                self.read_task.abort();
304                log_task_aborted("read");
305            }
306
307            // If a disconnect was requested during reconnect, do not proceed to reactivate
308            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
309                tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
310                return Ok(());
311            }
312
313            // Mark as active only if not disconnecting
314            self.connection_mode
315                .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
316
317            // Spawn new read task
318            self.read_task = Arc::new(Self::spawn_read_task(
319                self.connection_mode.clone(),
320                reader,
321                self.handler.clone(),
322                #[cfg(feature = "python")]
323                py_handler.clone(),
324                suffix.clone(),
325            ));
326
327            tracing::debug!("Reconnect succeeded");
328            Ok(())
329        })
330        .await
331        .map_err(|_| {
332            Error::Io(std::io::Error::new(
333                std::io::ErrorKind::TimedOut,
334                format!(
335                    "reconnection timed out after {}s",
336                    self.reconnect_timeout.as_secs_f64()
337                ),
338            ))
339        })?
340    }
341
342    /// Check if the client is still alive.
343    ///
344    /// The client is connected if the read task has not finished. It is expected
345    /// that in case of any failure client or server side. The read task will be
346    /// shutdown. There might be some delay between the connection being closed
347    /// and the client detecting it.
348    #[inline]
349    #[must_use]
350    pub fn is_alive(&self) -> bool {
351        !self.read_task.is_finished()
352    }
353
354    #[must_use]
355    fn spawn_read_task(
356        connection_state: Arc<AtomicU8>,
357        mut reader: TcpReader,
358        handler: Option<Arc<TcpMessageHandler>>,
359        #[cfg(feature = "python")] py_handler: Option<Arc<PyObject>>,
360        suffix: Vec<u8>,
361    ) -> tokio::task::JoinHandle<()> {
362        log_task_started("read");
363
364        // Interval between checking the connection mode
365        let check_interval = Duration::from_millis(10);
366
367        tokio::task::spawn(async move {
368            let mut buf = Vec::new();
369
370            loop {
371                if !ConnectionMode::from_atomic(&connection_state).is_active() {
372                    break;
373                }
374
375                match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
376                    // Connection has been terminated or vector buffer is complete
377                    Ok(Ok(0)) => {
378                        tracing::debug!("Connection closed by server");
379                        break;
380                    }
381                    Ok(Err(e)) => {
382                        tracing::debug!("Connection ended: {e}");
383                        break;
384                    }
385                    // Received bytes of data
386                    Ok(Ok(bytes)) => {
387                        tracing::trace!("Received <binary> {bytes} bytes");
388
389                        if let Some(handler) = &handler {
390                            process_fix_buffer(&mut buf, handler);
391                        } else {
392                            while let Some((i, _)) = &buf
393                                .windows(suffix.len())
394                                .enumerate()
395                                .find(|(_, pair)| pair.eq(&suffix))
396                            {
397                                let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
398                                data.truncate(data.len() - suffix.len());
399
400                                if let Some(handler) = &handler {
401                                    handler(&data);
402                                }
403
404                                #[cfg(feature = "python")]
405                                if let Some(py_handler) = &py_handler {
406                                    if let Err(e) = Python::with_gil(|py| {
407                                        py_handler.call1(py, (data.as_slice(),))
408                                    }) {
409                                        tracing::error!("Call to handler failed: {e}");
410                                        break;
411                                    }
412                                }
413                            }
414                        }
415                    }
416                    Err(_) => {
417                        // Timeout - continue loop and check connection mode
418                        continue;
419                    }
420                }
421            }
422
423            log_task_stopped("read");
424        })
425    }
426
427    fn spawn_write_task(
428        connection_state: Arc<AtomicU8>,
429        writer: TcpWriter,
430        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
431        suffix: Vec<u8>,
432    ) -> tokio::task::JoinHandle<()> {
433        log_task_started("write");
434
435        // Interval between checking the connection mode
436        let check_interval = Duration::from_millis(10);
437
438        tokio::task::spawn(async move {
439            let mut active_writer = writer;
440
441            loop {
442                if matches!(
443                    ConnectionMode::from_atomic(&connection_state),
444                    ConnectionMode::Disconnect | ConnectionMode::Closed
445                ) {
446                    break;
447                }
448
449                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
450                    Ok(Some(msg)) => {
451                        // Re-check connection mode after receiving a message
452                        let mode = ConnectionMode::from_atomic(&connection_state);
453                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
454                            break;
455                        }
456
457                        match msg {
458                            WriterCommand::Update(new_writer) => {
459                                tracing::debug!("Received new writer");
460
461                                // Delay before closing connection
462                                tokio::time::sleep(Duration::from_millis(100)).await;
463
464                                // Attempt to shutdown the writer gracefully before updating,
465                                // we ignore any error as the writer may already be closed.
466                                _ = active_writer.shutdown().await;
467
468                                active_writer = new_writer;
469                                tracing::debug!("Updated writer");
470                            }
471                            _ if mode.is_reconnect() => {
472                                tracing::warn!("Skipping message while reconnecting, {msg:?}");
473                                continue;
474                            }
475                            WriterCommand::Send(msg) => {
476                                if let Err(e) = active_writer.write_all(&msg).await {
477                                    tracing::error!("Failed to send message: {e}");
478                                    // Mode is active so trigger reconnection
479                                    tracing::warn!("Writer triggering reconnect");
480                                    connection_state
481                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
482                                    continue;
483                                }
484                                if let Err(e) = active_writer.write_all(&suffix).await {
485                                    tracing::error!("Failed to send message: {e}");
486                                }
487                            }
488                        }
489                    }
490                    Ok(None) => {
491                        // Channel closed - writer task should terminate
492                        tracing::debug!("Writer channel closed, terminating writer task");
493                        break;
494                    }
495                    Err(_) => {
496                        // Timeout - just continue the loop
497                        continue;
498                    }
499                }
500            }
501
502            // Attempt to shutdown the writer gracefully before exiting,
503            // we ignore any error as the writer may already be closed.
504            _ = active_writer.shutdown().await;
505
506            log_task_stopped("write");
507        })
508    }
509
510    fn spawn_heartbeat_task(
511        connection_state: Arc<AtomicU8>,
512        heartbeat: (u64, Vec<u8>),
513        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
514    ) -> tokio::task::JoinHandle<()> {
515        log_task_started("heartbeat");
516        let (interval_secs, message) = heartbeat;
517
518        tokio::task::spawn(async move {
519            let interval = Duration::from_secs(interval_secs);
520
521            loop {
522                tokio::time::sleep(interval).await;
523
524                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
525                    ConnectionMode::Active => {
526                        let msg = WriterCommand::Send(message.clone().into());
527
528                        match writer_tx.send(msg) {
529                            Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
530                            Err(e) => {
531                                tracing::error!("Failed to send heartbeat to writer task: {e}");
532                            }
533                        }
534                    }
535                    ConnectionMode::Reconnect => continue,
536                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
537                }
538            }
539
540            log_task_stopped("heartbeat");
541        })
542    }
543}
544
545impl Drop for SocketClientInner {
546    fn drop(&mut self) {
547        if !self.read_task.is_finished() {
548            self.read_task.abort();
549            log_task_aborted("read");
550        }
551
552        if !self.write_task.is_finished() {
553            self.write_task.abort();
554            log_task_aborted("write");
555        }
556
557        if let Some(ref handle) = self.heartbeat_task.take() {
558            if !handle.is_finished() {
559                handle.abort();
560                log_task_aborted("heartbeat");
561            }
562        }
563    }
564}
565
566#[cfg_attr(
567    feature = "python",
568    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
569)]
570pub struct SocketClient {
571    pub(crate) controller_task: tokio::task::JoinHandle<()>,
572    pub(crate) connection_mode: Arc<AtomicU8>,
573    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
574}
575
576impl Debug for SocketClient {
577    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578        f.debug_struct(stringify!(SocketClient)).finish()
579    }
580}
581
582impl SocketClient {
583    /// Connect to the server.
584    ///
585    /// # Errors
586    ///
587    /// Returns any error connecting to the server.
588    pub async fn connect(
589        config: SocketConfig,
590        handler: Option<Arc<TcpMessageHandler>>,
591        #[cfg(feature = "python")] post_connection: Option<PyObject>,
592        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
593        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
594    ) -> anyhow::Result<Self> {
595        let inner = SocketClientInner::connect_url(config, handler).await?;
596        let writer_tx = inner.writer_tx.clone();
597        let connection_mode = inner.connection_mode.clone();
598
599        let controller_task = Self::spawn_controller_task(
600            inner,
601            connection_mode.clone(),
602            #[cfg(feature = "python")]
603            post_reconnection,
604            #[cfg(feature = "python")]
605            post_disconnection,
606        );
607
608        #[cfg(feature = "python")]
609        if let Some(handler) = post_connection {
610            Python::with_gil(|py| match handler.call0(py) {
611                Ok(_) => tracing::debug!("Called `post_connection` handler"),
612                Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
613            });
614        }
615
616        Ok(Self {
617            controller_task,
618            connection_mode,
619            writer_tx,
620        })
621    }
622
623    /// Returns the current connection mode.
624    #[must_use]
625    pub fn connection_mode(&self) -> ConnectionMode {
626        ConnectionMode::from_atomic(&self.connection_mode)
627    }
628
629    /// Check if the client connection is active.
630    ///
631    /// Returns `true` if the client is connected and has not been signalled to disconnect.
632    /// The client will automatically retry connection based on its configuration.
633    #[inline]
634    #[must_use]
635    pub fn is_active(&self) -> bool {
636        self.connection_mode().is_active()
637    }
638
639    /// Check if the client is reconnecting.
640    ///
641    /// Returns `true` if the client lost connection and is attempting to reestablish it.
642    /// The client will automatically retry connection based on its configuration.
643    #[inline]
644    #[must_use]
645    pub fn is_reconnecting(&self) -> bool {
646        self.connection_mode().is_reconnect()
647    }
648
649    /// Check if the client is disconnecting.
650    ///
651    /// Returns `true` if the client is in disconnect mode.
652    #[inline]
653    #[must_use]
654    pub fn is_disconnecting(&self) -> bool {
655        self.connection_mode().is_disconnect()
656    }
657
658    /// Check if the client is closed.
659    ///
660    /// Returns `true` if the client has been explicitly disconnected or reached
661    /// maximum reconnection attempts. In this state, the client cannot be reused
662    /// and a new client must be created for further connections.
663    #[inline]
664    #[must_use]
665    pub fn is_closed(&self) -> bool {
666        self.connection_mode().is_closed()
667    }
668
669    /// Close the client.
670    ///
671    /// Controller task will periodically check the disconnect mode
672    /// and shutdown the client if it is not alive.
673    pub async fn close(&self) {
674        self.connection_mode
675            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
676
677        match tokio::time::timeout(Duration::from_secs(5), async {
678            while !self.is_closed() {
679                tokio::time::sleep(Duration::from_millis(10)).await;
680            }
681
682            if !self.controller_task.is_finished() {
683                self.controller_task.abort();
684                log_task_aborted("controller");
685            }
686        })
687        .await
688        {
689            Ok(()) => {
690                log_task_stopped("controller");
691            }
692            Err(_) => {
693                tracing::error!("Timeout waiting for controller task to finish");
694            }
695        }
696    }
697
698    /// Sends a message of the given `data`.
699    ///
700    /// # Errors
701    ///
702    /// Returns an error if sending fails.
703    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
704        if self.is_closed() {
705            return Err(SendError::Closed);
706        }
707
708        let timeout = Duration::from_secs(2);
709        let check_interval = Duration::from_millis(1);
710
711        if !self.is_active() {
712            tracing::debug!("Waiting for client to become ACTIVE before sending...");
713
714            let inner = tokio::time::timeout(timeout, async {
715                loop {
716                    if self.is_active() {
717                        return Ok(());
718                    }
719                    if matches!(
720                        self.connection_mode(),
721                        ConnectionMode::Disconnect | ConnectionMode::Closed
722                    ) {
723                        return Err(());
724                    }
725                    tokio::time::sleep(check_interval).await;
726                }
727            })
728            .await
729            .map_err(|_| SendError::Timeout)?;
730            inner.map_err(|()| SendError::Closed)?;
731        }
732
733        let msg = WriterCommand::Send(data.into());
734        self.writer_tx
735            .send(msg)
736            .map_err(|e| SendError::BrokenPipe(e.to_string()))
737    }
738
739    fn spawn_controller_task(
740        mut inner: SocketClientInner,
741        connection_mode: Arc<AtomicU8>,
742        #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
743        #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
744    ) -> tokio::task::JoinHandle<()> {
745        tokio::task::spawn(async move {
746            log_task_started("controller");
747
748            let check_interval = Duration::from_millis(10);
749
750            loop {
751                tokio::time::sleep(check_interval).await;
752                let mode = ConnectionMode::from_atomic(&connection_mode);
753
754                if mode.is_disconnect() {
755                    tracing::debug!("Disconnecting");
756
757                    let timeout = Duration::from_secs(5);
758                    if tokio::time::timeout(timeout, async {
759                        // Delay awaiting graceful shutdown
760                        tokio::time::sleep(Duration::from_millis(100)).await;
761
762                        if !inner.read_task.is_finished() {
763                            inner.read_task.abort();
764                            log_task_aborted("read");
765                        }
766
767                        if let Some(task) = &inner.heartbeat_task {
768                            if !task.is_finished() {
769                                task.abort();
770                                log_task_aborted("heartbeat");
771                            }
772                        }
773                    })
774                    .await
775                    .is_err()
776                    {
777                        tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
778                    }
779
780                    tracing::debug!("Closed");
781
782                    #[cfg(feature = "python")]
783                    if let Some(ref handler) = post_disconnection {
784                        Python::with_gil(|py| match handler.call0(py) {
785                            Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
786                            Err(e) => {
787                                tracing::error!("Error calling `post_disconnection` handler: {e}");
788                            }
789                        });
790                    }
791                    break; // Controller finished
792                }
793
794                if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
795                    match inner.reconnect().await {
796                        Ok(()) => {
797                            tracing::debug!("Reconnected successfully");
798                            inner.backoff.reset();
799                            // Only invoke Python reconnect handler if still active
800                            #[cfg(feature = "python")]
801                            {
802                                if ConnectionMode::from_atomic(&connection_mode).is_active() {
803                                    if let Some(ref handler) = post_reconnection {
804                                        Python::with_gil(|py| match handler.call0(py) {
805                                            Ok(_) => tracing::debug!(
806                                                "Called `post_reconnection` handler"
807                                            ),
808                                            Err(e) => tracing::error!(
809                                                "Error calling `post_reconnection` handler: {e}"
810                                            ),
811                                        });
812                                    }
813                                } else {
814                                    tracing::debug!(
815                                        "Skipping post_reconnection handlers due to disconnect state"
816                                    );
817                                }
818                            }
819                        }
820                        Err(e) => {
821                            let duration = inner.backoff.next_duration();
822                            tracing::warn!("Reconnect attempt failed: {e}");
823                            if !duration.is_zero() {
824                                tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
825                            }
826                            tokio::time::sleep(duration).await;
827                        }
828                    }
829                }
830            }
831            inner
832                .connection_mode
833                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
834
835            log_task_stopped("controller");
836        })
837    }
838}
839
840// Abort controller task on drop to clean up background tasks
841impl Drop for SocketClient {
842    fn drop(&mut self) {
843        if !self.controller_task.is_finished() {
844            self.controller_task.abort();
845            log_task_aborted("controller");
846        }
847    }
848}
849
850////////////////////////////////////////////////////////////////////////////////
851// Tests
852////////////////////////////////////////////////////////////////////////////////
853#[cfg(test)]
854#[cfg(feature = "python")]
855#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
856mod tests {
857    use std::ffi::CString;
858
859    use nautilus_common::testing::wait_until_async;
860    use nautilus_core::python::IntoPyObjectPoseiExt;
861    use pyo3::prepare_freethreaded_python;
862    use tokio::{
863        io::{AsyncReadExt, AsyncWriteExt},
864        net::{TcpListener, TcpStream},
865        sync::Mutex,
866        task,
867        time::{Duration, sleep},
868    };
869
870    use super::*;
871
872    fn create_handler() -> PyObject {
873        let code_raw = r"
874class Counter:
875    def __init__(self):
876        self.count = 0
877        self.check = False
878
879    def handler(self, bytes):
880        msg = bytes.decode()
881        if msg == 'ping':
882            self.count += 1
883        elif msg == 'heartbeat message':
884            self.check = True
885
886    def get_check(self):
887        return self.check
888
889    def get_count(self):
890        return self.count
891
892counter = Counter()
893";
894        let code = CString::new(code_raw).unwrap();
895        let filename = CString::new("test".to_string()).unwrap();
896        let module = CString::new("test".to_string()).unwrap();
897        Python::with_gil(|py| {
898            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
899            let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
900
901            counter
902                .getattr(py, "handler")
903                .unwrap()
904                .into_py_any_unwrap(py)
905        })
906    }
907
908    async fn bind_test_server() -> (u16, TcpListener) {
909        let listener = TcpListener::bind("127.0.0.1:0")
910            .await
911            .expect("Failed to bind ephemeral port");
912        let port = listener.local_addr().unwrap().port();
913        (port, listener)
914    }
915
916    async fn run_echo_server(mut socket: TcpStream) {
917        let mut buf = Vec::new();
918        loop {
919            match socket.read_buf(&mut buf).await {
920                Ok(0) => {
921                    break;
922                }
923                Ok(_n) => {
924                    while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
925                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
926                        // Remove trailing \r\n
927                        line.truncate(line.len() - 2);
928
929                        if line == b"close" {
930                            let _ = socket.shutdown().await;
931                            return;
932                        }
933
934                        let mut echo_data = line;
935                        echo_data.extend_from_slice(b"\r\n");
936                        if socket.write_all(&echo_data).await.is_err() {
937                            break;
938                        }
939                    }
940                }
941                Err(e) => {
942                    eprintln!("Server read error: {e}");
943                    break;
944                }
945            }
946        }
947    }
948
949    #[tokio::test]
950    async fn test_basic_send_receive() {
951        prepare_freethreaded_python();
952
953        let (port, listener) = bind_test_server().await;
954        let server_task = task::spawn(async move {
955            let (socket, _) = listener.accept().await.unwrap();
956            run_echo_server(socket).await;
957        });
958
959        let config = SocketConfig {
960            url: format!("127.0.0.1:{port}"),
961            mode: Mode::Plain,
962            suffix: b"\r\n".to_vec(),
963            py_handler: Some(Arc::new(create_handler())),
964            heartbeat: None,
965            reconnect_timeout_ms: None,
966            reconnect_delay_initial_ms: None,
967            reconnect_backoff_factor: None,
968            reconnect_delay_max_ms: None,
969            reconnect_jitter_ms: None,
970            certs_dir: None,
971        };
972
973        let client = SocketClient::connect(config, None, None, None, None)
974            .await
975            .expect("Client connect failed unexpectedly");
976
977        client.send_bytes(b"Hello".into()).await.unwrap();
978        client.send_bytes(b"World".into()).await.unwrap();
979
980        // Wait a bit for the server to echo them back
981        sleep(Duration::from_millis(100)).await;
982
983        client.send_bytes(b"close".into()).await.unwrap();
984        server_task.await.unwrap();
985        assert!(!client.is_closed());
986    }
987
988    #[tokio::test]
989    async fn test_reconnect_fail_exhausted() {
990        prepare_freethreaded_python();
991
992        let (port, listener) = bind_test_server().await;
993        drop(listener); // We drop it immediately -> no server is listening
994
995        let config = SocketConfig {
996            url: format!("127.0.0.1:{port}"),
997            mode: Mode::Plain,
998            suffix: b"\r\n".to_vec(),
999            py_handler: Some(Arc::new(create_handler())),
1000            heartbeat: None,
1001            reconnect_timeout_ms: None,
1002            reconnect_delay_initial_ms: None,
1003            reconnect_backoff_factor: None,
1004            reconnect_delay_max_ms: None,
1005            reconnect_jitter_ms: None,
1006            certs_dir: None,
1007        };
1008
1009        let client_res = SocketClient::connect(config, None, None, None, None).await;
1010        assert!(
1011            client_res.is_err(),
1012            "Should fail quickly with no server listening"
1013        );
1014    }
1015
1016    #[tokio::test]
1017    async fn test_user_disconnect() {
1018        prepare_freethreaded_python();
1019
1020        let (port, listener) = bind_test_server().await;
1021        let server_task = task::spawn(async move {
1022            let (socket, _) = listener.accept().await.unwrap();
1023            let mut buf = [0u8; 1024];
1024            let _ = socket.try_read(&mut buf);
1025
1026            loop {
1027                sleep(Duration::from_secs(1)).await;
1028            }
1029        });
1030
1031        let config = SocketConfig {
1032            url: format!("127.0.0.1:{port}"),
1033            mode: Mode::Plain,
1034            suffix: b"\r\n".to_vec(),
1035            py_handler: Some(Arc::new(create_handler())),
1036            heartbeat: None,
1037            reconnect_timeout_ms: None,
1038            reconnect_delay_initial_ms: None,
1039            reconnect_backoff_factor: None,
1040            reconnect_delay_max_ms: None,
1041            reconnect_jitter_ms: None,
1042            certs_dir: None,
1043        };
1044
1045        let client = SocketClient::connect(config, None, None, None, None)
1046            .await
1047            .unwrap();
1048
1049        client.close().await;
1050        assert!(client.is_closed());
1051        server_task.abort();
1052    }
1053
1054    #[tokio::test]
1055    async fn test_heartbeat() {
1056        prepare_freethreaded_python();
1057
1058        let (port, listener) = bind_test_server().await;
1059        let received = Arc::new(Mutex::new(Vec::new()));
1060        let received2 = received.clone();
1061
1062        let server_task = task::spawn(async move {
1063            let (socket, _) = listener.accept().await.unwrap();
1064
1065            let mut buf = Vec::new();
1066            loop {
1067                match socket.try_read_buf(&mut buf) {
1068                    Ok(0) => break,
1069                    Ok(_) => {
1070                        while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1071                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1072                            line.truncate(line.len() - 2);
1073                            received2.lock().await.push(line);
1074                        }
1075                    }
1076                    Err(_) => {
1077                        tokio::time::sleep(Duration::from_millis(10)).await;
1078                    }
1079                }
1080            }
1081        });
1082
1083        // Heartbeat every 1 second
1084        let heartbeat = Some((1, b"ping".to_vec()));
1085
1086        let config = SocketConfig {
1087            url: format!("127.0.0.1:{port}"),
1088            mode: Mode::Plain,
1089            suffix: b"\r\n".to_vec(),
1090            py_handler: Some(Arc::new(create_handler())),
1091            heartbeat,
1092            reconnect_timeout_ms: None,
1093            reconnect_delay_initial_ms: None,
1094            reconnect_backoff_factor: None,
1095            reconnect_delay_max_ms: None,
1096            reconnect_jitter_ms: None,
1097            certs_dir: None,
1098        };
1099
1100        let client = SocketClient::connect(config, None, None, None, None)
1101            .await
1102            .unwrap();
1103
1104        // Wait ~3 seconds to collect some heartbeats
1105        sleep(Duration::from_secs(3)).await;
1106
1107        {
1108            let lock = received.lock().await;
1109            let pings = lock
1110                .iter()
1111                .filter(|line| line == &&b"ping".to_vec())
1112                .count();
1113            assert!(
1114                pings >= 2,
1115                "Expected at least 2 heartbeat pings; got {pings}"
1116            );
1117        }
1118
1119        client.close().await;
1120        server_task.abort();
1121    }
1122
1123    #[tokio::test]
1124    async fn test_python_handler_error() {
1125        prepare_freethreaded_python();
1126
1127        let (port, listener) = bind_test_server().await;
1128        let server_task = task::spawn(async move {
1129            let (socket, _) = listener.accept().await.unwrap();
1130            run_echo_server(socket).await;
1131        });
1132
1133        let code_raw = r#"
1134def handler(bytes_data):
1135    txt = bytes_data.decode()
1136    if "ERR" in txt:
1137        raise ValueError("Simulated error in handler")
1138    return
1139"#;
1140        let code = CString::new(code_raw).unwrap();
1141        let filename = CString::new("test".to_string()).unwrap();
1142        let module = CString::new("test".to_string()).unwrap();
1143
1144        let py_handler = Some(Python::with_gil(|py| {
1145            let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1146            let func = pymod.getattr("handler").unwrap();
1147            Arc::new(func.into_py_any_unwrap(py))
1148        }));
1149
1150        let config = SocketConfig {
1151            url: format!("127.0.0.1:{port}"),
1152            mode: Mode::Plain,
1153            suffix: b"\r\n".to_vec(),
1154            py_handler,
1155            heartbeat: None,
1156            reconnect_timeout_ms: None,
1157            reconnect_delay_initial_ms: None,
1158            reconnect_backoff_factor: None,
1159            reconnect_delay_max_ms: None,
1160            reconnect_jitter_ms: None,
1161            certs_dir: None,
1162        };
1163
1164        let client = SocketClient::connect(config, None, None, None, None)
1165            .await
1166            .expect("Client connect failed unexpectedly");
1167
1168        client.send_bytes(b"hello".into()).await.unwrap();
1169        sleep(Duration::from_millis(100)).await;
1170
1171        client.send_bytes(b"ERR".into()).await.unwrap();
1172        sleep(Duration::from_secs(1)).await;
1173
1174        assert!(client.is_active());
1175
1176        client.close().await;
1177
1178        assert!(client.is_closed());
1179        server_task.abort();
1180    }
1181
1182    #[tokio::test]
1183    async fn test_reconnect_success() {
1184        prepare_freethreaded_python();
1185
1186        let (port, listener) = bind_test_server().await;
1187
1188        // Spawn a server task that:
1189        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1190        // 2. Waits a bit and then accepts a new connection and runs the echo server
1191        let server_task = task::spawn(async move {
1192            // Accept first connection
1193            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1194
1195            // Wait briefly and then force-close the connection
1196            sleep(Duration::from_millis(500)).await;
1197            let _ = socket.shutdown().await;
1198
1199            // Wait for the client's reconnect attempt
1200            sleep(Duration::from_millis(500)).await;
1201
1202            // Run the echo server on the new connection
1203            let (socket, _) = listener.accept().await.expect("Second accept failed");
1204            run_echo_server(socket).await;
1205        });
1206
1207        let config = SocketConfig {
1208            url: format!("127.0.0.1:{port}"),
1209            mode: Mode::Plain,
1210            suffix: b"\r\n".to_vec(),
1211            py_handler: Some(Arc::new(create_handler())),
1212            heartbeat: None,
1213            reconnect_timeout_ms: Some(5_000),
1214            reconnect_delay_initial_ms: Some(500),
1215            reconnect_delay_max_ms: Some(5_000),
1216            reconnect_backoff_factor: Some(2.0),
1217            reconnect_jitter_ms: Some(50),
1218            certs_dir: None,
1219        };
1220
1221        let client = SocketClient::connect(config, None, None, None, None)
1222            .await
1223            .expect("Client connect failed unexpectedly");
1224
1225        // Initially, the client should be active
1226        assert!(client.is_active(), "Client should start as active");
1227
1228        // Wait until the client loses connection (i.e. not active),
1229        // then wait until it reconnects (active again).
1230        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1231
1232        client
1233            .send_bytes(b"TestReconnect".into())
1234            .await
1235            .expect("Send failed");
1236
1237        client.close().await;
1238        server_task.abort();
1239    }
1240}
1241
1242#[cfg(test)]
1243mod rust_tests {
1244    use tokio::{
1245        net::TcpListener,
1246        task,
1247        time::{Duration, sleep},
1248    };
1249
1250    use super::*;
1251
1252    #[tokio::test]
1253    async fn test_reconnect_then_close() {
1254        // Bind an ephemeral port
1255        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1256        let port = listener.local_addr().unwrap().port();
1257
1258        // Server task: accept one connection and then drop it
1259        let server = task::spawn(async move {
1260            if let Ok((mut sock, _)) = listener.accept().await {
1261                let _ = sock.shutdown();
1262            }
1263            // Keep listener alive briefly to avoid premature exit
1264            sleep(Duration::from_secs(1)).await;
1265        });
1266
1267        // Configure client with a short reconnect backoff
1268        let config = SocketConfig {
1269            url: format!("127.0.0.1:{port}"),
1270            mode: Mode::Plain,
1271            suffix: b"\r\n".to_vec(),
1272            #[cfg(feature = "python")]
1273            py_handler: None,
1274            heartbeat: None,
1275            reconnect_timeout_ms: Some(1_000),
1276            reconnect_delay_initial_ms: Some(50),
1277            reconnect_delay_max_ms: Some(100),
1278            reconnect_backoff_factor: Some(1.0),
1279            reconnect_jitter_ms: Some(0),
1280            certs_dir: None,
1281        };
1282
1283        // Connect client (handler=None)
1284        let client = {
1285            #[cfg(feature = "python")]
1286            {
1287                SocketClient::connect(config.clone(), None, None, None, None)
1288                    .await
1289                    .unwrap()
1290            }
1291            #[cfg(not(feature = "python"))]
1292            {
1293                SocketClient::connect(config.clone(), None).await.unwrap()
1294            }
1295        };
1296
1297        // Allow server to drop connection and client to notice
1298        sleep(Duration::from_millis(100)).await;
1299
1300        // Now close the client
1301        client.close().await;
1302        assert!(client.is_closed());
1303        server.abort();
1304    }
1305}