nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    fmt::Debug,
19    sync::{
20        Arc,
21        atomic::{AtomicBool, Ordering},
22    },
23    time::{Duration, Instant},
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29    logging::{log_task_error, log_task_started, log_task_stopped},
30    msgbus::{
31        BusMessage,
32        database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
33        switchboard::CLOSE_TOPIC,
34    },
35    runtime::get_runtime,
36};
37use nautilus_core::{
38    UUID4,
39    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use ustr::Ustr;
46
47use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
48use crate::redis::{create_redis_connection, get_stream_key};
49
50const MSGBUS_PUBLISH: &str = "msgbus-publish";
51const MSGBUS_STREAM: &str = "msgbus-stream";
52const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
53const HEARTBEAT_TOPIC: &str = "health:heartbeat";
54const TRIM_BUFFER_SECS: u64 = 60;
55
56type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
57
58#[cfg_attr(
59    feature = "python",
60    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.infrastructure")
61)]
62pub struct RedisMessageBusDatabase {
63    /// The trader ID for this message bus database.
64    pub trader_id: TraderId,
65    /// The instance ID for this message bus database.
66    pub instance_id: UUID4,
67    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
68    pub_handle: Option<tokio::task::JoinHandle<()>>,
69    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
70    stream_handle: Option<tokio::task::JoinHandle<()>>,
71    stream_signal: Arc<AtomicBool>,
72    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
73    heartbeat_signal: Arc<AtomicBool>,
74}
75
76impl Debug for RedisMessageBusDatabase {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct(stringify!(RedisMessageBusDatabase))
79            .field("trader_id", &self.trader_id)
80            .field("instance_id", &self.instance_id)
81            .finish()
82    }
83}
84
85impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
86    type DatabaseType = Self;
87
88    /// Creates a new [`RedisMessageBusDatabase`] instance for the given `trader_id`, `instance_id`, and `config`.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if:
93    /// - The database configuration is missing in `config`.
94    /// - Establishing the Redis connection for publishing fails.
95    fn new(
96        trader_id: TraderId,
97        instance_id: UUID4,
98        config: MessageBusConfig,
99    ) -> anyhow::Result<Self> {
100        install_cryptographic_provider();
101
102        let config_clone = config.clone();
103        let db_config = config
104            .database
105            .clone()
106            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
107
108        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
109
110        // Create publish task (start the runtime here for now)
111        let pub_handle = Some(get_runtime().spawn(async move {
112            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
113                log_task_error(MSGBUS_PUBLISH, &e);
114            }
115        }));
116
117        // Conditionally create stream task and channel if external streams configured
118        let external_streams = config.external_streams.clone().unwrap_or_default();
119        let stream_signal = Arc::new(AtomicBool::new(false));
120        let (stream_rx, stream_handle) = if external_streams.is_empty() {
121            (None, None)
122        } else {
123            let stream_signal_clone = stream_signal.clone();
124            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
125            (
126                Some(stream_rx),
127                Some(get_runtime().spawn(async move {
128                    if let Err(e) =
129                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
130                            .await
131                    {
132                        log_task_error(MSGBUS_STREAM, &e);
133                    }
134                })),
135            )
136        };
137
138        // Create heartbeat task
139        let heartbeat_signal = Arc::new(AtomicBool::new(false));
140        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
141        {
142            let signal = heartbeat_signal.clone();
143            let pub_tx_clone = pub_tx.clone();
144
145            Some(get_runtime().spawn(async move {
146                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
147            }))
148        } else {
149            None
150        };
151
152        Ok(Self {
153            trader_id,
154            instance_id,
155            pub_tx,
156            pub_handle,
157            stream_rx,
158            stream_handle,
159            stream_signal,
160            heartbeat_handle,
161            heartbeat_signal,
162        })
163    }
164
165    /// Returns whether the message bus database adapter publishing channel is closed.
166    fn is_closed(&self) -> bool {
167        self.pub_tx.is_closed()
168    }
169
170    /// Publishes a message with the given `topic` and `payload`.
171    fn publish(&self, topic: Ustr, payload: Bytes) {
172        let msg = BusMessage::new(topic, payload);
173        if let Err(e) = self.pub_tx.send(msg) {
174            log::error!("Failed to send message: {e}");
175        }
176    }
177
178    /// Closes the message bus database adapter.
179    fn close(&mut self) {
180        log::debug!("Closing");
181
182        self.stream_signal.store(true, Ordering::Relaxed);
183        self.heartbeat_signal.store(true, Ordering::Relaxed);
184
185        if !self.pub_tx.is_closed() {
186            let msg = BusMessage::new_close();
187
188            if let Err(e) = self.pub_tx.send(msg) {
189                log::error!("Failed to send close message: {e:?}");
190            }
191        }
192
193        // Keep close sync for now to avoid async trait method
194        tokio::task::block_in_place(|| {
195            get_runtime().block_on(async {
196                self.close_async().await;
197            });
198        });
199
200        log::debug!("Closed");
201    }
202}
203
204impl RedisMessageBusDatabase {
205    /// Retrieves the Redis stream receiver for this message bus instance.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the stream receiver has already been taken.
210    pub fn get_stream_receiver(
211        &mut self,
212    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
213        self.stream_rx
214            .take()
215            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
216    }
217
218    /// Streams messages arriving on the stream receiver channel.
219    pub fn stream(
220        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
221    ) -> impl Stream<Item = BusMessage> + 'static {
222        async_stream::stream! {
223            while let Some(msg) = stream_rx.recv().await {
224                yield msg;
225            }
226        }
227    }
228
229    pub async fn close_async(&mut self) {
230        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
231        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
232        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
233    }
234}
235
236/// Publishes messages received on `rx` to Redis streams for the given `trader_id` and `instance_id`, using `config`.
237///
238/// # Errors
239///
240/// Returns an error if:
241/// - The database configuration is missing in `config`.
242/// - Establishing the Redis connection fails.
243/// - Any Redis command fails during publishing.
244pub async fn publish_messages(
245    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
246    trader_id: TraderId,
247    instance_id: UUID4,
248    config: MessageBusConfig,
249) -> anyhow::Result<()> {
250    log_task_started(MSGBUS_PUBLISH);
251
252    let db_config = config
253        .database
254        .as_ref()
255        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
256    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
257    let stream_key = get_stream_key(trader_id, instance_id, &config);
258
259    // Auto-trimming
260    let autotrim_duration = config
261        .autotrim_mins
262        .filter(|&mins| mins > 0)
263        .map(|mins| Duration::from_secs(u64::from(mins) * 60));
264    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
265
266    // Buffering
267    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
268    let mut last_drain = Instant::now();
269    let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
270
271    loop {
272        if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
273            drain_buffer(
274                &mut con,
275                &stream_key,
276                config.stream_per_topic,
277                autotrim_duration,
278                &mut last_trim_index,
279                &mut buffer,
280            )
281            .await?;
282            last_drain = Instant::now();
283        } else if let Some(msg) = rx.recv().await {
284            if msg.topic == CLOSE_TOPIC {
285                tracing::debug!("Received close message");
286                drop(rx);
287                break;
288            }
289            buffer.push_back(msg);
290        } else {
291            tracing::debug!("Channel hung up");
292            break;
293        }
294    }
295
296    // Drain any remaining messages
297    if !buffer.is_empty() {
298        drain_buffer(
299            &mut con,
300            &stream_key,
301            config.stream_per_topic,
302            autotrim_duration,
303            &mut last_trim_index,
304            &mut buffer,
305        )
306        .await?;
307    }
308
309    log_task_stopped(MSGBUS_PUBLISH);
310    Ok(())
311}
312
313async fn drain_buffer(
314    conn: &mut redis::aio::ConnectionManager,
315    stream_key: &str,
316    stream_per_topic: bool,
317    autotrim_duration: Option<Duration>,
318    last_trim_index: &mut HashMap<String, usize>,
319    buffer: &mut VecDeque<BusMessage>,
320) -> anyhow::Result<()> {
321    let mut pipe = redis::pipe();
322    pipe.atomic();
323
324    for msg in buffer.drain(..) {
325        let items: Vec<(&str, &[u8])> = vec![
326            ("topic", msg.topic.as_ref()),
327            ("payload", msg.payload.as_ref()),
328        ];
329        let stream_key = if stream_per_topic {
330            format!("{stream_key}:{}", &msg.topic)
331        } else {
332            stream_key.to_string()
333        };
334        pipe.xadd(&stream_key, "*", &items);
335
336        if autotrim_duration.is_none() {
337            continue; // Nothing else to do
338        }
339
340        // Autotrim stream
341        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
342        let unix_duration_now = duration_since_unix_epoch();
343        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
344
345        // Improve efficiency of this by batching
346        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
347            let min_timestamp_ms =
348                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
349            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
350                .arg(stream_key.clone())
351                .arg(REDIS_MINID)
352                .arg(min_timestamp_ms)
353                .query_async(conn)
354                .await;
355
356            if let Err(e) = result {
357                tracing::error!("Error trimming stream '{stream_key}': {e}");
358            } else {
359                last_trim_index.insert(
360                    stream_key.to_string(),
361                    unix_duration_now.as_millis() as usize,
362                );
363            }
364        }
365    }
366
367    pipe.query_async(conn).await.map_err(anyhow::Error::from)
368}
369
370/// Streams messages from Redis streams and sends them over the provided `tx` channel.
371///
372/// # Errors
373///
374/// Returns an error if:
375/// - Establishing the Redis connection fails.
376/// - Any Redis read operation fails.
377pub async fn stream_messages(
378    tx: tokio::sync::mpsc::Sender<BusMessage>,
379    config: DatabaseConfig,
380    stream_keys: Vec<String>,
381    stream_signal: Arc<AtomicBool>,
382) -> anyhow::Result<()> {
383    log_task_started(MSGBUS_STREAM);
384
385    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
386
387    let stream_keys = &stream_keys
388        .iter()
389        .map(String::as_str)
390        .collect::<Vec<&str>>();
391
392    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
393
394    // Start streaming from current timestamp
395    let clock = get_atomic_clock_realtime();
396    let timestamp_ms = clock.get_time_ms();
397    let mut last_id = timestamp_ms.to_string();
398
399    let opts = StreamReadOptions::default().block(100);
400
401    'outer: loop {
402        if stream_signal.load(Ordering::Relaxed) {
403            tracing::debug!("Received streaming terminate signal");
404            break;
405        }
406        let result: Result<RedisStreamBulk, _> =
407            con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
408        match result {
409            Ok(stream_bulk) => {
410                if stream_bulk.is_empty() {
411                    // Timeout occurred: no messages received
412                    continue;
413                }
414                for entry in &stream_bulk {
415                    for stream_msgs in entry.values() {
416                        for stream_msg in stream_msgs {
417                            for (id, array) in stream_msg {
418                                last_id.clear();
419                                last_id.push_str(id);
420                                match decode_bus_message(array) {
421                                    Ok(msg) => {
422                                        if let Err(e) = tx.send(msg).await {
423                                            tracing::debug!("Channel closed: {e:?}");
424                                            break 'outer; // End streaming
425                                        }
426                                    }
427                                    Err(e) => {
428                                        tracing::error!("{e:?}");
429                                        continue;
430                                    }
431                                }
432                            }
433                        }
434                    }
435                }
436            }
437            Err(e) => {
438                anyhow::bail!("Error reading from stream: {e:?}");
439            }
440        }
441    }
442
443    log_task_stopped(MSGBUS_STREAM);
444    Ok(())
445}
446
447/// Decodes a Redis stream message value into a `BusMessage`.
448///
449/// # Errors
450///
451/// Returns an error if:
452/// - The incoming `stream_msg` is not an array.
453/// - The array has fewer than four elements (invalid format).
454/// - Parsing the topic or payload fails.
455fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
456    if let redis::Value::Array(stream_msg) = stream_msg {
457        if stream_msg.len() < 4 {
458            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
459        }
460
461        let topic = match &stream_msg[1] {
462            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
463                Ok(topic) => topic,
464                Err(e) => anyhow::bail!("Error parsing topic: {e}"),
465            },
466            _ => {
467                anyhow::bail!("Invalid topic format: {stream_msg:?}");
468            }
469        };
470
471        let payload = match &stream_msg[3] {
472            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
473            _ => {
474                anyhow::bail!("Invalid payload format: {stream_msg:?}");
475            }
476        };
477
478        Ok(BusMessage::with_str_topic(topic, payload))
479    } else {
480        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
481    }
482}
483
484async fn run_heartbeat(
485    heartbeat_interval_secs: u16,
486    signal: Arc<AtomicBool>,
487    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
488) {
489    log_task_started("heartbeat");
490    tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
491
492    let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
493    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
494
495    let check_interval = Duration::from_millis(100);
496    let check_timer = tokio::time::interval(check_interval);
497
498    tokio::pin!(heartbeat_timer);
499    tokio::pin!(check_timer);
500
501    loop {
502        if signal.load(Ordering::Relaxed) {
503            tracing::debug!("Received heartbeat terminate signal");
504            break;
505        }
506
507        tokio::select! {
508            _ = heartbeat_timer.tick() => {
509                let heartbeat = create_heartbeat_msg();
510                if let Err(e) = pub_tx.send(heartbeat) {
511                    // We expect an error if the channel is closed during shutdown
512                    tracing::debug!("Error sending heartbeat: {e}");
513                }
514            },
515            _ = check_timer.tick() => {}
516        }
517    }
518
519    log_task_stopped("heartbeat");
520}
521
522fn create_heartbeat_msg() -> BusMessage {
523    let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
524    BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
525}
526
527////////////////////////////////////////////////////////////////////////////////
528// Tests
529////////////////////////////////////////////////////////////////////////////////
530#[cfg(test)]
531mod tests {
532    use redis::Value;
533    use rstest::*;
534
535    use super::*;
536
537    #[rstest]
538    fn test_decode_bus_message_valid() {
539        let stream_msg = Value::Array(vec![
540            Value::BulkString(b"0".to_vec()),
541            Value::BulkString(b"topic1".to_vec()),
542            Value::BulkString(b"unused".to_vec()),
543            Value::BulkString(b"data1".to_vec()),
544        ]);
545
546        let result = decode_bus_message(&stream_msg);
547        assert!(result.is_ok());
548        let msg = result.unwrap();
549        assert_eq!(msg.topic, "topic1");
550        assert_eq!(msg.payload, Bytes::from("data1"));
551    }
552
553    #[rstest]
554    fn test_decode_bus_message_missing_fields() {
555        let stream_msg = Value::Array(vec![
556            Value::BulkString(b"0".to_vec()),
557            Value::BulkString(b"topic1".to_vec()),
558        ]);
559
560        let result = decode_bus_message(&stream_msg);
561        assert!(result.is_err());
562        assert_eq!(
563            format!("{}", result.unwrap_err()),
564            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
565        );
566    }
567
568    #[rstest]
569    fn test_decode_bus_message_invalid_topic_format() {
570        let stream_msg = Value::Array(vec![
571            Value::BulkString(b"0".to_vec()),
572            Value::Int(42), // Invalid topic format
573            Value::BulkString(b"unused".to_vec()),
574            Value::BulkString(b"data1".to_vec()),
575        ]);
576
577        let result = decode_bus_message(&stream_msg);
578        assert!(result.is_err());
579        assert_eq!(
580            format!("{}", result.unwrap_err()),
581            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
582        );
583    }
584
585    #[rstest]
586    fn test_decode_bus_message_invalid_payload_format() {
587        let stream_msg = Value::Array(vec![
588            Value::BulkString(b"0".to_vec()),
589            Value::BulkString(b"topic1".to_vec()),
590            Value::BulkString(b"unused".to_vec()),
591            Value::Int(42), // Invalid payload format
592        ]);
593
594        let result = decode_bus_message(&stream_msg);
595        assert!(result.is_err());
596        assert_eq!(
597            format!("{}", result.unwrap_err()),
598            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
599        );
600    }
601
602    #[rstest]
603    fn test_decode_bus_message_invalid_stream_msg_format() {
604        let stream_msg = Value::BulkString(b"not an array".to_vec());
605
606        let result = decode_bus_message(&stream_msg);
607        assert!(result.is_err());
608        assert_eq!(
609            format!("{}", result.unwrap_err()),
610            "Invalid stream message format: bulk-string('\"not an array\"')"
611        );
612    }
613}
614
615#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
616#[cfg(test)]
617mod serial_tests {
618    use nautilus_common::testing::wait_until_async;
619    use redis::aio::ConnectionManager;
620    use rstest::*;
621
622    use super::*;
623    use crate::redis::flush_redis;
624
625    #[fixture]
626    async fn redis_connection() -> ConnectionManager {
627        let config = DatabaseConfig::default();
628        let mut con = create_redis_connection(MSGBUS_STREAM, config)
629            .await
630            .unwrap();
631        flush_redis(&mut con).await.unwrap();
632        con
633    }
634
635    #[rstest]
636    #[tokio::test(flavor = "multi_thread")]
637    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
638        let mut con = redis_connection.await;
639        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
640
641        let trader_id = TraderId::from("tester-001");
642        let instance_id = UUID4::new();
643        let mut config = MessageBusConfig::default();
644        config.database = Some(DatabaseConfig::default());
645
646        let stream_key = get_stream_key(trader_id, instance_id, &config);
647        let external_streams = vec![stream_key.clone()];
648        let stream_signal = Arc::new(AtomicBool::new(false));
649        let stream_signal_clone = stream_signal.clone();
650
651        // Start the message streaming task
652        let handle = tokio::spawn(async move {
653            stream_messages(
654                tx,
655                DatabaseConfig::default(),
656                external_streams,
657                stream_signal_clone,
658            )
659            .await
660            .unwrap();
661        });
662
663        stream_signal.store(true, Ordering::Relaxed);
664        let _ = rx.recv().await; // Wait for the tx to close
665
666        // Shutdown and cleanup
667        rx.close();
668        handle.await.unwrap();
669        flush_redis(&mut con).await.unwrap();
670    }
671
672    #[rstest]
673    #[tokio::test(flavor = "multi_thread")]
674    async fn test_stream_messages_when_receiver_closed(
675        #[future] redis_connection: ConnectionManager,
676    ) {
677        let mut con = redis_connection.await;
678        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
679
680        let trader_id = TraderId::from("tester-001");
681        let instance_id = UUID4::new();
682        let mut config = MessageBusConfig::default();
683        config.database = Some(DatabaseConfig::default());
684
685        let stream_key = get_stream_key(trader_id, instance_id, &config);
686        let external_streams = vec![stream_key.clone()];
687        let stream_signal = Arc::new(AtomicBool::new(false));
688        let stream_signal_clone = stream_signal.clone();
689
690        // Use a message ID in the future, as streaming begins
691        // around the timestamp the task is spawned.
692        let clock = get_atomic_clock_realtime();
693        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
694
695        // Publish test message
696        let _: () = con
697            .xadd(
698                stream_key,
699                future_id,
700                &[("topic", "topic1"), ("payload", "data1")],
701            )
702            .await
703            .unwrap();
704
705        // Immediately close channel
706        rx.close();
707
708        // Start the message streaming task
709        let handle = tokio::spawn(async move {
710            stream_messages(
711                tx,
712                DatabaseConfig::default(),
713                external_streams,
714                stream_signal_clone,
715            )
716            .await
717            .unwrap();
718        });
719
720        // Shutdown and cleanup
721        handle.await.unwrap();
722        flush_redis(&mut con).await.unwrap();
723    }
724
725    #[rstest]
726    #[tokio::test(flavor = "multi_thread")]
727    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
728        let mut con = redis_connection.await;
729        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
730
731        let trader_id = TraderId::from("tester-001");
732        let instance_id = UUID4::new();
733        let mut config = MessageBusConfig::default();
734        config.database = Some(DatabaseConfig::default());
735
736        let stream_key = get_stream_key(trader_id, instance_id, &config);
737        let external_streams = vec![stream_key.clone()];
738        let stream_signal = Arc::new(AtomicBool::new(false));
739        let stream_signal_clone = stream_signal.clone();
740
741        // Use a message ID in the future, as streaming begins
742        // around the timestamp the task is spawned.
743        let clock = get_atomic_clock_realtime();
744        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
745
746        // Publish test message
747        let _: () = con
748            .xadd(
749                stream_key,
750                future_id,
751                &[("topic", "topic1"), ("payload", "data1")],
752            )
753            .await
754            .unwrap();
755
756        // Start the message streaming task
757        let handle = tokio::spawn(async move {
758            stream_messages(
759                tx,
760                DatabaseConfig::default(),
761                external_streams,
762                stream_signal_clone,
763            )
764            .await
765            .unwrap();
766        });
767
768        // Receive and verify the message
769        let msg = rx.recv().await.unwrap();
770        assert_eq!(msg.topic, "topic1");
771        assert_eq!(msg.payload, Bytes::from("data1"));
772
773        // Shutdown and cleanup
774        rx.close();
775        stream_signal.store(true, Ordering::Relaxed);
776        handle.await.unwrap();
777        flush_redis(&mut con).await.unwrap();
778    }
779
780    #[rstest]
781    #[tokio::test(flavor = "multi_thread")]
782    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
783        let mut con = redis_connection.await;
784        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
785
786        let trader_id = TraderId::from("tester-001");
787        let instance_id = UUID4::new();
788        let mut config = MessageBusConfig::default();
789        config.database = Some(DatabaseConfig::default());
790        config.stream_per_topic = false;
791        let stream_key = get_stream_key(trader_id, instance_id, &config);
792
793        // Start the publish_messages task
794        let handle = tokio::spawn(async move {
795            publish_messages(rx, trader_id, instance_id, config)
796                .await
797                .unwrap();
798        });
799
800        // Send a test message
801        let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
802        tx.send(msg).unwrap();
803
804        // Wait until the message is published to Redis
805        wait_until_async(
806            || {
807                let mut con = con.clone();
808                let stream_key = stream_key.clone();
809                async move {
810                    let messages: RedisStreamBulk =
811                        con.xread(&[&stream_key], &["0"]).await.unwrap();
812                    !messages.is_empty()
813                }
814            },
815            Duration::from_secs(3),
816        )
817        .await;
818
819        // Verify the message was published to Redis
820        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
821        assert_eq!(messages.len(), 1);
822        let stream_msgs = messages[0].get(&stream_key).unwrap();
823        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
824        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
825        assert_eq!(decoded_message.topic, "test_topic");
826        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
827
828        // Stop publishing task
829        let msg = BusMessage::new_close();
830        tx.send(msg).unwrap();
831
832        // Shutdown and cleanup
833        handle.await.unwrap();
834        flush_redis(&mut con).await.unwrap();
835    }
836
837    #[rstest]
838    #[tokio::test(flavor = "multi_thread")]
839    async fn test_close() {
840        let trader_id = TraderId::from("tester-001");
841        let instance_id = UUID4::new();
842        let mut config = MessageBusConfig::default();
843        config.database = Some(DatabaseConfig::default());
844
845        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
846
847        // Close the message bus database (test should not hang)
848        db.close();
849    }
850
851    #[rstest]
852    #[tokio::test(flavor = "multi_thread")]
853    async fn test_heartbeat_task() {
854        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
855        let signal = Arc::new(AtomicBool::new(false));
856
857        // Start the heartbeat task with a short interval
858        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
859
860        // Wait for a couple of heartbeats
861        tokio::time::sleep(Duration::from_secs(2)).await;
862
863        // Stop the heartbeat task
864        signal.store(true, Ordering::Relaxed);
865        handle.await.unwrap();
866
867        // Ensure heartbeats were sent
868        let mut heartbeats: Vec<BusMessage> = Vec::new();
869        while let Ok(hb) = rx.try_recv() {
870            heartbeats.push(hb);
871        }
872
873        assert!(!heartbeats.is_empty());
874
875        for hb in heartbeats {
876            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
877        }
878    }
879}