1use std::{collections::HashMap, hash::Hash, str::FromStr, sync::Arc, time::Duration};
19
20use bytes::Bytes;
21use http::{HeaderValue, StatusCode, status::InvalidStatusCode};
22use reqwest::{
23 Method, Response, Url,
24 header::{HeaderMap, HeaderName},
25};
26
27use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
28
29#[derive(Clone, Debug)]
34#[cfg_attr(
35 feature = "python",
36 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
37)]
38pub struct HttpStatus {
39 inner: StatusCode,
40}
41
42impl HttpStatus {
43 #[must_use]
45 pub const fn new(code: StatusCode) -> Self {
46 Self { inner: code }
47 }
48
49 pub fn from(code: u16) -> Result<Self, InvalidStatusCode> {
55 Ok(Self {
56 inner: StatusCode::from_u16(code)?,
57 })
58 }
59
60 #[inline]
62 #[must_use]
63 pub const fn as_u16(&self) -> u16 {
64 self.inner.as_u16()
65 }
66
67 #[inline]
69 #[must_use]
70 pub fn as_str(&self) -> &str {
71 self.inner.as_str()
72 }
73
74 #[inline]
76 #[must_use]
77 pub fn is_informational(&self) -> bool {
78 self.inner.is_informational()
79 }
80
81 #[inline]
83 #[must_use]
84 pub fn is_success(&self) -> bool {
85 self.inner.is_success()
86 }
87
88 #[inline]
90 #[must_use]
91 pub fn is_redirection(&self) -> bool {
92 self.inner.is_redirection()
93 }
94
95 #[inline]
97 #[must_use]
98 pub fn is_client_error(&self) -> bool {
99 self.inner.is_client_error()
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn is_server_error(&self) -> bool {
106 self.inner.is_server_error()
107 }
108}
109
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
112#[cfg_attr(
113 feature = "python",
114 pyo3::pyclass(eq, eq_int, module = "posei_trader.core.nautilus_pyo3.network")
115)]
116pub enum HttpMethod {
117 GET,
118 POST,
119 PUT,
120 DELETE,
121 PATCH,
122}
123
124#[allow(clippy::from_over_into)]
125impl Into<Method> for HttpMethod {
126 fn into(self) -> Method {
127 match self {
128 Self::GET => Method::GET,
129 Self::POST => Method::POST,
130 Self::PUT => Method::PUT,
131 Self::DELETE => Method::DELETE,
132 Self::PATCH => Method::PATCH,
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
142#[cfg_attr(
143 feature = "python",
144 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
145)]
146pub struct HttpResponse {
147 pub status: HttpStatus,
149 pub headers: HashMap<String, String>,
151 pub body: Bytes,
153}
154
155#[derive(thiserror::Error, Debug)]
159pub enum HttpClientError {
160 #[error("HTTP error occurred: {0}")]
161 Error(String),
162
163 #[error("HTTP request timed out: {0}")]
164 TimeoutError(String),
165}
166
167impl From<reqwest::Error> for HttpClientError {
168 fn from(source: reqwest::Error) -> Self {
169 if source.is_timeout() {
170 Self::TimeoutError(source.to_string())
171 } else {
172 Self::Error(source.to_string())
173 }
174 }
175}
176
177impl From<String> for HttpClientError {
178 fn from(value: String) -> Self {
179 Self::Error(value)
180 }
181}
182
183#[derive(Clone, Debug)]
193#[cfg_attr(
194 feature = "python",
195 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
196)]
197pub struct HttpClient {
198 pub(crate) client: InnerHttpClient,
200 pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
202}
203
204impl HttpClient {
205 #[must_use]
211 pub fn new(
212 headers: HashMap<String, String>,
213 header_keys: Vec<String>,
214 keyed_quotas: Vec<(String, Quota)>,
215 default_quota: Option<Quota>,
216 timeout_secs: Option<u64>,
217 ) -> Self {
218 let mut header_map = HeaderMap::new();
220 for (key, value) in headers {
221 let header_name = HeaderName::from_str(&key).expect("Invalid header name");
222 let header_value = HeaderValue::from_str(&value).expect("Invalid header value");
223 header_map.insert(header_name, header_value);
224 }
225
226 let mut client_builder = reqwest::Client::builder().default_headers(header_map);
227 if let Some(timeout_secs) = timeout_secs {
228 client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
229 }
230
231 let client = client_builder
232 .build()
233 .expect("Failed to build reqwest client");
234
235 let client = InnerHttpClient {
236 client,
237 header_keys: Arc::new(header_keys),
238 };
239
240 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
241
242 Self {
243 client,
244 rate_limiter,
245 }
246 }
247
248 #[allow(clippy::too_many_arguments)]
265 pub async fn request(
266 &self,
267 method: Method,
268 url: String,
269 headers: Option<HashMap<String, String>>,
270 body: Option<Vec<u8>>,
271 timeout_secs: Option<u64>,
272 keys: Option<Vec<String>>,
273 ) -> Result<HttpResponse, HttpClientError> {
274 let rate_limiter = self.rate_limiter.clone();
275 rate_limiter.await_keys_ready(keys).await;
276
277 self.client
278 .send_request(method, url, headers, body, timeout_secs)
279 .await
280 }
281}
282
283#[derive(Clone, Debug)]
292pub struct InnerHttpClient {
293 pub(crate) client: reqwest::Client,
294 pub(crate) header_keys: Arc<Vec<String>>,
295}
296
297impl InnerHttpClient {
298 pub async fn send_request(
310 &self,
311 method: Method,
312 url: String,
313 headers: Option<HashMap<String, String>>,
314 body: Option<Vec<u8>>,
315 timeout_secs: Option<u64>,
316 ) -> Result<HttpResponse, HttpClientError> {
317 let headers = headers.unwrap_or_default();
318 let reqwest_url = Url::parse(url.as_str())
319 .map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
320
321 let mut header_map = HeaderMap::new();
322 for (header_key, header_value) in &headers {
323 let key = HeaderName::from_bytes(header_key.as_bytes())
324 .map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
325 let _ = header_map.insert(
326 key,
327 header_value
328 .parse()
329 .map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
330 );
331 }
332
333 let mut request_builder = self.client.request(method, reqwest_url).headers(header_map);
334
335 if let Some(timeout_secs) = timeout_secs {
336 request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
337 }
338
339 let request = match body {
340 Some(b) => request_builder
341 .body(b)
342 .build()
343 .map_err(HttpClientError::from)?,
344 None => request_builder.build().map_err(HttpClientError::from)?,
345 };
346
347 tracing::trace!("{request:?}");
348
349 let response = self
350 .client
351 .execute(request)
352 .await
353 .map_err(HttpClientError::from)?;
354
355 self.to_response(response).await
356 }
357
358 pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
364 tracing::trace!("{response:?}");
365
366 let headers: HashMap<String, String> = self
367 .header_keys
368 .iter()
369 .filter_map(|key| response.headers().get(key).map(|val| (key, val)))
370 .filter_map(|(key, val)| val.to_str().map(|v| (key, v)).ok())
371 .map(|(k, v)| (k.clone(), v.to_owned()))
372 .collect();
373 let status = HttpStatus::new(response.status());
374 let body = response.bytes().await.map_err(HttpClientError::from)?;
375
376 Ok(HttpResponse {
377 status,
378 headers,
379 body,
380 })
381 }
382}
383
384impl Default for InnerHttpClient {
385 fn default() -> Self {
389 let client = reqwest::Client::new();
390 Self {
391 client,
392 header_keys: Default::default(),
393 }
394 }
395}
396
397#[cfg(test)]
401#[cfg(target_os = "linux")] mod tests {
403 use std::net::{SocketAddr, TcpListener};
404
405 use axum::{
406 Router,
407 routing::{delete, get, patch, post},
408 serve,
409 };
410 use http::status::StatusCode;
411
412 use super::*;
413
414 fn get_unique_port() -> u16 {
415 let listener =
417 TcpListener::bind("127.0.0.1:0").expect("Failed to bind temporary TcpListener");
418 let port = listener.local_addr().unwrap().port();
419
420 drop(listener);
422
423 port
424 }
425
426 fn create_router() -> Router {
427 Router::new()
428 .route("/get", get(|| async { "hello-world!" }))
429 .route("/post", post(|| async { StatusCode::OK }))
430 .route("/patch", patch(|| async { StatusCode::OK }))
431 .route("/delete", delete(|| async { StatusCode::OK }))
432 .route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
433 .route(
434 "/slow",
435 get(|| async {
436 tokio::time::sleep(Duration::from_secs(2)).await;
437 "Eventually responded"
438 }),
439 )
440 }
441
442 async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
443 let port = get_unique_port();
444 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}"))
445 .await
446 .unwrap();
447 let addr = listener.local_addr().unwrap();
448
449 tokio::spawn(async move {
450 serve(listener, create_router()).await.unwrap();
451 });
452
453 Ok(addr)
454 }
455
456 #[tokio::test]
457 async fn test_get() {
458 let addr = start_test_server().await.unwrap();
459 let url = format!("http://{addr}");
460
461 let client = InnerHttpClient::default();
462 let response = client
463 .send_request(reqwest::Method::GET, format!("{url}/get"), None, None, None)
464 .await
465 .unwrap();
466
467 assert!(response.status.is_success());
468 assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
469 }
470
471 #[tokio::test]
472 async fn test_post() {
473 let addr = start_test_server().await.unwrap();
474 let url = format!("http://{addr}");
475
476 let client = InnerHttpClient::default();
477 let response = client
478 .send_request(
479 reqwest::Method::POST,
480 format!("{url}/post"),
481 None,
482 None,
483 None,
484 )
485 .await
486 .unwrap();
487
488 assert!(response.status.is_success());
489 }
490
491 #[tokio::test]
492 async fn test_post_with_body() {
493 let addr = start_test_server().await.unwrap();
494 let url = format!("http://{addr}");
495
496 let client = InnerHttpClient::default();
497
498 let mut body = HashMap::new();
499 body.insert(
500 "key1".to_string(),
501 serde_json::Value::String("value1".to_string()),
502 );
503 body.insert(
504 "key2".to_string(),
505 serde_json::Value::String("value2".to_string()),
506 );
507
508 let body_string = serde_json::to_string(&body).unwrap();
509 let body_bytes = body_string.into_bytes();
510
511 let response = client
512 .send_request(
513 reqwest::Method::POST,
514 format!("{url}/post"),
515 None,
516 Some(body_bytes),
517 None,
518 )
519 .await
520 .unwrap();
521
522 assert!(response.status.is_success());
523 }
524
525 #[tokio::test]
526 async fn test_patch() {
527 let addr = start_test_server().await.unwrap();
528 let url = format!("http://{addr}");
529
530 let client = InnerHttpClient::default();
531 let response = client
532 .send_request(
533 reqwest::Method::PATCH,
534 format!("{url}/patch"),
535 None,
536 None,
537 None,
538 )
539 .await
540 .unwrap();
541
542 assert!(response.status.is_success());
543 }
544
545 #[tokio::test]
546 async fn test_delete() {
547 let addr = start_test_server().await.unwrap();
548 let url = format!("http://{addr}");
549
550 let client = InnerHttpClient::default();
551 let response = client
552 .send_request(
553 reqwest::Method::DELETE,
554 format!("{url}/delete"),
555 None,
556 None,
557 None,
558 )
559 .await
560 .unwrap();
561
562 assert!(response.status.is_success());
563 }
564
565 #[tokio::test]
566 async fn test_not_found() {
567 let addr = start_test_server().await.unwrap();
568 let url = format!("http://{addr}/notfound");
569 let client = InnerHttpClient::default();
570
571 let response = client
572 .send_request(reqwest::Method::GET, url, None, None, None)
573 .await
574 .unwrap();
575
576 assert!(response.status.is_client_error());
577 assert_eq!(response.status.as_u16(), 404);
578 }
579
580 #[tokio::test]
581 async fn test_timeout() {
582 let addr = start_test_server().await.unwrap();
583 let url = format!("http://{addr}/slow");
584 let client = InnerHttpClient::default();
585
586 let result = client
588 .send_request(reqwest::Method::GET, url, None, None, Some(1))
589 .await;
590
591 match result {
592 Err(HttpClientError::TimeoutError(msg)) => {
593 println!("Got expected timeout error: {msg}");
594 }
595 Err(other) => panic!("Expected a timeout error, got: {other:?}"),
596 Ok(resp) => panic!("Expected a timeout error, but got a successful response: {resp:?}"),
597 }
598 }
599}