1use std::{
17 cell::RefCell,
18 collections::{HashMap, HashSet},
19 rc::Rc,
20 sync::Arc,
21};
22
23use datafusion::arrow::{
24 datatypes::Schema, error::ArrowError, ipc::writer::StreamWriter, record_batch::RecordBatch,
25};
26use nautilus_common::clock::Clock;
27use nautilus_core::UnixNanos;
28use nautilus_serialization::arrow::{EncodeToRecordBatch, KEY_INSTRUMENT_ID};
29use object_store::{ObjectStore, path::Path};
30
31use super::catalog::CatalogPathPrefix;
32
33#[derive(Debug, Default, PartialEq, PartialOrd, Hash, Eq, Clone)]
34pub struct FileWriterPath {
35 path: Path,
36 type_str: String,
37 instrument_id: Option<String>,
38}
39
40pub struct FeatherBuffer {
44 writer: StreamWriter<Vec<u8>>,
46 size: u64,
48 schema: Schema,
52 max_buffer_size: u64,
54 rotation_config: RotationConfig,
56}
57
58impl FeatherBuffer {
59 pub fn new(schema: &Schema, rotation_config: RotationConfig) -> Result<Self, ArrowError> {
61 let writer = StreamWriter::try_new(Vec::new(), schema)?;
62 let mut max_buffer_size = 1_000_000_000_000; if let RotationConfig::Size { max_size } = &rotation_config {
65 max_buffer_size = *max_size;
66 }
67
68 Ok(Self {
69 writer,
70 size: 0,
71 max_buffer_size,
73 schema: schema.clone(),
74 rotation_config,
75 })
76 }
77
78 pub fn write_record_batch(&mut self, batch: &RecordBatch) -> Result<bool, ArrowError> {
82 self.writer.write(batch)?;
83 self.size += batch.get_array_memory_size() as u64;
84 Ok(self.size >= self.max_buffer_size)
85 }
86
87 pub fn take_buffer(&mut self) -> Result<Vec<u8>, ArrowError> {
89 let mut writer = StreamWriter::try_new(Vec::new(), &self.schema)?;
90 std::mem::swap(&mut self.writer, &mut writer);
91 let buffer = writer.into_inner()?;
92 self.size = 0;
94 Ok(buffer)
95 }
96
97 #[must_use]
99 pub const fn should_rotate(&self) -> bool {
100 match &self.rotation_config {
101 RotationConfig::Size { max_size } => self.size >= *max_size,
102 _ => false,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub enum RotationConfig {
110 Size {
112 max_size: u64,
114 },
115 Interval {
117 interval_ns: u64,
119 },
120 ScheduledDates {
122 interval_ns: u64,
124 schedule_ns: UnixNanos,
126 },
127 NoRotation,
129}
130
131pub struct FeatherWriter {
138 base_path: String,
140 store: Arc<dyn ObjectStore>,
142 clock: Rc<RefCell<dyn Clock>>,
144 rotation_config: RotationConfig,
146 included_types: Option<HashSet<String>>,
148 per_instrument_types: HashSet<String>,
150 writers: HashMap<FileWriterPath, FeatherBuffer>,
152}
153
154impl FeatherWriter {
155 pub fn new(
157 base_path: String,
158 store: Arc<dyn ObjectStore>,
159 clock: Rc<RefCell<dyn Clock>>,
160 rotation_config: RotationConfig,
161 included_types: Option<HashSet<String>>,
162 per_instrument_types: Option<HashSet<String>>,
163 ) -> Self {
164 Self {
165 base_path,
166 store,
167 clock,
168 rotation_config,
169 included_types,
170 per_instrument_types: per_instrument_types.unwrap_or_default(),
171 writers: HashMap::new(),
172 }
173 }
174
175 pub async fn write<T>(&mut self, data: T) -> Result<(), Box<dyn std::error::Error>>
180 where
181 T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
182 {
183 if !self.should_write::<T>() {
184 return Ok(());
185 }
186
187 let path = self.get_writer_path(&data)?;
188
189 if !self.writers.contains_key(&path) {
191 self.create_writer::<T>(path.clone(), &data)?;
192 }
193
194 let batch = T::encode_batch(&T::metadata(&data), &[data])?;
196
197 if let Some(writer) = self.writers.get_mut(&path) {
199 let should_rotate = writer.write_record_batch(&batch)?;
200 if should_rotate {
201 self.rotate_writer(&path).await?;
202 }
203 }
204
205 Ok(())
206 }
207
208 async fn rotate_writer(
211 &mut self,
212 path: &FileWriterPath,
213 ) -> Result<(), Box<dyn std::error::Error>> {
214 let mut writer = self.writers.remove(path).unwrap();
215 let bytes = writer.take_buffer()?;
216 self.store.put(&path.path, bytes.into()).await?;
217 let new_path = self.regen_writer_path(path)?;
218 self.writers.insert(new_path, writer);
219 Ok(())
220 }
221
222 fn create_writer<T>(&mut self, path: FileWriterPath, data: &T) -> Result<(), ArrowError>
224 where
225 T: EncodeToRecordBatch + CatalogPathPrefix + 'static,
226 {
227 let schema = if self.per_instrument_types.contains(T::path_prefix()) {
228 let metadata = T::metadata(data);
229 T::get_schema(Some(metadata))
230 } else {
231 T::get_schema(None)
232 };
233
234 let writer = FeatherBuffer::new(&schema, self.rotation_config.clone())?;
235 self.writers.insert(path, writer);
236 Ok(())
237 }
238
239 pub async fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
244 for (path, mut writer) in self.writers.drain() {
245 let bytes = writer.take_buffer()?;
246 self.store.put(&path.path, bytes.into()).await?;
247 }
248 Ok(())
249 }
250
251 fn should_write<T: CatalogPathPrefix>(&self) -> bool {
253 self.included_types.as_ref().is_none_or(|included| {
254 let path = T::path_prefix();
255 included.contains(path)
256 })
257 }
258
259 fn regen_writer_path(
260 &self,
261 path: &FileWriterPath,
262 ) -> Result<FileWriterPath, Box<dyn std::error::Error>> {
263 let type_str = path.type_str.clone();
264 let instrument_id = path.instrument_id.clone();
265 let timestamp = self.clock.borrow().timestamp_ns();
266 let mut path = Path::from(self.base_path.clone());
268 if let Some(ref instrument_id) = instrument_id {
269 path = path.child(type_str.clone());
270 path = path.child(format!("{instrument_id}_{timestamp}.feather"));
271 } else {
272 path = path.child(format!("{type_str}_{timestamp}.feather"));
273 }
274
275 Ok(FileWriterPath {
276 path,
277 type_str,
278 instrument_id,
279 })
280 }
281
282 fn get_writer_path<T>(&self, data: &T) -> Result<FileWriterPath, Box<dyn std::error::Error>>
284 where
285 T: EncodeToRecordBatch + CatalogPathPrefix,
286 {
287 let type_str = T::path_prefix();
288 let instrument_id = self.per_instrument_types.contains(type_str).then(|| {
289 let metadata = T::metadata(data);
290 metadata
291 .get(KEY_INSTRUMENT_ID)
292 .cloned()
293 .expect("Data {type_str} expected instrument_id metadata for per instrument writer")
294 });
295
296 let timestamp = self.clock.borrow().timestamp_ns();
297 let mut path = Path::from(self.base_path.clone());
298 if let Some(ref instrument_id) = instrument_id {
299 path = path.child(type_str);
300 path = path.child(format!("{instrument_id}_{timestamp}.feather"));
301 } else {
302 path = path.child(format!("{type_str}_{timestamp}.feather"));
303 }
304
305 Ok(FileWriterPath {
306 path,
307 type_str: type_str.to_string(),
308 instrument_id,
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use std::{io::Cursor, sync::Arc};
316
317 use datafusion::arrow::ipc::reader::StreamReader;
318 use nautilus_common::clock::TestClock;
319 use nautilus_model::{
320 data::{Data, QuoteTick, TradeTick},
321 enums::AggressorSide,
322 identifiers::{InstrumentId, TradeId},
323 types::{Price, Quantity},
324 };
325 use nautilus_serialization::arrow::{
326 ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch,
327 };
328 use object_store::{ObjectStore, local::LocalFileSystem};
329 use rstest::rstest;
330 use tempfile::TempDir;
331
332 use super::*;
333
334 #[tokio::test]
335 async fn test_writer_manager_keys() {
336 let temp_dir = TempDir::new().unwrap();
338 let base_path = temp_dir.path().to_str().unwrap().to_string();
339
340 let local_fs = LocalFileSystem::new_with_prefix(temp_dir.path()).unwrap();
342 let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
343
344 let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
346 let timestamp = clock.borrow().timestamp_ns();
347
348 let quote_type_str = QuoteTick::path_prefix();
349
350 let mut per_instrument = HashSet::new();
351 per_instrument.insert(quote_type_str.to_string());
352
353 let mut manager = FeatherWriter::new(
354 base_path.clone(),
355 store,
356 clock,
357 RotationConfig::NoRotation,
358 None,
359 Some(per_instrument),
360 );
361
362 let instrument_id = "AAPL.AAPL";
363 let quote = QuoteTick::new(
365 InstrumentId::from(instrument_id),
366 Price::from("100.0"),
367 Price::from("100.0"),
368 Quantity::from("100.0"),
369 Quantity::from("100.0"),
370 UnixNanos::from(1000000000000000000),
371 UnixNanos::from(1000000000000000000),
372 );
373
374 let trade = TradeTick::new(
375 InstrumentId::from(instrument_id),
376 Price::from("100.0"),
377 Quantity::from("100.0"),
378 AggressorSide::Buyer,
379 TradeId::from("1"),
380 UnixNanos::from(1000000000000000000),
381 UnixNanos::from(1000000000000000000),
382 );
383
384 manager.write(quote).await.unwrap();
385 manager.write(trade).await.unwrap();
386
387 let path = manager.get_writer_path("e).unwrap();
389 let expected_path = Path::from(format!(
390 "{base_path}/quotes/{instrument_id}_{timestamp}.feather"
391 ));
392 assert_eq!(path.path, expected_path);
393 assert!(manager.writers.contains_key(&path));
394 let writer = manager.writers.get(&path).unwrap();
395 assert!(writer.size > 0);
396
397 let path = manager.get_writer_path(&trade).unwrap();
398 let expected_path = Path::from(format!("{base_path}/trades_{timestamp}.feather"));
399 assert_eq!(path.path, expected_path);
400 assert!(manager.writers.contains_key(&path));
401 let writer = manager.writers.get(&path).unwrap();
402 assert!(writer.size > 0);
403 }
404
405 #[rstest]
406 fn test_file_writer_round_trip() {
407 let instrument_id = "AAPL.AAPL";
408 let quote = QuoteTick::new(
410 InstrumentId::from(instrument_id),
411 Price::from("100.0"),
412 Price::from("100.0"),
413 Quantity::from("100.0"),
414 Quantity::from("100.0"),
415 UnixNanos::from(100),
416 UnixNanos::from(100),
417 );
418 let metadata = QuoteTick::metadata("e);
419 let schema = QuoteTick::get_schema(Some(metadata.clone()));
420 let batch = QuoteTick::encode_batch(&QuoteTick::metadata("e), &[quote]).unwrap();
421
422 let mut writer = FeatherBuffer::new(&schema, RotationConfig::NoRotation).unwrap();
423 writer.write_record_batch(&batch).unwrap();
424
425 let buffer = writer.take_buffer().unwrap();
426 let mut reader = StreamReader::try_new(Cursor::new(buffer.as_slice()), None).unwrap();
427
428 let read_metadata = reader.schema().metadata().clone();
429 assert_eq!(read_metadata, metadata);
430
431 let read_batch = reader.next().unwrap().unwrap();
432 assert_eq!(read_batch.column(0), batch.column(0));
433
434 let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
435 assert_eq!(decoded[0], Data::from(quote));
436 }
437
438 #[tokio::test]
439 async fn test_round_trip() {
440 let temp_dir = TempDir::new_in(".").unwrap();
442 let base_path = temp_dir.path().to_str().unwrap().to_string();
443
444 let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
446 let store: Arc<dyn ObjectStore> = Arc::new(local_fs);
447
448 let clock: Rc<RefCell<dyn Clock>> = Rc::new(RefCell::new(TestClock::new()));
450
451 let quote_type_str = QuoteTick::path_prefix();
452 let trade_type_str = TradeTick::path_prefix();
453
454 let mut per_instrument = HashSet::new();
455 per_instrument.insert(quote_type_str.to_string());
456 per_instrument.insert(trade_type_str.to_string());
457
458 let mut manager = FeatherWriter::new(
459 base_path.clone(),
460 store,
461 clock,
462 RotationConfig::NoRotation,
463 None,
464 Some(per_instrument),
465 );
466
467 let instrument_id = "AAPL.AAPL";
468 let quote = QuoteTick::new(
470 InstrumentId::from(instrument_id),
471 Price::from("100.0"),
472 Price::from("100.0"),
473 Quantity::from("100.0"),
474 Quantity::from("100.0"),
475 UnixNanos::from(100),
476 UnixNanos::from(100),
477 );
478
479 let trade = TradeTick::new(
480 InstrumentId::from(instrument_id),
481 Price::from("100.0"),
482 Quantity::from("100.0"),
483 AggressorSide::Buyer,
484 TradeId::from("1"),
485 UnixNanos::from(100),
486 UnixNanos::from(100),
487 );
488
489 manager.write(quote).await.unwrap();
490 manager.write(trade).await.unwrap();
491
492 let paths = manager.writers.keys().cloned().collect::<Vec<_>>();
493 assert_eq!(paths.len(), 2);
494
495 manager.flush().await.unwrap();
497
498 let mut recovered_quotes = Vec::new();
500 let mut recovered_trades = Vec::new();
501 let local_fs = LocalFileSystem::new_with_prefix(&base_path).unwrap();
502 for path in paths {
503 let path_str = local_fs.path_to_filesystem(&path.path).unwrap();
504 let buffer = std::fs::File::open(&path_str).unwrap();
505 let reader = StreamReader::try_new(buffer, None).unwrap();
506 let metadata = reader.schema().metadata().clone();
507 for batch in reader {
508 let batch = batch.unwrap();
509 if path_str.to_str().unwrap().contains("quotes") {
510 let decoded = QuoteTick::decode_data_batch(&metadata, batch).unwrap();
511 recovered_quotes.extend(decoded);
512 } else if path_str.to_str().unwrap().contains("trades") {
513 let decoded = TradeTick::decode_data_batch(&metadata, batch).unwrap();
514 recovered_trades.extend(decoded);
515 }
516 }
517 }
518
519 assert_eq!(recovered_quotes.len(), 1, "Expected one QuoteTick record");
521 assert_eq!(recovered_trades.len(), 1, "Expected one TradeTick record");
522
523 assert_eq!(recovered_quotes[0], Data::from(quote));
525 assert_eq!(recovered_trades[0], Data::from(trade));
526 }
527}