1use std::{
19 collections::HashMap,
20 fmt::Display,
21 ops::{Deref, DerefMut},
22};
23
24use rust_decimal::{Decimal, prelude::ToPrimitive};
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 accounts::{Account, base::BaseAccount},
29 enums::{AccountType, LiquiditySide, OrderSide},
30 events::{AccountState, OrderFilled},
31 identifiers::AccountId,
32 instruments::InstrumentAny,
33 position::Position,
34 types::{AccountBalance, Currency, Money, Price, Quantity},
35};
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[cfg_attr(
39 feature = "python",
40 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.model")
41)]
42pub struct CashAccount {
43 pub base: BaseAccount,
44}
45
46impl CashAccount {
47 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
49 Self {
50 base: BaseAccount::new(event, calculate_account_state),
51 }
52 }
53
54 #[must_use]
55 pub fn is_cash_account(&self) -> bool {
56 self.account_type == AccountType::Cash
57 }
58 #[must_use]
59 pub fn is_margin_account(&self) -> bool {
60 self.account_type == AccountType::Margin
61 }
62
63 #[must_use]
64 pub const fn is_unleveraged(&self) -> bool {
65 false
66 }
67
68 pub fn recalculate_balance(&mut self, currency: Currency) {
74 let current_balance = match self.balances.get(¤cy) {
75 Some(balance) => *balance,
76 None => {
77 return;
78 }
79 };
80
81 let total_locked = self
82 .balances
83 .values()
84 .filter(|balance| balance.currency == currency)
85 .fold(Decimal::ZERO, |acc, balance| {
86 acc + balance.locked.as_decimal()
87 });
88
89 let new_balance = AccountBalance::new(
90 current_balance.total,
91 Money::new(total_locked.to_f64().unwrap(), currency),
92 Money::new(
93 (current_balance.total.as_decimal() - total_locked)
94 .to_f64()
95 .unwrap(),
96 currency,
97 ),
98 );
99
100 self.balances.insert(currency, new_balance);
101 }
102}
103
104impl Account for CashAccount {
105 fn id(&self) -> AccountId {
106 self.id
107 }
108
109 fn account_type(&self) -> AccountType {
110 self.account_type
111 }
112
113 fn base_currency(&self) -> Option<Currency> {
114 self.base_currency
115 }
116
117 fn is_cash_account(&self) -> bool {
118 self.account_type == AccountType::Cash
119 }
120
121 fn is_margin_account(&self) -> bool {
122 self.account_type == AccountType::Margin
123 }
124
125 fn calculated_account_state(&self) -> bool {
126 false }
128
129 fn balance_total(&self, currency: Option<Currency>) -> Option<Money> {
130 self.base_balance_total(currency)
131 }
132
133 fn balances_total(&self) -> HashMap<Currency, Money> {
134 self.base_balances_total()
135 }
136
137 fn balance_free(&self, currency: Option<Currency>) -> Option<Money> {
138 self.base_balance_free(currency)
139 }
140
141 fn balances_free(&self) -> HashMap<Currency, Money> {
142 self.base_balances_free()
143 }
144
145 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
146 self.base_balance_locked(currency)
147 }
148
149 fn balances_locked(&self) -> HashMap<Currency, Money> {
150 self.base_balances_locked()
151 }
152
153 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
154 self.base_balance(currency)
155 }
156
157 fn last_event(&self) -> Option<AccountState> {
158 self.base_last_event()
159 }
160
161 fn events(&self) -> Vec<AccountState> {
162 self.events.clone()
163 }
164
165 fn event_count(&self) -> usize {
166 self.events.len()
167 }
168
169 fn currencies(&self) -> Vec<Currency> {
170 self.balances.keys().copied().collect()
171 }
172
173 fn starting_balances(&self) -> HashMap<Currency, Money> {
174 self.balances_starting.clone()
175 }
176
177 fn balances(&self) -> HashMap<Currency, AccountBalance> {
178 self.balances.clone()
179 }
180
181 fn apply(&mut self, event: AccountState) {
182 self.base_apply(event);
183 }
184
185 fn purge_account_events(&mut self, ts_now: nautilus_core::UnixNanos, lookback_secs: u64) {
186 self.base.base_purge_account_events(ts_now, lookback_secs);
187 }
188
189 fn calculate_balance_locked(
190 &mut self,
191 instrument: InstrumentAny,
192 side: OrderSide,
193 quantity: Quantity,
194 price: Price,
195 use_quote_for_inverse: Option<bool>,
196 ) -> anyhow::Result<Money> {
197 self.base_calculate_balance_locked(instrument, side, quantity, price, use_quote_for_inverse)
198 }
199
200 fn calculate_pnls(
201 &self,
202 instrument: InstrumentAny, fill: OrderFilled, position: Option<Position>,
205 ) -> anyhow::Result<Vec<Money>> {
206 self.base_calculate_pnls(instrument, fill, position)
207 }
208
209 fn calculate_commission(
210 &self,
211 instrument: InstrumentAny,
212 last_qty: Quantity,
213 last_px: Price,
214 liquidity_side: LiquiditySide,
215 use_quote_for_inverse: Option<bool>,
216 ) -> anyhow::Result<Money> {
217 self.base_calculate_commission(
218 instrument,
219 last_qty,
220 last_px,
221 liquidity_side,
222 use_quote_for_inverse,
223 )
224 }
225}
226
227impl Deref for CashAccount {
228 type Target = BaseAccount;
229
230 fn deref(&self) -> &Self::Target {
231 &self.base
232 }
233}
234
235impl DerefMut for CashAccount {
236 fn deref_mut(&mut self) -> &mut Self::Target {
237 &mut self.base
238 }
239}
240
241impl PartialEq for CashAccount {
242 fn eq(&self, other: &Self) -> bool {
243 self.id == other.id
244 }
245}
246
247impl Eq for CashAccount {}
248
249impl Display for CashAccount {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(
252 f,
253 "CashAccount(id={}, type={}, base={})",
254 self.id,
255 self.account_type,
256 self.base_currency.map_or_else(
257 || "None".to_string(),
258 |base_currency| format!("{}", base_currency.code)
259 ),
260 )
261 }
262}
263
264#[cfg(test)]
268mod tests {
269 use std::collections::{HashMap, HashSet};
270
271 use rstest::rstest;
272
273 use crate::{
274 accounts::{Account, CashAccount, stubs::*},
275 enums::{AccountType, LiquiditySide, OrderSide, OrderType},
276 events::{AccountState, account::stubs::*},
277 identifiers::{AccountId, position_id::PositionId},
278 instruments::{CryptoPerpetual, CurrencyPair, Equity, Instrument, InstrumentAny, stubs::*},
279 orders::{builder::OrderTestBuilder, stubs::TestOrderEventStubs},
280 position::Position,
281 types::{Currency, Money, Price, Quantity},
282 };
283
284 #[rstest]
285 fn test_display(cash_account: CashAccount) {
286 assert_eq!(
287 format!("{cash_account}"),
288 "CashAccount(id=SIM-001, type=CASH, base=USD)"
289 );
290 }
291
292 #[rstest]
293 fn test_instantiate_single_asset_cash_account(
294 cash_account: CashAccount,
295 cash_account_state: AccountState,
296 ) {
297 assert_eq!(cash_account.id, AccountId::from("SIM-001"));
298 assert_eq!(cash_account.account_type, AccountType::Cash);
299 assert_eq!(cash_account.base_currency, Some(Currency::from("USD")));
300 assert_eq!(cash_account.last_event(), Some(cash_account_state.clone()));
301 assert_eq!(cash_account.events(), vec![cash_account_state]);
302 assert_eq!(cash_account.event_count(), 1);
303 assert_eq!(
304 cash_account.balance_total(None),
305 Some(Money::from("1525000 USD"))
306 );
307 assert_eq!(
308 cash_account.balance_free(None),
309 Some(Money::from("1500000 USD"))
310 );
311 assert_eq!(
312 cash_account.balance_locked(None),
313 Some(Money::from("25000 USD"))
314 );
315 let mut balances_total_expected = HashMap::new();
316 balances_total_expected.insert(Currency::from("USD"), Money::from("1525000 USD"));
317 assert_eq!(cash_account.balances_total(), balances_total_expected);
318 let mut balances_free_expected = HashMap::new();
319 balances_free_expected.insert(Currency::from("USD"), Money::from("1500000 USD"));
320 assert_eq!(cash_account.balances_free(), balances_free_expected);
321 let mut balances_locked_expected = HashMap::new();
322 balances_locked_expected.insert(Currency::from("USD"), Money::from("25000 USD"));
323 assert_eq!(cash_account.balances_locked(), balances_locked_expected);
324 }
325
326 #[rstest]
327 fn test_instantiate_multi_asset_cash_account(
328 cash_account_multi: CashAccount,
329 cash_account_state_multi: AccountState,
330 ) {
331 assert_eq!(cash_account_multi.id, AccountId::from("SIM-001"));
332 assert_eq!(cash_account_multi.account_type, AccountType::Cash);
333 assert_eq!(
334 cash_account_multi.last_event(),
335 Some(cash_account_state_multi.clone())
336 );
337 assert_eq!(cash_account_state_multi.base_currency, None);
338 assert_eq!(cash_account_multi.events(), vec![cash_account_state_multi]);
339 assert_eq!(cash_account_multi.event_count(), 1);
340 assert_eq!(
341 cash_account_multi.balance_total(Some(Currency::BTC())),
342 Some(Money::from("10 BTC"))
343 );
344 assert_eq!(
345 cash_account_multi.balance_total(Some(Currency::ETH())),
346 Some(Money::from("20 ETH"))
347 );
348 assert_eq!(
349 cash_account_multi.balance_free(Some(Currency::BTC())),
350 Some(Money::from("10 BTC"))
351 );
352 assert_eq!(
353 cash_account_multi.balance_free(Some(Currency::ETH())),
354 Some(Money::from("20 ETH"))
355 );
356 assert_eq!(
357 cash_account_multi.balance_locked(Some(Currency::BTC())),
358 Some(Money::from("0 BTC"))
359 );
360 assert_eq!(
361 cash_account_multi.balance_locked(Some(Currency::ETH())),
362 Some(Money::from("0 ETH"))
363 );
364 let mut balances_total_expected = HashMap::new();
365 balances_total_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
366 balances_total_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
367 assert_eq!(cash_account_multi.balances_total(), balances_total_expected);
368 let mut balances_free_expected = HashMap::new();
369 balances_free_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
370 balances_free_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
371 assert_eq!(cash_account_multi.balances_free(), balances_free_expected);
372 let mut balances_locked_expected = HashMap::new();
373 balances_locked_expected.insert(Currency::from("BTC"), Money::from("0 BTC"));
374 balances_locked_expected.insert(Currency::from("ETH"), Money::from("0 ETH"));
375 assert_eq!(
376 cash_account_multi.balances_locked(),
377 balances_locked_expected
378 );
379 }
380
381 #[rstest]
382 fn test_apply_given_new_state_event_updates_correctly(
383 mut cash_account_multi: CashAccount,
384 cash_account_state_multi: AccountState,
385 cash_account_state_multi_changed_btc: AccountState,
386 ) {
387 cash_account_multi.apply(cash_account_state_multi_changed_btc.clone());
389 assert_eq!(
390 cash_account_multi.last_event(),
391 Some(cash_account_state_multi_changed_btc.clone())
392 );
393 assert_eq!(
394 cash_account_multi.events,
395 vec![
396 cash_account_state_multi,
397 cash_account_state_multi_changed_btc
398 ]
399 );
400 assert_eq!(cash_account_multi.event_count(), 2);
401 assert_eq!(
402 cash_account_multi.balance_total(Some(Currency::BTC())),
403 Some(Money::from("9 BTC"))
404 );
405 assert_eq!(
406 cash_account_multi.balance_free(Some(Currency::BTC())),
407 Some(Money::from("8.5 BTC"))
408 );
409 assert_eq!(
410 cash_account_multi.balance_locked(Some(Currency::BTC())),
411 Some(Money::from("0.5 BTC"))
412 );
413 assert_eq!(
414 cash_account_multi.balance_total(Some(Currency::ETH())),
415 Some(Money::from("20 ETH"))
416 );
417 assert_eq!(
418 cash_account_multi.balance_free(Some(Currency::ETH())),
419 Some(Money::from("20 ETH"))
420 );
421 assert_eq!(
422 cash_account_multi.balance_locked(Some(Currency::ETH())),
423 Some(Money::from("0 ETH"))
424 );
425 }
426
427 #[rstest]
428 fn test_calculate_balance_locked_buy(
429 mut cash_account_million_usd: CashAccount,
430 audusd_sim: CurrencyPair,
431 ) {
432 let balance_locked = cash_account_million_usd
433 .calculate_balance_locked(
434 audusd_sim.into_any(),
435 OrderSide::Buy,
436 Quantity::from("1000000"),
437 Price::from("0.8"),
438 None,
439 )
440 .unwrap();
441 assert_eq!(balance_locked, Money::from("800000 USD"));
442 }
443
444 #[rstest]
445 fn test_calculate_balance_locked_sell(
446 mut cash_account_million_usd: CashAccount,
447 audusd_sim: CurrencyPair,
448 ) {
449 let balance_locked = cash_account_million_usd
450 .calculate_balance_locked(
451 audusd_sim.into_any(),
452 OrderSide::Sell,
453 Quantity::from("1000000"),
454 Price::from("0.8"),
455 None,
456 )
457 .unwrap();
458 assert_eq!(balance_locked, Money::from("1000000 AUD"));
459 }
460
461 #[rstest]
462 fn test_calculate_balance_locked_sell_no_base_currency(
463 mut cash_account_million_usd: CashAccount,
464 equity_aapl: Equity,
465 ) {
466 let balance_locked = cash_account_million_usd
467 .calculate_balance_locked(
468 equity_aapl.into_any(),
469 OrderSide::Sell,
470 Quantity::from("100"),
471 Price::from("1500.0"),
472 None,
473 )
474 .unwrap();
475 assert_eq!(balance_locked, Money::from("100 USD"));
476 }
477
478 #[rstest]
479 fn test_calculate_pnls_for_single_currency_cash_account(
480 cash_account_million_usd: CashAccount,
481 audusd_sim: CurrencyPair,
482 ) {
483 let audusd_sim = InstrumentAny::CurrencyPair(audusd_sim);
484 let order = OrderTestBuilder::new(OrderType::Market)
485 .instrument_id(audusd_sim.id())
486 .side(OrderSide::Buy)
487 .quantity(Quantity::from("1000000"))
488 .build();
489 let fill = TestOrderEventStubs::filled(
490 &order,
491 &audusd_sim,
492 None,
493 Some(PositionId::new("P-123456")),
494 Some(Price::from("0.8")),
495 None,
496 None,
497 None,
498 None,
499 Some(AccountId::from("SIM-001")),
500 );
501 let position = Position::new(&audusd_sim, fill.clone().into());
502 let pnls = cash_account_million_usd
503 .calculate_pnls(audusd_sim, fill.into(), Some(position)) .unwrap();
505 assert_eq!(pnls, vec![Money::from("-800000 USD")]);
506 }
507
508 #[rstest]
509 fn test_calculate_pnls_for_multi_currency_cash_account_btcusdt(
510 cash_account_multi: CashAccount,
511 currency_pair_btcusdt: CurrencyPair,
512 ) {
513 let btcusdt = InstrumentAny::CurrencyPair(currency_pair_btcusdt);
514 let order1 = OrderTestBuilder::new(OrderType::Market)
515 .instrument_id(currency_pair_btcusdt.id)
516 .side(OrderSide::Sell)
517 .quantity(Quantity::from("0.5"))
518 .build();
519 let fill1 = TestOrderEventStubs::filled(
520 &order1,
521 &btcusdt,
522 None,
523 Some(PositionId::new("P-123456")),
524 Some(Price::from("45500.00")),
525 None,
526 None,
527 None,
528 None,
529 Some(AccountId::from("SIM-001")),
530 );
531 let position = Position::new(&btcusdt, fill1.clone().into());
532 let result1 = cash_account_multi
533 .calculate_pnls(
534 currency_pair_btcusdt.into_any(),
535 fill1.into(), Some(position.clone()),
537 )
538 .unwrap();
539 let order2 = OrderTestBuilder::new(OrderType::Market)
540 .instrument_id(currency_pair_btcusdt.id)
541 .side(OrderSide::Buy)
542 .quantity(Quantity::from("0.5"))
543 .build();
544 let fill2 = TestOrderEventStubs::filled(
545 &order2,
546 &btcusdt,
547 None,
548 Some(PositionId::new("P-123456")),
549 Some(Price::from("45500.00")),
550 None,
551 None,
552 None,
553 None,
554 Some(AccountId::from("SIM-001")),
555 );
556 let result2 = cash_account_multi
557 .calculate_pnls(
558 currency_pair_btcusdt.into_any(),
559 fill2.into(),
560 Some(position),
561 )
562 .unwrap();
563 let result1_set: HashSet<Money> = result1.into_iter().collect();
565 let result1_expected: HashSet<Money> =
566 vec![Money::from("22750 USDT"), Money::from("-0.5 BTC")]
567 .into_iter()
568 .collect();
569 let result2_set: HashSet<Money> = result2.into_iter().collect();
570 let result2_expected: HashSet<Money> =
571 vec![Money::from("-22750 USDT"), Money::from("0.5 BTC")]
572 .into_iter()
573 .collect();
574 assert_eq!(result1_set, result1_expected);
575 assert_eq!(result2_set, result2_expected);
576 }
577
578 #[rstest]
579 #[case(false, Money::from("-0.00218331 BTC"))]
580 #[case(true, Money::from("-25.0 USD"))]
581 fn test_calculate_commission_for_inverse_maker_crypto(
582 #[case] use_quote_for_inverse: bool,
583 #[case] expected: Money,
584 cash_account_million_usd: CashAccount,
585 xbtusd_bitmex: CryptoPerpetual,
586 ) {
587 let result = cash_account_million_usd
588 .calculate_commission(
589 xbtusd_bitmex.into_any(),
590 Quantity::from("100000"),
591 Price::from("11450.50"),
592 LiquiditySide::Maker,
593 Some(use_quote_for_inverse),
594 )
595 .unwrap();
596 assert_eq!(result, expected);
597 }
598
599 #[rstest]
600 fn test_calculate_commission_for_taker_fx(
601 cash_account_million_usd: CashAccount,
602 audusd_sim: CurrencyPair,
603 ) {
604 let result = cash_account_million_usd
605 .calculate_commission(
606 audusd_sim.into_any(),
607 Quantity::from("1500000"),
608 Price::from("0.8005"),
609 LiquiditySide::Taker,
610 None,
611 )
612 .unwrap();
613 assert_eq!(result, Money::from("24.02 USD"));
614 }
615
616 #[rstest]
617 fn test_calculate_commission_crypto_taker(
618 cash_account_million_usd: CashAccount,
619 xbtusd_bitmex: CryptoPerpetual,
620 ) {
621 let result = cash_account_million_usd
622 .calculate_commission(
623 xbtusd_bitmex.into_any(),
624 Quantity::from("100000"),
625 Price::from("11450.50"),
626 LiquiditySide::Taker,
627 None,
628 )
629 .unwrap();
630 assert_eq!(result, Money::from("0.00654993 BTC"));
631 }
632
633 #[rstest]
634 fn test_calculate_commission_fx_taker(cash_account_million_usd: CashAccount) {
635 let instrument = usdjpy_idealpro();
636 let result = cash_account_million_usd
637 .calculate_commission(
638 instrument.into_any(),
639 Quantity::from("2200000"),
640 Price::from("120.310"),
641 LiquiditySide::Taker,
642 None,
643 )
644 .unwrap();
645 assert_eq!(result, Money::from("5294 JPY"));
646 }
647}