nautilus_testkit/
files.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    fs::{File, OpenOptions},
18    io::{BufReader, BufWriter, Read, copy},
19    path::Path,
20};
21
22use aws_lc_rs::digest::{self, Context};
23use reqwest::blocking::Client;
24use serde_json::Value;
25
26/// Ensures that a file exists at the specified path by downloading it if necessary.
27///
28/// If the file already exists, it checks the integrity of the file using a SHA-256 checksum
29/// from the optional `checksums` file. If the checksum is valid, the function exits early. If
30/// the checksum is invalid or missing, the function updates the checksums file with the correct
31/// hash for the existing file without redownloading it.
32///
33/// If the file does not exist, it downloads the file from the specified `url` and updates the
34/// checksums file (if provided) with the calculated SHA-256 checksum of the downloaded file.
35///
36/// # Errors
37///
38/// Returns an error if:
39/// - The HTTP request cannot be sent or returns a non-success status code.
40/// - Any I/O operation fails during file creation, reading, or writing.
41/// - Checksum verification or JSON parsing fails.
42pub fn ensure_file_exists_or_download_http(
43    filepath: &Path,
44    url: &str,
45    checksums: Option<&Path>,
46) -> anyhow::Result<()> {
47    if filepath.exists() {
48        println!("File already exists: {filepath:?}");
49
50        if let Some(checksums_file) = checksums {
51            if verify_sha256_checksum(filepath, checksums_file)? {
52                println!("File is valid");
53                return Ok(());
54            } else {
55                let new_checksum = calculate_sha256(filepath)?;
56                println!("Adding checksum for existing file: {new_checksum}");
57                update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
58                return Ok(());
59            }
60        }
61        return Ok(());
62    }
63
64    download_file(filepath, url)?;
65
66    if let Some(checksums_file) = checksums {
67        let new_checksum = calculate_sha256(filepath)?;
68        update_sha256_checksums(filepath, checksums_file, &new_checksum)?;
69    }
70
71    Ok(())
72}
73
74fn download_file(filepath: &Path, url: &str) -> anyhow::Result<()> {
75    println!("Downloading file from {url} to {filepath:?}");
76
77    if let Some(parent) = filepath.parent() {
78        std::fs::create_dir_all(parent)?;
79    }
80
81    let mut response = Client::new().get(url).send()?;
82    if !response.status().is_success() {
83        anyhow::bail!("Failed to download file: HTTP {}", response.status());
84    }
85
86    let mut out = File::create(filepath)?;
87    copy(&mut response, &mut out)?;
88
89    println!("File downloaded to {filepath:?}");
90    Ok(())
91}
92
93fn calculate_sha256(filepath: &Path) -> anyhow::Result<String> {
94    let mut file = File::open(filepath)?;
95    let mut ctx = Context::new(&digest::SHA256);
96    let mut buffer = [0u8; 4096];
97
98    loop {
99        let count = file.read(&mut buffer)?;
100        if count == 0 {
101            break;
102        }
103        ctx.update(&buffer[..count]);
104    }
105
106    let digest = ctx.finish();
107    Ok(hex::encode(digest.as_ref()))
108}
109
110fn verify_sha256_checksum(filepath: &Path, checksums: &Path) -> anyhow::Result<bool> {
111    let file = File::open(checksums)?;
112    let reader = BufReader::new(file);
113    let checksums: Value = serde_json::from_reader(reader)?;
114
115    let filename = filepath.file_name().unwrap().to_str().unwrap();
116    if let Some(expected_checksum) = checksums.get(filename) {
117        let expected_checksum_str = expected_checksum.as_str().unwrap();
118        let expected_hash = expected_checksum_str
119            .strip_prefix("sha256:")
120            .unwrap_or(expected_checksum_str);
121        let calculated_checksum = calculate_sha256(filepath)?;
122        if expected_hash == calculated_checksum {
123            return Ok(true);
124        }
125    }
126
127    Ok(false)
128}
129
130fn update_sha256_checksums(
131    filepath: &Path,
132    checksums_file: &Path,
133    new_checksum: &str,
134) -> anyhow::Result<()> {
135    let checksums: Value = if checksums_file.exists() {
136        let file = File::open(checksums_file)?;
137        let reader = BufReader::new(file);
138        serde_json::from_reader(reader)?
139    } else {
140        serde_json::json!({})
141    };
142
143    let mut checksums_map = checksums.as_object().unwrap().clone();
144
145    // Add or update the checksum
146    let filename = filepath.file_name().unwrap().to_str().unwrap().to_string();
147    let prefixed_checksum = format!("sha256:{new_checksum}");
148    checksums_map.insert(filename, Value::String(prefixed_checksum));
149
150    let file = OpenOptions::new()
151        .write(true)
152        .create(true)
153        .truncate(true)
154        .open(checksums_file)?;
155    let writer = BufWriter::new(file);
156    serde_json::to_writer_pretty(writer, &serde_json::Value::Object(checksums_map))?;
157
158    Ok(())
159}
160
161////////////////////////////////////////////////////////////////////////////////
162// Tests
163////////////////////////////////////////////////////////////////////////////////
164#[cfg(test)]
165mod tests {
166    use std::{
167        fs,
168        io::{BufWriter, Write},
169        net::SocketAddr,
170        sync::Arc,
171    };
172
173    use axum::{Router, http::StatusCode, routing::get, serve};
174    use rstest::*;
175    use serde_json::{json, to_writer};
176    use tempfile::TempDir;
177    use tokio::{
178        net::TcpListener,
179        task,
180        time::{Duration, sleep},
181    };
182
183    use super::*;
184
185    async fn setup_test_server(
186        server_content: Option<String>,
187        status_code: StatusCode,
188    ) -> SocketAddr {
189        let server_content = Arc::new(server_content);
190        let server_content_clone = server_content.clone();
191        let app = Router::new().route(
192            "/testfile.txt",
193            get(move || {
194                let server_content = server_content_clone.clone();
195                async move {
196                    let response_body = match &*server_content {
197                        Some(content) => content.clone(),
198                        None => "File not found".to_string(),
199                    };
200                    (status_code, response_body)
201                }
202            }),
203        );
204
205        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
206        let addr = listener.local_addr().unwrap();
207        let server = serve(listener, app);
208
209        task::spawn(async move {
210            if let Err(e) = server.await {
211                eprintln!("server error: {e}");
212            }
213        });
214
215        sleep(Duration::from_millis(100)).await;
216
217        addr
218    }
219
220    #[tokio::test]
221    async fn test_file_already_exists() {
222        let temp_dir = TempDir::new().unwrap();
223        let file_path = temp_dir.path().join("testfile.txt");
224        fs::write(&file_path, "Existing file content").unwrap();
225
226        let url = "http://example.com/testfile.txt".to_string();
227        let result = ensure_file_exists_or_download_http(&file_path, &url, None);
228
229        assert!(result.is_ok());
230        let content = fs::read_to_string(&file_path).unwrap();
231        assert_eq!(content, "Existing file content");
232    }
233
234    #[tokio::test]
235    async fn test_download_file_success() {
236        let temp_dir = TempDir::new().unwrap();
237        let filepath = temp_dir.path().join("testfile.txt");
238        let filepath_clone = filepath.clone();
239
240        let server_content = Some("Server file content".to_string());
241        let status_code = StatusCode::OK;
242        let addr = setup_test_server(server_content.clone(), status_code).await;
243        let url = format!("http://{addr}/testfile.txt");
244
245        let result = tokio::task::spawn_blocking(move || {
246            ensure_file_exists_or_download_http(&filepath_clone, &url, None)
247        })
248        .await
249        .unwrap();
250
251        assert!(result.is_ok());
252        let content = fs::read_to_string(&filepath).unwrap();
253        assert_eq!(content, server_content.unwrap());
254    }
255
256    #[tokio::test]
257    async fn test_download_file_not_found() {
258        let temp_dir = TempDir::new().unwrap();
259        let file_path = temp_dir.path().join("testfile.txt");
260
261        let server_content = None;
262        let status_code = StatusCode::NOT_FOUND;
263        let addr = setup_test_server(server_content, status_code).await;
264        let url = format!("http://{addr}/testfile.txt");
265
266        let result = tokio::task::spawn_blocking(move || {
267            ensure_file_exists_or_download_http(&file_path, &url, None)
268        })
269        .await
270        .unwrap();
271
272        assert!(result.is_err());
273        let err_msg = format!("{}", result.unwrap_err());
274        assert!(
275            err_msg.contains("Failed to download file"),
276            "Unexpected error message: {err_msg}"
277        );
278    }
279
280    #[tokio::test]
281    async fn test_network_error() {
282        let temp_dir = TempDir::new().unwrap();
283        let file_path = temp_dir.path().join("testfile.txt");
284
285        // Use an unreachable address to simulate a network error
286        let url = "http://127.0.0.1:0/testfile.txt".to_string();
287
288        let result = tokio::task::spawn_blocking(move || {
289            ensure_file_exists_or_download_http(&file_path, &url, None)
290        })
291        .await
292        .unwrap();
293
294        assert!(result.is_err());
295        let err_msg = format!("{}", result.unwrap_err());
296        assert!(
297            err_msg.contains("error"),
298            "Unexpected error message: {err_msg}"
299        );
300    }
301
302    #[rstest]
303    fn test_calculate_sha256() -> anyhow::Result<()> {
304        let temp_dir = TempDir::new()?;
305        let test_file_path = temp_dir.path().join("test_file.txt");
306        let mut test_file = File::create(&test_file_path)?;
307        let content = b"Hello, world!";
308        test_file.write_all(content)?;
309
310        let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";
311        let calculated_hash = calculate_sha256(&test_file_path)?;
312
313        assert_eq!(calculated_hash, expected_hash);
314        Ok(())
315    }
316
317    #[rstest]
318    fn test_verify_sha256_checksum() -> anyhow::Result<()> {
319        let temp_dir = TempDir::new()?;
320        let test_file_path = temp_dir.path().join("test_file.txt");
321        let mut test_file = File::create(&test_file_path)?;
322        let content = b"Hello, world!";
323        test_file.write_all(content)?;
324
325        let calculated_checksum = calculate_sha256(&test_file_path)?;
326
327        // Create checksums.json containing the checksum
328        let checksums_path = temp_dir.path().join("checksums.json");
329        let checksums_data = json!({
330            "test_file.txt": format!("sha256:{}", calculated_checksum)
331        });
332        let checksums_file = File::create(&checksums_path)?;
333        let writer = BufWriter::new(checksums_file);
334        to_writer(writer, &checksums_data)?;
335
336        let is_valid = verify_sha256_checksum(&test_file_path, &checksums_path)?;
337        assert!(is_valid, "The checksum should be valid");
338        Ok(())
339    }
340}