1use std::{
32 fmt::Debug,
33 path::Path,
34 sync::{
35 Arc,
36 atomic::{AtomicU8, Ordering},
37 },
38 time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_cryptography::providers::install_cryptographic_provider;
43#[cfg(feature = "python")]
44use pyo3::prelude::*;
45use tokio::{
46 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
47 net::TcpStream,
48};
49use tokio_tungstenite::{
50 MaybeTlsStream,
51 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
52};
53
54use crate::{
55 backoff::ExponentialBackoff,
56 error::SendError,
57 fix::process_fix_buffer,
58 logging::{log_task_aborted, log_task_started, log_task_stopped},
59 mode::ConnectionMode,
60 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
61};
62
63type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
64type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
65pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
66
67#[derive(Debug, Clone)]
69#[cfg_attr(
70 feature = "python",
71 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
72)]
73pub struct SocketConfig {
74 pub url: String,
76 pub mode: Mode,
78 pub suffix: Vec<u8>,
80 #[cfg(feature = "python")]
81 pub py_handler: Option<Arc<PyObject>>,
83 pub heartbeat: Option<(u64, Vec<u8>)>,
85 pub reconnect_timeout_ms: Option<u64>,
87 pub reconnect_delay_initial_ms: Option<u64>,
89 pub reconnect_delay_max_ms: Option<u64>,
91 pub reconnect_backoff_factor: Option<f64>,
93 pub reconnect_jitter_ms: Option<u64>,
95 pub certs_dir: Option<String>,
97}
98
99#[derive(Debug)]
101pub enum WriterCommand {
102 Update(TcpWriter),
104 Send(Bytes),
106}
107
108#[cfg_attr(
124 feature = "python",
125 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
126)]
127struct SocketClientInner {
128 config: SocketConfig,
129 connector: Option<Connector>,
130 read_task: Arc<tokio::task::JoinHandle<()>>,
131 write_task: tokio::task::JoinHandle<()>,
132 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
133 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
134 connection_mode: Arc<AtomicU8>,
135 reconnect_timeout: Duration,
136 backoff: ExponentialBackoff,
137 handler: Option<Arc<TcpMessageHandler>>,
138}
139
140impl SocketClientInner {
141 pub async fn connect_url(
147 config: SocketConfig,
148 handler: Option<Arc<TcpMessageHandler>>,
149 ) -> anyhow::Result<Self> {
150 install_cryptographic_provider();
151
152 let SocketConfig {
153 url,
154 mode,
155 heartbeat,
156 suffix,
157 #[cfg(feature = "python")]
158 py_handler,
159 reconnect_timeout_ms,
160 reconnect_delay_initial_ms,
161 reconnect_delay_max_ms,
162 reconnect_backoff_factor,
163 reconnect_jitter_ms,
164 certs_dir,
165 } = &config;
166 let connector = if let Some(dir) = certs_dir {
167 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
168 Some(Connector::Rustls(Arc::new(config)))
169 } else {
170 None
171 };
172
173 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
174 tracing::debug!("Connected");
175
176 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
177
178 let read_task = Arc::new(Self::spawn_read_task(
179 connection_mode.clone(),
180 reader,
181 handler.clone(),
182 #[cfg(feature = "python")]
183 py_handler.clone(),
184 suffix.clone(),
185 ));
186
187 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
188
189 let write_task =
190 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
191
192 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
194 Self::spawn_heartbeat_task(
195 connection_mode.clone(),
196 heartbeat.clone(),
197 writer_tx.clone(),
198 )
199 });
200
201 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
202 let backoff = ExponentialBackoff::new(
203 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
204 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
205 reconnect_backoff_factor.unwrap_or(1.5),
206 reconnect_jitter_ms.unwrap_or(100),
207 true, )?;
209
210 Ok(Self {
211 config,
212 connector,
213 read_task,
214 write_task,
215 writer_tx,
216 heartbeat_task,
217 connection_mode,
218 reconnect_timeout,
219 backoff,
220 handler,
221 })
222 }
223
224 pub async fn tls_connect_with_server(
230 url: &str,
231 mode: Mode,
232 connector: Option<Connector>,
233 ) -> Result<(TcpReader, TcpWriter), Error> {
234 tracing::debug!("Connecting to {url}");
235 let tcp_result = TcpStream::connect(url).await;
236
237 match tcp_result {
238 Ok(stream) => {
239 tracing::debug!("TCP connection established, proceeding with TLS");
240 let request = url.into_client_request()?;
241 tcp_tls(&request, mode, stream, connector)
242 .await
243 .map(tokio::io::split)
244 }
245 Err(e) => {
246 tracing::error!("TCP connection failed: {e:?}");
247 Err(Error::Io(e))
248 }
249 }
250 }
251
252 async fn reconnect(&mut self) -> Result<(), Error> {
257 tracing::debug!("Reconnecting");
258
259 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
260 tracing::debug!("Reconnect aborted due to disconnect state");
261 return Ok(());
262 }
263
264 tokio::time::timeout(self.reconnect_timeout, async {
265 let SocketConfig {
266 url,
267 mode,
268 heartbeat: _,
269 suffix,
270 #[cfg(feature = "python")]
271 py_handler,
272 reconnect_timeout_ms: _,
273 reconnect_delay_initial_ms: _,
274 reconnect_backoff_factor: _,
275 reconnect_delay_max_ms: _,
276 reconnect_jitter_ms: _,
277 certs_dir: _,
278 } = &self.config;
279 let connector = self.connector.clone();
281 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
283
284 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
285 tracing::debug!("Reconnect aborted mid-flight (after connect)");
286 return Ok(());
287 }
288 tracing::debug!("Connected");
289
290 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
291 tracing::error!("{e}");
292 }
293
294 tokio::time::sleep(Duration::from_millis(100)).await;
296
297 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
298 tracing::debug!("Reconnect aborted mid-flight (after delay)");
299 return Ok(());
300 }
301
302 if !self.read_task.is_finished() {
303 self.read_task.abort();
304 log_task_aborted("read");
305 }
306
307 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
309 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
310 return Ok(());
311 }
312
313 self.connection_mode
315 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
316
317 self.read_task = Arc::new(Self::spawn_read_task(
319 self.connection_mode.clone(),
320 reader,
321 self.handler.clone(),
322 #[cfg(feature = "python")]
323 py_handler.clone(),
324 suffix.clone(),
325 ));
326
327 tracing::debug!("Reconnect succeeded");
328 Ok(())
329 })
330 .await
331 .map_err(|_| {
332 Error::Io(std::io::Error::new(
333 std::io::ErrorKind::TimedOut,
334 format!(
335 "reconnection timed out after {}s",
336 self.reconnect_timeout.as_secs_f64()
337 ),
338 ))
339 })?
340 }
341
342 #[inline]
349 #[must_use]
350 pub fn is_alive(&self) -> bool {
351 !self.read_task.is_finished()
352 }
353
354 #[must_use]
355 fn spawn_read_task(
356 connection_state: Arc<AtomicU8>,
357 mut reader: TcpReader,
358 handler: Option<Arc<TcpMessageHandler>>,
359 #[cfg(feature = "python")] py_handler: Option<Arc<PyObject>>,
360 suffix: Vec<u8>,
361 ) -> tokio::task::JoinHandle<()> {
362 log_task_started("read");
363
364 let check_interval = Duration::from_millis(10);
366
367 tokio::task::spawn(async move {
368 let mut buf = Vec::new();
369
370 loop {
371 if !ConnectionMode::from_atomic(&connection_state).is_active() {
372 break;
373 }
374
375 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
376 Ok(Ok(0)) => {
378 tracing::debug!("Connection closed by server");
379 break;
380 }
381 Ok(Err(e)) => {
382 tracing::debug!("Connection ended: {e}");
383 break;
384 }
385 Ok(Ok(bytes)) => {
387 tracing::trace!("Received <binary> {bytes} bytes");
388
389 if let Some(handler) = &handler {
390 process_fix_buffer(&mut buf, handler);
391 } else {
392 while let Some((i, _)) = &buf
393 .windows(suffix.len())
394 .enumerate()
395 .find(|(_, pair)| pair.eq(&suffix))
396 {
397 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
398 data.truncate(data.len() - suffix.len());
399
400 if let Some(handler) = &handler {
401 handler(&data);
402 }
403
404 #[cfg(feature = "python")]
405 if let Some(py_handler) = &py_handler {
406 if let Err(e) = Python::with_gil(|py| {
407 py_handler.call1(py, (data.as_slice(),))
408 }) {
409 tracing::error!("Call to handler failed: {e}");
410 break;
411 }
412 }
413 }
414 }
415 }
416 Err(_) => {
417 continue;
419 }
420 }
421 }
422
423 log_task_stopped("read");
424 })
425 }
426
427 fn spawn_write_task(
428 connection_state: Arc<AtomicU8>,
429 writer: TcpWriter,
430 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
431 suffix: Vec<u8>,
432 ) -> tokio::task::JoinHandle<()> {
433 log_task_started("write");
434
435 let check_interval = Duration::from_millis(10);
437
438 tokio::task::spawn(async move {
439 let mut active_writer = writer;
440
441 loop {
442 if matches!(
443 ConnectionMode::from_atomic(&connection_state),
444 ConnectionMode::Disconnect | ConnectionMode::Closed
445 ) {
446 break;
447 }
448
449 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
450 Ok(Some(msg)) => {
451 let mode = ConnectionMode::from_atomic(&connection_state);
453 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
454 break;
455 }
456
457 match msg {
458 WriterCommand::Update(new_writer) => {
459 tracing::debug!("Received new writer");
460
461 tokio::time::sleep(Duration::from_millis(100)).await;
463
464 _ = active_writer.shutdown().await;
467
468 active_writer = new_writer;
469 tracing::debug!("Updated writer");
470 }
471 _ if mode.is_reconnect() => {
472 tracing::warn!("Skipping message while reconnecting, {msg:?}");
473 continue;
474 }
475 WriterCommand::Send(msg) => {
476 if let Err(e) = active_writer.write_all(&msg).await {
477 tracing::error!("Failed to send message: {e}");
478 tracing::warn!("Writer triggering reconnect");
480 connection_state
481 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
482 continue;
483 }
484 if let Err(e) = active_writer.write_all(&suffix).await {
485 tracing::error!("Failed to send message: {e}");
486 }
487 }
488 }
489 }
490 Ok(None) => {
491 tracing::debug!("Writer channel closed, terminating writer task");
493 break;
494 }
495 Err(_) => {
496 continue;
498 }
499 }
500 }
501
502 _ = active_writer.shutdown().await;
505
506 log_task_stopped("write");
507 })
508 }
509
510 fn spawn_heartbeat_task(
511 connection_state: Arc<AtomicU8>,
512 heartbeat: (u64, Vec<u8>),
513 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
514 ) -> tokio::task::JoinHandle<()> {
515 log_task_started("heartbeat");
516 let (interval_secs, message) = heartbeat;
517
518 tokio::task::spawn(async move {
519 let interval = Duration::from_secs(interval_secs);
520
521 loop {
522 tokio::time::sleep(interval).await;
523
524 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
525 ConnectionMode::Active => {
526 let msg = WriterCommand::Send(message.clone().into());
527
528 match writer_tx.send(msg) {
529 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
530 Err(e) => {
531 tracing::error!("Failed to send heartbeat to writer task: {e}");
532 }
533 }
534 }
535 ConnectionMode::Reconnect => continue,
536 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
537 }
538 }
539
540 log_task_stopped("heartbeat");
541 })
542 }
543}
544
545impl Drop for SocketClientInner {
546 fn drop(&mut self) {
547 if !self.read_task.is_finished() {
548 self.read_task.abort();
549 log_task_aborted("read");
550 }
551
552 if !self.write_task.is_finished() {
553 self.write_task.abort();
554 log_task_aborted("write");
555 }
556
557 if let Some(ref handle) = self.heartbeat_task.take() {
558 if !handle.is_finished() {
559 handle.abort();
560 log_task_aborted("heartbeat");
561 }
562 }
563 }
564}
565
566#[cfg_attr(
567 feature = "python",
568 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
569)]
570pub struct SocketClient {
571 pub(crate) controller_task: tokio::task::JoinHandle<()>,
572 pub(crate) connection_mode: Arc<AtomicU8>,
573 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
574}
575
576impl Debug for SocketClient {
577 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578 f.debug_struct(stringify!(SocketClient)).finish()
579 }
580}
581
582impl SocketClient {
583 pub async fn connect(
589 config: SocketConfig,
590 handler: Option<Arc<TcpMessageHandler>>,
591 #[cfg(feature = "python")] post_connection: Option<PyObject>,
592 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
593 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
594 ) -> anyhow::Result<Self> {
595 let inner = SocketClientInner::connect_url(config, handler).await?;
596 let writer_tx = inner.writer_tx.clone();
597 let connection_mode = inner.connection_mode.clone();
598
599 let controller_task = Self::spawn_controller_task(
600 inner,
601 connection_mode.clone(),
602 #[cfg(feature = "python")]
603 post_reconnection,
604 #[cfg(feature = "python")]
605 post_disconnection,
606 );
607
608 #[cfg(feature = "python")]
609 if let Some(handler) = post_connection {
610 Python::with_gil(|py| match handler.call0(py) {
611 Ok(_) => tracing::debug!("Called `post_connection` handler"),
612 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
613 });
614 }
615
616 Ok(Self {
617 controller_task,
618 connection_mode,
619 writer_tx,
620 })
621 }
622
623 #[must_use]
625 pub fn connection_mode(&self) -> ConnectionMode {
626 ConnectionMode::from_atomic(&self.connection_mode)
627 }
628
629 #[inline]
634 #[must_use]
635 pub fn is_active(&self) -> bool {
636 self.connection_mode().is_active()
637 }
638
639 #[inline]
644 #[must_use]
645 pub fn is_reconnecting(&self) -> bool {
646 self.connection_mode().is_reconnect()
647 }
648
649 #[inline]
653 #[must_use]
654 pub fn is_disconnecting(&self) -> bool {
655 self.connection_mode().is_disconnect()
656 }
657
658 #[inline]
664 #[must_use]
665 pub fn is_closed(&self) -> bool {
666 self.connection_mode().is_closed()
667 }
668
669 pub async fn close(&self) {
674 self.connection_mode
675 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
676
677 match tokio::time::timeout(Duration::from_secs(5), async {
678 while !self.is_closed() {
679 tokio::time::sleep(Duration::from_millis(10)).await;
680 }
681
682 if !self.controller_task.is_finished() {
683 self.controller_task.abort();
684 log_task_aborted("controller");
685 }
686 })
687 .await
688 {
689 Ok(()) => {
690 log_task_stopped("controller");
691 }
692 Err(_) => {
693 tracing::error!("Timeout waiting for controller task to finish");
694 }
695 }
696 }
697
698 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
704 if self.is_closed() {
705 return Err(SendError::Closed);
706 }
707
708 let timeout = Duration::from_secs(2);
709 let check_interval = Duration::from_millis(1);
710
711 if !self.is_active() {
712 tracing::debug!("Waiting for client to become ACTIVE before sending...");
713
714 let inner = tokio::time::timeout(timeout, async {
715 loop {
716 if self.is_active() {
717 return Ok(());
718 }
719 if matches!(
720 self.connection_mode(),
721 ConnectionMode::Disconnect | ConnectionMode::Closed
722 ) {
723 return Err(());
724 }
725 tokio::time::sleep(check_interval).await;
726 }
727 })
728 .await
729 .map_err(|_| SendError::Timeout)?;
730 inner.map_err(|()| SendError::Closed)?;
731 }
732
733 let msg = WriterCommand::Send(data.into());
734 self.writer_tx
735 .send(msg)
736 .map_err(|e| SendError::BrokenPipe(e.to_string()))
737 }
738
739 fn spawn_controller_task(
740 mut inner: SocketClientInner,
741 connection_mode: Arc<AtomicU8>,
742 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
743 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
744 ) -> tokio::task::JoinHandle<()> {
745 tokio::task::spawn(async move {
746 log_task_started("controller");
747
748 let check_interval = Duration::from_millis(10);
749
750 loop {
751 tokio::time::sleep(check_interval).await;
752 let mode = ConnectionMode::from_atomic(&connection_mode);
753
754 if mode.is_disconnect() {
755 tracing::debug!("Disconnecting");
756
757 let timeout = Duration::from_secs(5);
758 if tokio::time::timeout(timeout, async {
759 tokio::time::sleep(Duration::from_millis(100)).await;
761
762 if !inner.read_task.is_finished() {
763 inner.read_task.abort();
764 log_task_aborted("read");
765 }
766
767 if let Some(task) = &inner.heartbeat_task {
768 if !task.is_finished() {
769 task.abort();
770 log_task_aborted("heartbeat");
771 }
772 }
773 })
774 .await
775 .is_err()
776 {
777 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
778 }
779
780 tracing::debug!("Closed");
781
782 #[cfg(feature = "python")]
783 if let Some(ref handler) = post_disconnection {
784 Python::with_gil(|py| match handler.call0(py) {
785 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
786 Err(e) => {
787 tracing::error!("Error calling `post_disconnection` handler: {e}");
788 }
789 });
790 }
791 break; }
793
794 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
795 match inner.reconnect().await {
796 Ok(()) => {
797 tracing::debug!("Reconnected successfully");
798 inner.backoff.reset();
799 #[cfg(feature = "python")]
801 {
802 if ConnectionMode::from_atomic(&connection_mode).is_active() {
803 if let Some(ref handler) = post_reconnection {
804 Python::with_gil(|py| match handler.call0(py) {
805 Ok(_) => tracing::debug!(
806 "Called `post_reconnection` handler"
807 ),
808 Err(e) => tracing::error!(
809 "Error calling `post_reconnection` handler: {e}"
810 ),
811 });
812 }
813 } else {
814 tracing::debug!(
815 "Skipping post_reconnection handlers due to disconnect state"
816 );
817 }
818 }
819 }
820 Err(e) => {
821 let duration = inner.backoff.next_duration();
822 tracing::warn!("Reconnect attempt failed: {e}");
823 if !duration.is_zero() {
824 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
825 }
826 tokio::time::sleep(duration).await;
827 }
828 }
829 }
830 }
831 inner
832 .connection_mode
833 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
834
835 log_task_stopped("controller");
836 })
837 }
838}
839
840impl Drop for SocketClient {
842 fn drop(&mut self) {
843 if !self.controller_task.is_finished() {
844 self.controller_task.abort();
845 log_task_aborted("controller");
846 }
847 }
848}
849
850#[cfg(test)]
854#[cfg(feature = "python")]
855#[cfg(target_os = "linux")] mod tests {
857 use std::ffi::CString;
858
859 use nautilus_common::testing::wait_until_async;
860 use nautilus_core::python::IntoPyObjectPoseiExt;
861 use pyo3::prepare_freethreaded_python;
862 use tokio::{
863 io::{AsyncReadExt, AsyncWriteExt},
864 net::{TcpListener, TcpStream},
865 sync::Mutex,
866 task,
867 time::{Duration, sleep},
868 };
869
870 use super::*;
871
872 fn create_handler() -> PyObject {
873 let code_raw = r"
874class Counter:
875 def __init__(self):
876 self.count = 0
877 self.check = False
878
879 def handler(self, bytes):
880 msg = bytes.decode()
881 if msg == 'ping':
882 self.count += 1
883 elif msg == 'heartbeat message':
884 self.check = True
885
886 def get_check(self):
887 return self.check
888
889 def get_count(self):
890 return self.count
891
892counter = Counter()
893";
894 let code = CString::new(code_raw).unwrap();
895 let filename = CString::new("test".to_string()).unwrap();
896 let module = CString::new("test".to_string()).unwrap();
897 Python::with_gil(|py| {
898 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
899 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
900
901 counter
902 .getattr(py, "handler")
903 .unwrap()
904 .into_py_any_unwrap(py)
905 })
906 }
907
908 async fn bind_test_server() -> (u16, TcpListener) {
909 let listener = TcpListener::bind("127.0.0.1:0")
910 .await
911 .expect("Failed to bind ephemeral port");
912 let port = listener.local_addr().unwrap().port();
913 (port, listener)
914 }
915
916 async fn run_echo_server(mut socket: TcpStream) {
917 let mut buf = Vec::new();
918 loop {
919 match socket.read_buf(&mut buf).await {
920 Ok(0) => {
921 break;
922 }
923 Ok(_n) => {
924 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
925 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
926 line.truncate(line.len() - 2);
928
929 if line == b"close" {
930 let _ = socket.shutdown().await;
931 return;
932 }
933
934 let mut echo_data = line;
935 echo_data.extend_from_slice(b"\r\n");
936 if socket.write_all(&echo_data).await.is_err() {
937 break;
938 }
939 }
940 }
941 Err(e) => {
942 eprintln!("Server read error: {e}");
943 break;
944 }
945 }
946 }
947 }
948
949 #[tokio::test]
950 async fn test_basic_send_receive() {
951 prepare_freethreaded_python();
952
953 let (port, listener) = bind_test_server().await;
954 let server_task = task::spawn(async move {
955 let (socket, _) = listener.accept().await.unwrap();
956 run_echo_server(socket).await;
957 });
958
959 let config = SocketConfig {
960 url: format!("127.0.0.1:{port}"),
961 mode: Mode::Plain,
962 suffix: b"\r\n".to_vec(),
963 py_handler: Some(Arc::new(create_handler())),
964 heartbeat: None,
965 reconnect_timeout_ms: None,
966 reconnect_delay_initial_ms: None,
967 reconnect_backoff_factor: None,
968 reconnect_delay_max_ms: None,
969 reconnect_jitter_ms: None,
970 certs_dir: None,
971 };
972
973 let client = SocketClient::connect(config, None, None, None, None)
974 .await
975 .expect("Client connect failed unexpectedly");
976
977 client.send_bytes(b"Hello".into()).await.unwrap();
978 client.send_bytes(b"World".into()).await.unwrap();
979
980 sleep(Duration::from_millis(100)).await;
982
983 client.send_bytes(b"close".into()).await.unwrap();
984 server_task.await.unwrap();
985 assert!(!client.is_closed());
986 }
987
988 #[tokio::test]
989 async fn test_reconnect_fail_exhausted() {
990 prepare_freethreaded_python();
991
992 let (port, listener) = bind_test_server().await;
993 drop(listener); let config = SocketConfig {
996 url: format!("127.0.0.1:{port}"),
997 mode: Mode::Plain,
998 suffix: b"\r\n".to_vec(),
999 py_handler: Some(Arc::new(create_handler())),
1000 heartbeat: None,
1001 reconnect_timeout_ms: None,
1002 reconnect_delay_initial_ms: None,
1003 reconnect_backoff_factor: None,
1004 reconnect_delay_max_ms: None,
1005 reconnect_jitter_ms: None,
1006 certs_dir: None,
1007 };
1008
1009 let client_res = SocketClient::connect(config, None, None, None, None).await;
1010 assert!(
1011 client_res.is_err(),
1012 "Should fail quickly with no server listening"
1013 );
1014 }
1015
1016 #[tokio::test]
1017 async fn test_user_disconnect() {
1018 prepare_freethreaded_python();
1019
1020 let (port, listener) = bind_test_server().await;
1021 let server_task = task::spawn(async move {
1022 let (socket, _) = listener.accept().await.unwrap();
1023 let mut buf = [0u8; 1024];
1024 let _ = socket.try_read(&mut buf);
1025
1026 loop {
1027 sleep(Duration::from_secs(1)).await;
1028 }
1029 });
1030
1031 let config = SocketConfig {
1032 url: format!("127.0.0.1:{port}"),
1033 mode: Mode::Plain,
1034 suffix: b"\r\n".to_vec(),
1035 py_handler: Some(Arc::new(create_handler())),
1036 heartbeat: None,
1037 reconnect_timeout_ms: None,
1038 reconnect_delay_initial_ms: None,
1039 reconnect_backoff_factor: None,
1040 reconnect_delay_max_ms: None,
1041 reconnect_jitter_ms: None,
1042 certs_dir: None,
1043 };
1044
1045 let client = SocketClient::connect(config, None, None, None, None)
1046 .await
1047 .unwrap();
1048
1049 client.close().await;
1050 assert!(client.is_closed());
1051 server_task.abort();
1052 }
1053
1054 #[tokio::test]
1055 async fn test_heartbeat() {
1056 prepare_freethreaded_python();
1057
1058 let (port, listener) = bind_test_server().await;
1059 let received = Arc::new(Mutex::new(Vec::new()));
1060 let received2 = received.clone();
1061
1062 let server_task = task::spawn(async move {
1063 let (socket, _) = listener.accept().await.unwrap();
1064
1065 let mut buf = Vec::new();
1066 loop {
1067 match socket.try_read_buf(&mut buf) {
1068 Ok(0) => break,
1069 Ok(_) => {
1070 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1071 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1072 line.truncate(line.len() - 2);
1073 received2.lock().await.push(line);
1074 }
1075 }
1076 Err(_) => {
1077 tokio::time::sleep(Duration::from_millis(10)).await;
1078 }
1079 }
1080 }
1081 });
1082
1083 let heartbeat = Some((1, b"ping".to_vec()));
1085
1086 let config = SocketConfig {
1087 url: format!("127.0.0.1:{port}"),
1088 mode: Mode::Plain,
1089 suffix: b"\r\n".to_vec(),
1090 py_handler: Some(Arc::new(create_handler())),
1091 heartbeat,
1092 reconnect_timeout_ms: None,
1093 reconnect_delay_initial_ms: None,
1094 reconnect_backoff_factor: None,
1095 reconnect_delay_max_ms: None,
1096 reconnect_jitter_ms: None,
1097 certs_dir: None,
1098 };
1099
1100 let client = SocketClient::connect(config, None, None, None, None)
1101 .await
1102 .unwrap();
1103
1104 sleep(Duration::from_secs(3)).await;
1106
1107 {
1108 let lock = received.lock().await;
1109 let pings = lock
1110 .iter()
1111 .filter(|line| line == &&b"ping".to_vec())
1112 .count();
1113 assert!(
1114 pings >= 2,
1115 "Expected at least 2 heartbeat pings; got {pings}"
1116 );
1117 }
1118
1119 client.close().await;
1120 server_task.abort();
1121 }
1122
1123 #[tokio::test]
1124 async fn test_python_handler_error() {
1125 prepare_freethreaded_python();
1126
1127 let (port, listener) = bind_test_server().await;
1128 let server_task = task::spawn(async move {
1129 let (socket, _) = listener.accept().await.unwrap();
1130 run_echo_server(socket).await;
1131 });
1132
1133 let code_raw = r#"
1134def handler(bytes_data):
1135 txt = bytes_data.decode()
1136 if "ERR" in txt:
1137 raise ValueError("Simulated error in handler")
1138 return
1139"#;
1140 let code = CString::new(code_raw).unwrap();
1141 let filename = CString::new("test".to_string()).unwrap();
1142 let module = CString::new("test".to_string()).unwrap();
1143
1144 let py_handler = Some(Python::with_gil(|py| {
1145 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1146 let func = pymod.getattr("handler").unwrap();
1147 Arc::new(func.into_py_any_unwrap(py))
1148 }));
1149
1150 let config = SocketConfig {
1151 url: format!("127.0.0.1:{port}"),
1152 mode: Mode::Plain,
1153 suffix: b"\r\n".to_vec(),
1154 py_handler,
1155 heartbeat: None,
1156 reconnect_timeout_ms: None,
1157 reconnect_delay_initial_ms: None,
1158 reconnect_backoff_factor: None,
1159 reconnect_delay_max_ms: None,
1160 reconnect_jitter_ms: None,
1161 certs_dir: None,
1162 };
1163
1164 let client = SocketClient::connect(config, None, None, None, None)
1165 .await
1166 .expect("Client connect failed unexpectedly");
1167
1168 client.send_bytes(b"hello".into()).await.unwrap();
1169 sleep(Duration::from_millis(100)).await;
1170
1171 client.send_bytes(b"ERR".into()).await.unwrap();
1172 sleep(Duration::from_secs(1)).await;
1173
1174 assert!(client.is_active());
1175
1176 client.close().await;
1177
1178 assert!(client.is_closed());
1179 server_task.abort();
1180 }
1181
1182 #[tokio::test]
1183 async fn test_reconnect_success() {
1184 prepare_freethreaded_python();
1185
1186 let (port, listener) = bind_test_server().await;
1187
1188 let server_task = task::spawn(async move {
1192 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1194
1195 sleep(Duration::from_millis(500)).await;
1197 let _ = socket.shutdown().await;
1198
1199 sleep(Duration::from_millis(500)).await;
1201
1202 let (socket, _) = listener.accept().await.expect("Second accept failed");
1204 run_echo_server(socket).await;
1205 });
1206
1207 let config = SocketConfig {
1208 url: format!("127.0.0.1:{port}"),
1209 mode: Mode::Plain,
1210 suffix: b"\r\n".to_vec(),
1211 py_handler: Some(Arc::new(create_handler())),
1212 heartbeat: None,
1213 reconnect_timeout_ms: Some(5_000),
1214 reconnect_delay_initial_ms: Some(500),
1215 reconnect_delay_max_ms: Some(5_000),
1216 reconnect_backoff_factor: Some(2.0),
1217 reconnect_jitter_ms: Some(50),
1218 certs_dir: None,
1219 };
1220
1221 let client = SocketClient::connect(config, None, None, None, None)
1222 .await
1223 .expect("Client connect failed unexpectedly");
1224
1225 assert!(client.is_active(), "Client should start as active");
1227
1228 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1231
1232 client
1233 .send_bytes(b"TestReconnect".into())
1234 .await
1235 .expect("Send failed");
1236
1237 client.close().await;
1238 server_task.abort();
1239 }
1240}
1241
1242#[cfg(test)]
1243mod rust_tests {
1244 use tokio::{
1245 net::TcpListener,
1246 task,
1247 time::{Duration, sleep},
1248 };
1249
1250 use super::*;
1251
1252 #[tokio::test]
1253 async fn test_reconnect_then_close() {
1254 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1256 let port = listener.local_addr().unwrap().port();
1257
1258 let server = task::spawn(async move {
1260 if let Ok((mut sock, _)) = listener.accept().await {
1261 let _ = sock.shutdown();
1262 }
1263 sleep(Duration::from_secs(1)).await;
1265 });
1266
1267 let config = SocketConfig {
1269 url: format!("127.0.0.1:{port}"),
1270 mode: Mode::Plain,
1271 suffix: b"\r\n".to_vec(),
1272 #[cfg(feature = "python")]
1273 py_handler: None,
1274 heartbeat: None,
1275 reconnect_timeout_ms: Some(1_000),
1276 reconnect_delay_initial_ms: Some(50),
1277 reconnect_delay_max_ms: Some(100),
1278 reconnect_backoff_factor: Some(1.0),
1279 reconnect_jitter_ms: Some(0),
1280 certs_dir: None,
1281 };
1282
1283 let client = {
1285 #[cfg(feature = "python")]
1286 {
1287 SocketClient::connect(config.clone(), None, None, None, None)
1288 .await
1289 .unwrap()
1290 }
1291 #[cfg(not(feature = "python"))]
1292 {
1293 SocketClient::connect(config.clone(), None).await.unwrap()
1294 }
1295 };
1296
1297 sleep(Duration::from_millis(100)).await;
1299
1300 client.close().await;
1302 assert!(client.is_closed());
1303 server.abort();
1304 }
1305}