nautilus_serialization/arrow/
mark_price.rs1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::MarkPriceUpdate,
26 identifiers::InstrumentId,
27 types::{Price, fixed::PRECISION_BYTES},
28};
29
30use super::{
31 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
32 extract_column,
33};
34use crate::arrow::{
35 ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch, get_raw_price,
36};
37
38impl ArrowSchemaProvider for MarkPriceUpdate {
39 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40 let fields = vec![
41 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
42 Field::new("ts_event", DataType::UInt64, false),
43 Field::new("ts_init", DataType::UInt64, false),
44 ];
45
46 match metadata {
47 Some(metadata) => Schema::new_with_metadata(fields, metadata),
48 None => Schema::new(fields),
49 }
50 }
51}
52
53fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
54 let instrument_id_str = metadata
55 .get(KEY_INSTRUMENT_ID)
56 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
57 let instrument_id = InstrumentId::from_str(instrument_id_str)
58 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
59
60 let price_precision = metadata
61 .get(KEY_PRICE_PRECISION)
62 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
63 .parse::<u8>()
64 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
65
66 Ok((instrument_id, price_precision))
67}
68
69impl EncodeToRecordBatch for MarkPriceUpdate {
70 fn encode_batch(
71 metadata: &HashMap<String, String>,
72 data: &[Self],
73 ) -> Result<RecordBatch, ArrowError> {
74 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
75 let mut ts_event_builder = UInt64Array::builder(data.len());
76 let mut ts_init_builder = UInt64Array::builder(data.len());
77
78 for update in data {
79 value_builder
80 .append_value(update.value.raw.to_le_bytes())
81 .unwrap();
82 ts_event_builder.append_value(update.ts_event.as_u64());
83 ts_init_builder.append_value(update.ts_init.as_u64());
84 }
85
86 RecordBatch::try_new(
87 Self::get_schema(Some(metadata.clone())).into(),
88 vec![
89 Arc::new(value_builder.finish()),
90 Arc::new(ts_event_builder.finish()),
91 Arc::new(ts_init_builder.finish()),
92 ],
93 )
94 }
95
96 fn metadata(&self) -> HashMap<String, String> {
97 let mut metadata = HashMap::new();
98 metadata.insert(
99 KEY_INSTRUMENT_ID.to_string(),
100 self.instrument_id.to_string(),
101 );
102 metadata.insert(
103 KEY_PRICE_PRECISION.to_string(),
104 self.value.precision.to_string(),
105 );
106 metadata
107 }
108}
109
110impl DecodeFromRecordBatch for MarkPriceUpdate {
111 fn decode_batch(
112 metadata: &HashMap<String, String>,
113 record_batch: RecordBatch,
114 ) -> Result<Vec<Self>, EncodingError> {
115 let (instrument_id, price_precision) = parse_metadata(metadata)?;
116 let cols = record_batch.columns();
117
118 let value_values = extract_column::<FixedSizeBinaryArray>(
119 cols,
120 "value",
121 0,
122 DataType::FixedSizeBinary(PRECISION_BYTES),
123 )?;
124 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
125 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
126
127 assert_eq!(
128 value_values.value_length(),
129 PRECISION_BYTES,
130 "Price precision uses {PRECISION_BYTES} byte value"
131 );
132
133 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
134 .map(|row| {
135 Ok(Self {
136 instrument_id,
137 value: Price::from_raw(get_raw_price(value_values.value(row)), price_precision),
138 ts_event: ts_event_values.value(row).into(),
139 ts_init: ts_init_values.value(row).into(),
140 })
141 })
142 .collect();
143
144 result
145 }
146}
147
148impl DecodeDataFromRecordBatch for MarkPriceUpdate {
149 fn decode_data_batch(
150 metadata: &HashMap<String, String>,
151 record_batch: RecordBatch,
152 ) -> Result<Vec<Data>, EncodingError> {
153 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
154 Ok(updates.into_iter().map(Data::from).collect())
155 }
156}
157
158#[cfg(test)]
162mod tests {
163 use std::sync::Arc;
164
165 use arrow::{array::Array, record_batch::RecordBatch};
166 use nautilus_model::types::price::PriceRaw;
167 use rstest::rstest;
168 use rust_decimal_macros::dec;
169
170 use super::*;
171 use crate::arrow::get_raw_price;
172
173 #[rstest]
174 fn test_get_schema() {
175 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
176 let metadata = HashMap::from([
177 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
178 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
179 ]);
180 let schema = MarkPriceUpdate::get_schema(Some(metadata.clone()));
181
182 let expected_fields = vec![
183 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
184 Field::new("ts_event", DataType::UInt64, false),
185 Field::new("ts_init", DataType::UInt64, false),
186 ];
187
188 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
189 assert_eq!(schema, expected_schema);
190 }
191
192 #[rstest]
193 fn test_get_schema_map() {
194 let schema_map = MarkPriceUpdate::get_schema_map();
195 let mut expected_map = HashMap::new();
196
197 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
198 expected_map.insert("value".to_string(), fixed_size_binary);
199 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
200 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
201 assert_eq!(schema_map, expected_map);
202 }
203
204 #[rstest]
205 fn test_encode_batch() {
206 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
207 let metadata = HashMap::from([
208 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
209 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
210 ]);
211
212 let update1 = MarkPriceUpdate {
213 instrument_id,
214 value: Price::from("50200.00"),
215 ts_event: 1.into(),
216 ts_init: 3.into(),
217 };
218
219 let update2 = MarkPriceUpdate {
220 instrument_id,
221 value: Price::from("50300.00"),
222 ts_event: 2.into(),
223 ts_init: 4.into(),
224 };
225
226 let data = vec![update1, update2];
227 let record_batch = MarkPriceUpdate::encode_batch(&metadata, &data).unwrap();
228
229 let columns = record_batch.columns();
230 let value_values = columns[0]
231 .as_any()
232 .downcast_ref::<FixedSizeBinaryArray>()
233 .unwrap();
234 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
235 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
236
237 assert_eq!(columns.len(), 3);
238 assert_eq!(value_values.len(), 2);
239 assert_eq!(
240 get_raw_price(value_values.value(0)),
241 Price::from(dec!(50200.00).to_string()).raw
242 );
243 assert_eq!(
244 get_raw_price(value_values.value(1)),
245 Price::from(dec!(50300.00).to_string()).raw
246 );
247 assert_eq!(ts_event_values.len(), 2);
248 assert_eq!(ts_event_values.value(0), 1);
249 assert_eq!(ts_event_values.value(1), 2);
250 assert_eq!(ts_init_values.len(), 2);
251 assert_eq!(ts_init_values.value(0), 3);
252 assert_eq!(ts_init_values.value(1), 4);
253 }
254
255 #[rstest]
256 fn test_decode_batch() {
257 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
258 let metadata = HashMap::from([
259 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
260 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
261 ]);
262
263 let value = FixedSizeBinaryArray::from(vec![
264 &(5020000 as PriceRaw).to_le_bytes(),
265 &(5030000 as PriceRaw).to_le_bytes(),
266 ]);
267 let ts_event = UInt64Array::from(vec![1, 2]);
268 let ts_init = UInt64Array::from(vec![3, 4]);
269
270 let record_batch = RecordBatch::try_new(
271 MarkPriceUpdate::get_schema(Some(metadata.clone())).into(),
272 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
273 )
274 .unwrap();
275
276 let decoded_data = MarkPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
277
278 assert_eq!(decoded_data.len(), 2);
279 assert_eq!(decoded_data[0].instrument_id, instrument_id);
280 assert_eq!(decoded_data[0].value, Price::from_raw(5020000, 2));
281 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
282 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
283
284 assert_eq!(decoded_data[1].instrument_id, instrument_id);
285 assert_eq!(decoded_data[1].value, Price::from_raw(5030000, 2));
286 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
287 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
288 }
289}