nautilus_infrastructure/redis/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Provides a Redis backed `CacheDatabase` and `MessageBusDatabase` implementation.
17
18pub mod cache;
19pub mod msgbus;
20pub mod queries;
21
22use std::time::Duration;
23
24use nautilus_common::{
25    logging::log_task_awaiting,
26    msgbus::database::{DatabaseConfig, MessageBusConfig},
27};
28use nautilus_core::UUID4;
29use nautilus_model::identifiers::TraderId;
30use redis::RedisError;
31use semver::Version;
32
33const REDIS_MIN_VERSION: &str = "6.2.0";
34const REDIS_DELIMITER: char = ':';
35const REDIS_XTRIM: &str = "XTRIM";
36const REDIS_MINID: &str = "MINID";
37const REDIS_FLUSHDB: &str = "FLUSHDB";
38
39async fn await_handle(handle: Option<tokio::task::JoinHandle<()>>, task_name: &str) {
40    if let Some(handle) = handle {
41        log_task_awaiting(task_name);
42
43        let timeout = Duration::from_secs(2);
44        match tokio::time::timeout(timeout, handle).await {
45            Ok(result) => {
46                if let Err(e) = result {
47                    log::error!("Error awaiting task '{task_name}': {e:?}");
48                }
49            }
50            Err(_) => {
51                log::error!("Timeout {timeout:?} awaiting task '{task_name}'");
52            }
53        }
54    }
55}
56
57/// Parses a Redis connection URL from the given database config, returning the
58/// full URL and a redacted version with the password obfuscated.
59///
60/// Authentication matrix handled:
61/// ┌───────────┬───────────┬────────────────────────────┐
62/// │ Username  │ Password  │ Resulting user-info part   │
63/// ├───────────┼───────────┼────────────────────────────┤
64/// │ non-empty │ non-empty │ user:pass@                 │
65/// │ empty     │ non-empty │ :pass@                     │
66/// │ empty     │ empty     │ (omitted)                  │
67/// └───────────┴───────────┴────────────────────────────┘
68///
69/// # Panics
70///
71/// Panics if a username is provided without a corresponding password.
72#[must_use]
73pub fn get_redis_url(config: DatabaseConfig) -> (String, String) {
74    let host = config.host.unwrap_or("127.0.0.1".to_string());
75    let port = config.port.unwrap_or(6379);
76    let username = config.username.unwrap_or_default();
77    let password = config.password.unwrap_or_default();
78    let ssl = config.ssl;
79
80    // Redact the password for logging/metrics: keep the first & last two chars.
81    let redact_pw = |pw: &str| {
82        if pw.len() > 4 {
83            format!("{}...{}", &pw[..2], &pw[pw.len() - 2..])
84        } else {
85            pw.to_owned()
86        }
87    };
88
89    // Build the `userinfo@` portion for both the real and redacted URLs.
90    let (auth, auth_redacted) = match (username.is_empty(), password.is_empty()) {
91        // user:pass@
92        (false, false) => (
93            format!("{username}:{password}@"),
94            format!("{username}:{}@", redact_pw(&password)),
95        ),
96        // :pass@
97        (true, false) => (
98            format!(":{password}@"),
99            format!(":{}@", redact_pw(&password)),
100        ),
101        // username but no password ⇒  configuration error
102        (false, true) => panic!(
103            "Redis config error: username supplied without password. \
104            Either supply a password or omit the username."
105        ),
106        // no credentials
107        (true, true) => (String::new(), String::new()),
108    };
109
110    let scheme = if ssl { "rediss" } else { "redis" };
111
112    let url = format!("{scheme}://{auth}{host}:{port}");
113    let redacted_url = format!("{scheme}://{auth_redacted}{host}:{port}");
114
115    (url, redacted_url)
116}
117/// Creates a new Redis connection manager based on the provided database `config` and connection name.
118///
119/// # Errors
120///
121/// Returns an error if:
122/// - Constructing the Redis client fails.
123/// - Establishing or configuring the connection manager fails.
124///
125/// In case of reconnection issues, the connection will retry reconnection
126/// `number_of_retries` times, with an exponentially increasing delay, calculated as
127/// `rand(0 .. factor * (exponent_base ^ current-try))`.
128///
129/// The new connection will time out operations after `response_timeout` has passed.
130/// Each connection attempt to the server will time out after `connection_timeout`.
131pub async fn create_redis_connection(
132    con_name: &str,
133    config: DatabaseConfig,
134) -> anyhow::Result<redis::aio::ConnectionManager> {
135    tracing::debug!("Creating {con_name} redis connection");
136    let (redis_url, redacted_url) = get_redis_url(config.clone());
137    tracing::debug!("Connecting to {redacted_url}");
138
139    let connection_timeout = Duration::from_secs(u64::from(config.connection_timeout));
140    let response_timeout = Duration::from_secs(u64::from(config.response_timeout));
141    let number_of_retries = config.number_of_retries;
142    let exponent_base = config.exponent_base;
143    let factor = config.factor;
144
145    // into milliseconds
146    let max_delay = config.max_delay * 1000;
147
148    let client = redis::Client::open(redis_url)?;
149
150    let connection_manager_config = redis::aio::ConnectionManagerConfig::new()
151        .set_exponent_base(exponent_base)
152        .set_factor(factor)
153        .set_number_of_retries(number_of_retries)
154        .set_response_timeout(response_timeout)
155        .set_connection_timeout(connection_timeout)
156        .set_max_delay(max_delay);
157
158    let mut con = client
159        .get_connection_manager_with_config(connection_manager_config)
160        .await?;
161
162    let version = get_redis_version(&mut con).await?;
163    let min_version = Version::parse(REDIS_MIN_VERSION)?;
164    let con_msg = format!("Connected to redis v{version}");
165
166    if version >= min_version {
167        tracing::info!(con_msg);
168    } else {
169        // TODO: Using `log` error here so that the message is displayed regardless of whether
170        // the logging config has pyo3 enabled. Later we can standardize this to `tracing`.
171        log::error!("{con_msg}, but minimum supported version is {REDIS_MIN_VERSION}");
172    }
173
174    Ok(con)
175}
176
177/// Flushes the entire Redis database for the specified connection.
178///
179/// # Errors
180///
181/// Returns an error if the FLUSHDB command fails.
182pub async fn flush_redis(
183    con: &mut redis::aio::ConnectionManager,
184) -> anyhow::Result<(), RedisError> {
185    redis::cmd(REDIS_FLUSHDB).exec_async(con).await
186}
187
188/// Parse the stream key from the given identifiers and config.
189#[must_use]
190pub fn get_stream_key(
191    trader_id: TraderId,
192    instance_id: UUID4,
193    config: &MessageBusConfig,
194) -> String {
195    let mut stream_key = String::new();
196
197    if config.use_trader_prefix {
198        stream_key.push_str("trader-");
199    }
200
201    if config.use_trader_id {
202        stream_key.push_str(trader_id.as_str());
203        stream_key.push(REDIS_DELIMITER);
204    }
205
206    if config.use_instance_id {
207        stream_key.push_str(&format!("{instance_id}"));
208        stream_key.push(REDIS_DELIMITER);
209    }
210
211    stream_key.push_str(&config.streams_prefix);
212    stream_key
213}
214
215/// Retrieves and parses the Redis server version via the INFO command.
216///
217/// # Errors
218///
219/// Returns an error if the INFO command fails or version parsing fails.
220pub async fn get_redis_version(
221    conn: &mut redis::aio::ConnectionManager,
222) -> anyhow::Result<Version> {
223    let info: String = redis::cmd("INFO").query_async(conn).await?;
224    let version_str = match info.lines().find_map(|line| {
225        if line.starts_with("redis_version:") {
226            line.split(':').nth(1).map(|s| s.trim().to_string())
227        } else {
228            None
229        }
230    }) {
231        Some(info) => info,
232        None => {
233            anyhow::bail!("Redis version not available");
234        }
235    };
236
237    parse_redis_version(&version_str)
238}
239
240fn parse_redis_version(version_str: &str) -> anyhow::Result<Version> {
241    let mut components = version_str.split('.').map(str::parse::<u64>);
242
243    let major = components.next().unwrap_or(Ok(0))?;
244    let minor = components.next().unwrap_or(Ok(0))?;
245    let patch = components.next().unwrap_or(Ok(0))?;
246
247    Ok(Version::new(major, minor, patch))
248}
249
250////////////////////////////////////////////////////////////////////////////////
251// Tests
252////////////////////////////////////////////////////////////////////////////////
253#[cfg(test)]
254mod tests {
255    use rstest::rstest;
256    use serde_json::json;
257
258    use super::*;
259
260    #[rstest]
261    fn test_get_redis_url_default_values() {
262        let config: DatabaseConfig = serde_json::from_value(json!({})).unwrap();
263        let (url, redacted_url) = get_redis_url(config);
264        assert_eq!(url, "redis://127.0.0.1:6379");
265        assert_eq!(redacted_url, "redis://127.0.0.1:6379");
266    }
267
268    #[rstest]
269    fn test_get_redis_url_password_only() {
270        // Username omitted, but password present
271        let config_json = json!({
272            "host": "example.com",
273            "port": 6380,
274            "password": "secretpw",   // >4 chars ⇒ will be redacted
275        });
276        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
277        let (url, redacted_url) = get_redis_url(config);
278        assert_eq!(url, "redis://:secretpw@example.com:6380");
279        assert_eq!(redacted_url, "redis://:se...pw@example.com:6380");
280    }
281
282    #[rstest]
283    fn test_get_redis_url_full_config_with_ssl() {
284        let config_json = json!({
285            "host": "example.com",
286            "port": 6380,
287            "username": "user",
288            "password": "pass",
289            "ssl": true,
290        });
291        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
292        let (url, redacted_url) = get_redis_url(config);
293        assert_eq!(url, "rediss://user:pass@example.com:6380");
294        assert_eq!(redacted_url, "rediss://user:pass@example.com:6380");
295    }
296
297    #[rstest]
298    fn test_get_redis_url_full_config_without_ssl() {
299        let config_json = json!({
300            "host": "example.com",
301            "port": 6380,
302            "username": "username",
303            "password": "password",
304            "ssl": false,
305        });
306        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
307        let (url, redacted_url) = get_redis_url(config);
308        assert_eq!(url, "redis://username:password@example.com:6380");
309        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
310    }
311
312    #[rstest]
313    fn test_get_redis_url_missing_username_and_password() {
314        let config_json = json!({
315            "host": "example.com",
316            "port": 6380,
317            "ssl": false,
318        });
319        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
320        let (url, redacted_url) = get_redis_url(config);
321        assert_eq!(url, "redis://example.com:6380");
322        assert_eq!(redacted_url, "redis://example.com:6380");
323    }
324
325    #[rstest]
326    fn test_get_redis_url_ssl_default_false() {
327        let config_json = json!({
328            "host": "example.com",
329            "port": 6380,
330            "username": "username",
331            "password": "password",
332            // "ssl" is intentionally omitted to test default behavior
333        });
334        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
335        let (url, redacted_url) = get_redis_url(config);
336        assert_eq!(url, "redis://username:password@example.com:6380");
337        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
338    }
339
340    #[rstest]
341    fn test_get_stream_key_with_trader_prefix_and_instance_id() {
342        let trader_id = TraderId::from("tester-123");
343        let instance_id = UUID4::new();
344        let mut config = MessageBusConfig::default();
345        config.use_instance_id = true;
346
347        let key = get_stream_key(trader_id, instance_id, &config);
348        assert_eq!(key, format!("trader-tester-123:{instance_id}:stream"));
349    }
350
351    #[rstest]
352    fn test_get_stream_key_without_trader_prefix_or_instance_id() {
353        let trader_id = TraderId::from("tester-123");
354        let instance_id = UUID4::new();
355        let mut config = MessageBusConfig::default();
356        config.use_trader_prefix = false;
357        config.use_trader_id = false;
358
359        let key = get_stream_key(trader_id, instance_id, &config);
360        assert_eq!(key, format!("stream"));
361    }
362}