diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 9634c740d01d..fd784df17206 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -19,13 +19,13 @@ use crate::aws::checksum::Checksum; use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; use crate::aws::STRICT_PATH_ENCODE_SET; use crate::client::pagination::stream_paginated; -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::multipart::UploadPart; use crate::path::DELIMITER; use crate::util::{format_http_range, format_prefix}; use crate::{ BoxStream, ClientOptions, ListResult, MultipartId, ObjectMeta, Path, Result, - RetryConfig, StreamExt, + StreamExt, }; use base64::prelude::BASE64_STANDARD; use base64::Engine; @@ -208,7 +208,7 @@ pub struct S3Config { pub bucket: String, pub bucket_endpoint: String, pub credentials: Box, - pub retry_config: RetryConfig, + pub client_config: ClientConfig, pub client_options: ClientOptions, pub sign_payload: bool, pub checksum: Option, @@ -271,7 +271,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(GetRequestSnafu { path: path.as_ref(), @@ -317,7 +317,7 @@ impl S3Client { self.config.sign_payload, payload_sha256, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(PutRequestSnafu { path: path.as_ref(), @@ -345,7 +345,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(DeleteRequestSnafu { path: path.as_ref(), @@ -370,7 +370,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(CopyRequestSnafu { path: from.as_ref(), @@ -422,7 +422,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(ListRequestSnafu)? .bytes() @@ -476,7 +476,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(CreateMultipartRequestSnafu)? .bytes() @@ -521,7 +521,7 @@ impl S3Client { self.config.sign_payload, None, ) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(CompleteMultipartRequestSnafu)?; diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index c4cb7cfe1a01..5b34c94ffe21 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -16,10 +16,10 @@ // under the License. use crate::aws::STRICT_ENCODE_SET; -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::client::token::{TemporaryToken, TokenCache}; use crate::util::hmac_sha256; -use crate::{Result, RetryConfig}; +use crate::Result; use bytes::Buf; use chrono::{DateTime, Utc}; use futures::future::BoxFuture; @@ -315,7 +315,7 @@ impl CredentialProvider for StaticCredentialProvider { pub struct InstanceCredentialProvider { pub cache: TokenCache>, pub client: Client, - pub retry_config: RetryConfig, + pub client_config: ClientConfig, pub imdsv1_fallback: bool, pub metadata_endpoint: String, } @@ -325,7 +325,7 @@ impl CredentialProvider for InstanceCredentialProvider { Box::pin(self.cache.get_or_insert_with(|| { instance_creds( &self.client, - &self.retry_config, + &self.client_config, &self.metadata_endpoint, self.imdsv1_fallback, ) @@ -348,7 +348,7 @@ pub struct WebIdentityProvider { pub session_name: String, pub endpoint: String, pub client: Client, - pub retry_config: RetryConfig, + pub client_config: ClientConfig, } impl CredentialProvider for WebIdentityProvider { @@ -356,7 +356,7 @@ impl CredentialProvider for WebIdentityProvider { Box::pin(self.cache.get_or_insert_with(|| { web_identity( &self.client, - &self.retry_config, + &self.client_config, &self.token_path, &self.role_arn, &self.session_name, @@ -392,7 +392,7 @@ impl From for AwsCredential { /// async fn instance_creds( client: &Client, - retry_config: &RetryConfig, + config: &ClientConfig, endpoint: &str, imdsv1_fallback: bool, ) -> Result>, StdError> { @@ -404,7 +404,7 @@ async fn instance_creds( let token_result = client .request(Method::PUT, token_url) .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL - .send_retry(retry_config) + .send_retry(config) .await; let token = match token_result { @@ -425,7 +425,7 @@ async fn instance_creds( role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); } - let role = role_request.send_retry(retry_config).await?.text().await?; + let role = role_request.send_retry(config).await?.text().await?; let creds_url = format!("{endpoint}/{CREDENTIALS_PATH}/{role}"); let mut creds_request = client.request(Method::GET, creds_url); @@ -434,7 +434,7 @@ async fn instance_creds( } let creds: InstanceCredentials = - creds_request.send_retry(retry_config).await?.json().await?; + creds_request.send_retry(config).await?.json().await?; let now = Utc::now(); let ttl = (creds.expiration - now).to_std().unwrap_or_default(); @@ -478,7 +478,7 @@ impl From for AwsCredential { /// async fn web_identity( client: &Client, - retry_config: &RetryConfig, + config: &ClientConfig, token_path: &str, role_arn: &str, session_name: &str, @@ -497,7 +497,7 @@ async fn web_identity( ("Version", "2011-06-15"), ("WebIdentityToken", &token), ]) - .send_retry(retry_config) + .send_retry(config) .await? .bytes() .await?; @@ -709,7 +709,7 @@ mod tests { // For example https://github.com/aws/amazon-ec2-metadata-mock let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap(); let client = Client::new(); - let retry_config = RetryConfig::default(); + let config = ClientConfig::default(); // Verify only allows IMDSv2 let resp = client @@ -724,7 +724,7 @@ mod tests { "Ensure metadata endpoint is set to only allow IMDSv2" ); - let creds = instance_creds(&client, &retry_config, &endpoint, false) + let creds = instance_creds(&client, &config, &endpoint, false) .await .unwrap(); @@ -749,7 +749,7 @@ mod tests { let endpoint = server.url(); let client = Client::new(); - let retry_config = RetryConfig::default(); + let config = ClientConfig::default(); // Test IMDSv2 server.push_fn(|req| { @@ -775,7 +775,7 @@ mod tests { Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) }); - let creds = instance_creds(&client, &retry_config, endpoint, true) + let creds = instance_creds(&client, &config, endpoint, true) .await .unwrap(); @@ -808,7 +808,7 @@ mod tests { Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) }); - let creds = instance_creds(&client, &retry_config, endpoint, true) + let creds = instance_creds(&client, &config, endpoint, true) .await .unwrap(); @@ -825,7 +825,7 @@ mod tests { ); // Should fail - instance_creds(&client, &retry_config, endpoint, false) + instance_creds(&client, &config, endpoint, false) .await .unwrap_err(); } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index de62360d0522..901d915f9a7c 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -44,6 +44,7 @@ use std::ops::Range; use std::str::FromStr; use std::sync::Arc; use tokio::io::AsyncWrite; +use tokio::runtime::Handle; use tracing::info; use url::Url; @@ -53,6 +54,7 @@ use crate::aws::credential::{ AwsCredential, CredentialProvider, InstanceCredentialProvider, StaticCredentialProvider, WebIdentityProvider, }; +use crate::client::retry::ClientConfig; use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; use crate::util::str_is_truthy; use crate::{ @@ -414,8 +416,8 @@ pub struct AmazonS3Builder { token: Option, /// Url url: Option, - /// Retry config - retry_config: RetryConfig, + /// Client config + client_config: ClientConfig, /// When set to true, fallback to IMDSv1 imdsv1_fallback: bool, /// When set to true, virtual hosted style request has to be used @@ -897,7 +899,17 @@ impl AmazonS3Builder { /// Set the retry configuration pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; + self.client_config.retry = retry_config; + self + } + + /// Set the tokio runtime to use to perform IO + /// + /// This allows isolating IO into a dedicated [`Runtime`](tokio::runtime::Runtime) either + /// to ensure acceptable scheduling jitter in the presence of CPU-bound tasks, or to allow + /// using `object_store` outside of a tokio context + pub fn with_tokio_runtime(mut self, runtime: Handle) -> Self { + self.client_config.runtime = Some(runtime); self } @@ -1025,7 +1037,7 @@ impl AmazonS3Builder { role_arn, endpoint, client, - retry_config: self.retry_config.clone(), + client_config: self.client_config.clone(), }) as _ } _ => match self.profile { @@ -1043,7 +1055,7 @@ impl AmazonS3Builder { Box::new(InstanceCredentialProvider { cache: Default::default(), client: client_options.client()?, - retry_config: self.retry_config.clone(), + client_config: self.client_config.clone(), imdsv1_fallback: self.imdsv1_fallback, metadata_endpoint: self .metadata_endpoint @@ -1078,7 +1090,7 @@ impl AmazonS3Builder { bucket, bucket_endpoint, credentials, - retry_config: self.retry_config, + client_config: self.client_config, client_options: self.client_options, sign_payload: !self.unsigned_payload, checksum: self.checksum_algorithm, @@ -1110,8 +1122,8 @@ fn profile_credentials( mod tests { use super::*; use crate::tests::{ - get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, - put_get_delete_list_opts, rename_and_copy, stream_get, + dedicated_tokio, get_nonexistent_object, list_uses_directories_correctly, + list_with_delimiter, put_get_delete_list_opts, rename_and_copy, stream_get, }; use bytes::Bytes; use std::collections::HashMap; @@ -1388,30 +1400,42 @@ mod tests { assert!(builder.is_err()); } - #[tokio::test] - async fn s3_test() { - let config = maybe_skip_integration!(); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let integration = config.build().unwrap(); + #[test] + fn s3_test() { + let builder = maybe_skip_integration!(); + let is_local = matches!(&builder.endpoint, Some(e) if e.starts_with("http://")); + + let test = |integration| async move { + // Localstack doesn't support listing with spaces https://github.com/localstack/localstack/issues/6328 + put_get_delete_list_opts(&integration, is_local).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + }; + + let (handle, shutdown) = dedicated_tokio(); - // Localstack doesn't support listing with spaces https://github.com/localstack/localstack/issues/6328 - put_get_delete_list_opts(&integration, is_local).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - stream_get(&integration).await; + let integration = builder.clone().build().unwrap(); + handle.block_on(test(integration)); // run integration test with unsigned payload enabled - let config = maybe_skip_integration!().with_unsigned_payload(true); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let integration = config.build().unwrap(); - put_get_delete_list_opts(&integration, is_local).await; + let integration = builder.clone().with_unsigned_payload(true).build().unwrap(); + handle.block_on(test(integration)); // run integration test with checksum set to sha256 - let config = maybe_skip_integration!().with_checksum_algorithm(Checksum::SHA256); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let integration = config.build().unwrap(); - put_get_delete_list_opts(&integration, is_local).await; + let integration = builder + .clone() + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + handle.block_on(test(integration)); + + // run integration test without tokio runtime + let integration = builder.with_tokio_runtime(handle).build().unwrap(); + futures::executor::block_on(test(integration)); + + shutdown(); } #[tokio::test] diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index 87432f62b5cd..14e71ee84a60 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -18,13 +18,10 @@ use super::credential::{AzureCredential, CredentialProvider}; use crate::azure::credential::*; use crate::client::pagination::stream_paginated; -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::path::DELIMITER; use crate::util::{deserialize_rfc1123, format_http_range, format_prefix}; -use crate::{ - BoxStream, ClientOptions, ListResult, ObjectMeta, Path, Result, RetryConfig, - StreamExt, -}; +use crate::{BoxStream, ClientOptions, ListResult, ObjectMeta, Path, Result, StreamExt}; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::{Buf, Bytes}; @@ -126,7 +123,7 @@ pub struct AzureConfig { pub account: String, pub container: String, pub credentials: CredentialProvider, - pub retry_config: RetryConfig, + pub client_config: ClientConfig, pub service: Url, pub is_emulator: bool, pub client_options: ClientOptions, @@ -184,7 +181,7 @@ impl AzureClient { CredentialProvider::TokenCredential(cache, cred) => { let token = cache .get_or_insert_with(|| { - cred.fetch_token(&self.client, &self.config.retry_config) + cred.fetch_token(&self.client, &self.config.client_config) }) .await .context(AuthorizationSnafu)?; @@ -238,7 +235,7 @@ impl AzureClient { let response = builder .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(PutRequestSnafu { path: path.as_ref(), @@ -275,7 +272,7 @@ impl AzureClient { let response = builder .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(GetRequestSnafu { path: path.as_ref(), @@ -298,7 +295,7 @@ impl AzureClient { .query(query) .header(&DELETE_SNAPSHOTS, "include") .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(DeleteRequestSnafu { path: path.as_ref(), @@ -336,7 +333,7 @@ impl AzureClient { builder .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(CopyRequestSnafu { path: from.as_ref(), @@ -376,7 +373,7 @@ impl AzureClient { .request(Method::GET, url) .query(&query) .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) + .send_retry(&self.config.client_config) .await .context(ListRequestSnafu)? .bytes() diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 0196d93d8d2a..a5c4bd47c594 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::client::token::{TemporaryToken, TokenCache}; use crate::util::hmac_sha256; -use crate::RetryConfig; use base64::prelude::BASE64_STANDARD; use base64::Engine; use chrono::{DateTime, Utc}; @@ -287,7 +286,7 @@ pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result>; } @@ -334,7 +333,7 @@ impl TokenCredential for ClientSecretOAuthProvider { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result> { let response: TokenResponse = client .request(Method::POST, &self.token_url) @@ -345,7 +344,7 @@ impl TokenCredential for ClientSecretOAuthProvider { ("scope", AZURE_STORAGE_SCOPE), ("grant_type", "client_credentials"), ]) - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json() @@ -420,7 +419,7 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { async fn fetch_token( &self, _client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result> { let mut query_items = vec![ ("api-version", MSI_API_VERSION), @@ -452,7 +451,7 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { }; let response: MsiTokenResponse = builder - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json() @@ -507,7 +506,7 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result> { let token_str = std::fs::read_to_string(&self.federated_token_file) .map_err(|_| Error::FederatedTokenFile)?; @@ -526,7 +525,7 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { ("scope", AZURE_STORAGE_SCOPE), ("grant_type", "client_credentials"), ]) - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json() @@ -591,7 +590,7 @@ impl TokenCredential for AzureCliCredential { async fn fetch_token( &self, _client: &Client, - _retry: &RetryConfig, + _config: &ClientConfig, ) -> Result> { // on window az is a cmd and it should be called like this // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html @@ -678,7 +677,7 @@ mod tests { let endpoint = server.url(); let client = Client::new(); - let retry_config = RetryConfig::default(); + let client_config = ClientConfig::default(); // Test IMDS server.push_fn(|req| { @@ -718,7 +717,7 @@ mod tests { ); let token = credential - .fetch_token(&client, &retry_config) + .fetch_token(&client, &client_config) .await .unwrap(); @@ -734,7 +733,7 @@ mod tests { let endpoint = server.url(); let client = Client::new(); - let retry_config = RetryConfig::default(); + let config = ClientConfig::default(); // Test IMDS server.push_fn(move |req| { @@ -765,10 +764,7 @@ mod tests { Some(endpoint.to_string()), ); - let token = credential - .fetch_token(&client, &retry_config) - .await - .unwrap(); + let token = credential.fetch_token(&client, &config).await.unwrap(); assert_eq!(&token.token, "TOKEN"); } diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 11350a202c72..7b28bf36cc82 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -49,8 +49,10 @@ use std::ops::Range; use std::sync::Arc; use std::{collections::BTreeSet, str::FromStr}; use tokio::io::AsyncWrite; +use tokio::runtime::Handle; use url::Url; +use crate::client::retry::ClientConfig; use crate::util::{str_is_truthy, RFC1123_FMT}; pub use credential::authority_hosts; @@ -428,8 +430,8 @@ pub struct MicrosoftAzureBuilder { federated_token_file: Option, /// When set to true, azure cli has to be used for acquiring access token use_azure_cli: bool, - /// Retry config - retry_config: RetryConfig, + /// Client config + client_config: ClientConfig, /// Client options client_options: ClientOptions, } @@ -925,7 +927,17 @@ impl MicrosoftAzureBuilder { /// Set the retry configuration pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; + self.client_config.retry = retry_config; + self + } + + /// Set the tokio runtime to use to perform IO + /// + /// This allows isolating IO into a dedicated [`Runtime`](tokio::runtime::Runtime) either + /// to ensure acceptable scheduling jitter in the presence of CPU-bound tasks, or to allow + /// using `object_store` outside of a tokio context + pub fn with_tokio_runtime(mut self, runtime: Handle) -> Self { + self.client_config.runtime = Some(runtime); self } @@ -1054,7 +1066,7 @@ impl MicrosoftAzureBuilder { account, is_emulator, container, - retry_config: self.retry_config, + client_config: self.client_config, client_options: self.client_options, service: storage_url, credentials: auth, @@ -1104,8 +1116,9 @@ fn split_sas(sas: &str) -> Result, Error> { mod tests { use super::*; use crate::tests::{ - copy_if_not_exists, list_uses_directories_correctly, list_with_delimiter, - put_get_delete_list, put_get_delete_list_opts, rename_and_copy, stream_get, + copy_if_not_exists, dedicated_tokio, list_uses_directories_correctly, + list_with_delimiter, put_get_delete_list, put_get_delete_list_opts, + rename_and_copy, stream_get, }; use std::collections::HashMap; use std::env; @@ -1172,15 +1185,27 @@ mod tests { }}; } - #[tokio::test] - async fn azure_blob_test() { - let integration = maybe_skip_integration!().build().unwrap(); - put_get_delete_list_opts(&integration, false).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - copy_if_not_exists(&integration).await; - stream_get(&integration).await; + #[test] + fn azure_blob_test() { + let builder = maybe_skip_integration!(); + let test = |integration| async move { + put_get_delete_list_opts(&integration, false).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + }; + + let (handle, shutdown) = dedicated_tokio(); + + let integration = builder.clone().build().unwrap(); + handle.block_on(test(integration)); + + let integration = builder.with_tokio_runtime(handle).build().unwrap(); + futures::executor::block_on(test(integration)); + + shutdown(); } // test for running integration test against actual blob service with service principal diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index e6dd2eb8174b..c0b477046245 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -23,17 +23,62 @@ use futures::FutureExt; use reqwest::header::LOCATION; use reqwest::{Response, StatusCode}; use std::time::{Duration, Instant}; +use tokio::runtime::Handle; +use tokio::task::JoinError; use tracing::info; +#[derive(Debug)] +pub enum Error { + Retry(RetryError), + Spawn(JoinError), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Retry(r) => r.fmt(f), + Self::Spawn(e) => write!(f, "failed to join spawned task: {e}"), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Retry(retry) => retry.source(), + Self::Spawn(e) => Some(e), + } + } +} + +impl Error { + /// Returns the status code associated with this error if any + pub fn status(&self) -> Option { + match self { + Self::Retry(e) => e.status(), + Self::Spawn(_) => None, + } + } +} + +impl From for std::io::Error { + fn from(err: Error) -> Self { + match err { + Error::Retry(e) => e.into(), + Error::Spawn(e) => Self::new(std::io::ErrorKind::Other, e), + } + } +} + /// Retry request error #[derive(Debug)] -pub struct Error { +pub struct RetryError { retries: usize, message: String, source: Option, } -impl std::fmt::Display for Error { +impl std::fmt::Display for RetryError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -47,21 +92,21 @@ impl std::fmt::Display for Error { } } -impl std::error::Error for Error { +impl std::error::Error for RetryError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { self.source.as_ref().map(|e| e as _) } } -impl Error { +impl RetryError { /// Returns the status code associated with this error if any pub fn status(&self) -> Option { self.source.as_ref().and_then(|e| e.status()) } } -impl From for std::io::Error { - fn from(err: Error) -> Self { +impl From for std::io::Error { + fn from(err: RetryError) -> Self { use std::io::ErrorKind; match (&err.source, err.status()) { (Some(source), _) if source.is_builder() || source.is_request() => { @@ -121,22 +166,37 @@ impl Default for RetryConfig { } } +/// Crate-private client configuration +/// +/// Specifically this is the config passed to [`RetryExt::send_retry`] +/// +/// This is unlike the public [`ClientOptions`](crate::ClientOptions) which contains just +/// the properties used to construct [`Client`](reqwest::Client) +#[derive(Debug, Clone, Default)] +pub struct ClientConfig { + /// The retry configuration + pub retry: RetryConfig, + + /// Optional tokio runtime to perform IO + pub runtime: Option, +} + pub trait RetryExt { - /// Dispatch a request with the given retry configuration + /// Dispatch a request with the given [`ClientConfig`] /// /// # Panic /// /// This will panic if the request body is a stream - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; + fn send_retry(self, config: &ClientConfig) -> BoxFuture<'static, Result>; } impl RetryExt for reqwest::RequestBuilder { - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { - let mut backoff = Backoff::new(&config.backoff); - let max_retries = config.max_retries; - let retry_timeout = config.retry_timeout; + fn send_retry(self, config: &ClientConfig) -> BoxFuture<'static, Result> { + let mut backoff = Backoff::new(&config.retry.backoff); + let max_retries = config.retry.max_retries; + let retry_timeout = config.retry.retry_timeout; - async move { + let fut = async move { let mut retries = 0; let now = Instant::now(); @@ -146,42 +206,44 @@ impl RetryExt for reqwest::RequestBuilder { Ok(r) => match r.error_for_status_ref() { Ok(_) if r.status().is_success() => return Ok(r), Ok(r) => { - let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); + let is_bare_redirect = r.status().is_redirection() + && !r.headers().contains_key(LOCATION); let message = match is_bare_redirect { true => "Received redirect without LOCATION, this normally indicates an incorrectly configured region".to_string(), // Not actually sure if this is reachable, but here for completeness false => format!("request unsuccessful: {}", r.status()), }; - return Err(Error{ + return Err(Error::Retry(RetryError { message, retries, source: None, - }) + })); } Err(e) => { let status = r.status(); if retries == max_retries || now.elapsed() > retry_timeout - || !status.is_server_error() { - + || !status.is_server_error() + { // Get the response message if returned a client error let message = match status.is_client_error() { true => match r.text().await { Ok(message) if !message.is_empty() => message, Ok(_) => "No Body".to_string(), - Err(e) => format!("error getting response body: {e}") - } + Err(e) => { + format!("error getting response body: {e}") + } + }, false => status.to_string(), }; - return Err(Error{ + return Err(Error::Retry(RetryError { message, retries, source: Some(e), - }) - + })); } let sleep = backoff.next(); @@ -190,26 +252,36 @@ impl RetryExt for reqwest::RequestBuilder { tokio::time::sleep(sleep).await; } }, - Err(e) => - { - return Err(Error{ + Err(e) => { + return Err(Error::Retry(RetryError { retries, message: "request error".to_string(), - source: Some(e) - }) + source: Some(e), + })) } } } + }; + + match config.runtime.as_ref() { + Some(handle) => handle + .spawn(fut) + .map(|x| match x { + Ok(r) => r, + Err(e) => Err(Error::Spawn(e)), + }) + .boxed(), + None => fut.boxed(), } - .boxed() } } #[cfg(test)] mod tests { use crate::client::mock_server::MockServer; - use crate::client::retry::RetryExt; + use crate::client::retry::{ClientConfig, Error, RetryExt}; use crate::RetryConfig; + use futures::TryFutureExt; use hyper::header::LOCATION; use hyper::{Body, Response}; use reqwest::{Client, Method, StatusCode}; @@ -224,9 +296,21 @@ mod tests { max_retries: 2, retry_timeout: Duration::from_secs(1000), }; + let config = ClientConfig { + retry, + runtime: None, + }; let client = Client::new(); - let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); + let do_request = || { + client + .request(Method::GET, mock.url()) + .send_retry(&config) + .map_err(|e| match e { + Error::Retry(e) => e, + Error::Spawn(e) => unreachable!("spawn error {e}"), + }) + }; // Simple request should work let r = do_request().await.unwrap(); @@ -332,7 +416,7 @@ mod tests { assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); // Gives up after the retrying the specified number of times - for _ in 0..=retry.max_retries { + for _ in 0..=config.retry.max_retries { mock.push( Response::builder() .status(StatusCode::BAD_GATEWAY) @@ -342,7 +426,7 @@ mod tests { } let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); + assert_eq!(e.retries, config.retry.max_retries); assert_eq!(e.message, "502 Bad Gateway"); // Shutdown diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index a8dce7132755..ce513cd40040 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::client::token::TemporaryToken; use crate::ClientOptions; -use crate::RetryConfig; use async_trait::async_trait; use base64::prelude::BASE64_URL_SAFE_NO_PAD; use base64::Engine; @@ -129,7 +128,7 @@ pub trait TokenProvider: std::fmt::Debug + Send + Sync { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result>; } @@ -175,7 +174,7 @@ impl TokenProvider for OAuthProvider { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result> { let now = seconds_since_epoch(); let exp = now + 3600; @@ -211,7 +210,7 @@ impl TokenProvider for OAuthProvider { let response: TokenResponse = client .request(Method::POST, &self.audience) .form(&body) - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json() @@ -348,7 +347,7 @@ impl InstanceCredentialProvider { async fn make_metadata_request( client: &Client, hostname: &str, - retry: &RetryConfig, + config: &ClientConfig, audience: &str, ) -> Result { let url = format!( @@ -358,7 +357,7 @@ async fn make_metadata_request( .request(Method::GET, url) .header("Metadata-Flavor", "Google") .query(&[("audience", audience)]) - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json() @@ -374,19 +373,19 @@ impl TokenProvider for InstanceCredentialProvider { async fn fetch_token( &self, _client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result> { const METADATA_IP: &str = "169.254.169.254"; const METADATA_HOST: &str = "metadata"; info!("fetching token from metadata server"); let response = - make_metadata_request(&self.client, METADATA_HOST, retry, &self.audience) + make_metadata_request(&self.client, METADATA_HOST, config, &self.audience) .or_else(|_| { make_metadata_request( &self.client, METADATA_IP, - retry, + config, &self.audience, ) }) @@ -473,7 +472,7 @@ impl TokenProvider for ApplicationDefaultCredentials { async fn fetch_token( &self, client: &Client, - retry: &RetryConfig, + config: &ClientConfig, ) -> Result, Error> { let builder = client.request(Method::POST, DEFAULT_TOKEN_GCP_URI); let builder = match self { @@ -493,7 +492,7 @@ impl TokenProvider for ApplicationDefaultCredentials { }; let response = builder - .send_retry(retry) + .send_retry(config) .await .context(TokenRequestSnafu)? .json::() diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index a6cf660220bd..458915c7454d 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -45,10 +45,11 @@ use reqwest::{header, Client, Method, Response, StatusCode}; use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt, Snafu}; use tokio::io::AsyncWrite; +use tokio::runtime::Handle; use url::Url; use crate::client::pagination::stream_paginated; -use crate::client::retry::RetryExt; +use crate::client::retry::{ClientConfig, RetryExt}; use crate::{ client::token::TokenCache, multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, @@ -244,7 +245,7 @@ struct GoogleCloudStorageClient { bucket_name: String, bucket_name_encoded: String, - retry_config: RetryConfig, + client_config: ClientConfig, client_options: ClientOptions, // TODO: Hook this up in tests @@ -257,7 +258,7 @@ impl GoogleCloudStorageClient { Ok(self .token_cache .get_or_insert_with(|| { - token_provider.fetch_token(&self.client, &self.retry_config) + token_provider.fetch_token(&self.client, &self.client_config) }) .await .context(CredentialSnafu)?) @@ -299,7 +300,7 @@ impl GoogleCloudStorageClient { let response = builder .bearer_auth(token) .query(&[("alt", alt)]) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(GetRequestSnafu { path: path.as_ref(), @@ -328,7 +329,7 @@ impl GoogleCloudStorageClient { .header(header::CONTENT_LENGTH, payload.len()) .query(&[("uploadType", "media"), ("name", path.as_ref())]) .body(payload) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(PutRequestSnafu)?; @@ -352,7 +353,7 @@ impl GoogleCloudStorageClient { .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, "0") .query(&[("uploads", "")]) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(PutRequestSnafu)?; @@ -384,7 +385,7 @@ impl GoogleCloudStorageClient { .header(header::CONTENT_TYPE, "application/octet-stream") .header(header::CONTENT_LENGTH, "0") .query(&[("uploadId", multipart_id)]) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(PutRequestSnafu)?; @@ -399,7 +400,7 @@ impl GoogleCloudStorageClient { let builder = self.client.request(Method::DELETE, url); builder .bearer_auth(token) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(DeleteRequestSnafu { path: path.as_ref(), @@ -441,7 +442,7 @@ impl GoogleCloudStorageClient { // Needed if reqwest is compiled with native-tls instead of rustls-tls // See https://github.com/apache/arrow-rs/pull/3921 .header(header::CONTENT_LENGTH, 0) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .map_err(|err| { if err @@ -500,7 +501,7 @@ impl GoogleCloudStorageClient { .request(Method::GET, url) .query(&query) .bearer_auth(token) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(ListRequestSnafu)? .json() @@ -566,7 +567,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .header(header::CONTENT_TYPE, "application/octet-stream") .header(header::CONTENT_LENGTH, format!("{}", buf.len())) .body(buf) - .send_retry(&self.client.retry_config) + .send_retry(&self.client.client_config) .await?; let content_id = response @@ -623,7 +624,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .bearer_auth(token) .query(&[("uploadId", upload_id)]) .body(data) - .send_retry(&self.client.retry_config) + .send_retry(&self.client.client_config) .await?; Ok(()) @@ -778,8 +779,8 @@ pub struct GoogleCloudStorageBuilder { service_account_key: Option, /// Path to the application credentials file. application_credentials_path: Option, - /// Retry config - retry_config: RetryConfig, + /// Client config + client_config: ClientConfig, /// Client options client_options: ClientOptions, } @@ -882,7 +883,7 @@ impl Default for GoogleCloudStorageBuilder { service_account_path: None, service_account_key: None, application_credentials_path: None, - retry_config: Default::default(), + client_config: Default::default(), client_options: ClientOptions::new().with_allow_http(true), url: None, } @@ -1089,7 +1090,17 @@ impl GoogleCloudStorageBuilder { /// Set the retry configuration pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; + self.client_config.retry = retry_config; + self + } + + /// Set the tokio runtime to use to perform IO + /// + /// This allows isolating IO into a dedicated [`Runtime`](tokio::runtime::Runtime) either + /// to ensure acceptable scheduling jitter in the presence of CPU-bound tasks, or to allow + /// using `object_store` outside of a tokio context + pub fn with_tokio_runtime(mut self, runtime: Handle) -> Self { + self.client_config.runtime = Some(runtime); self } @@ -1188,7 +1199,7 @@ impl GoogleCloudStorageBuilder { token_cache: Default::default(), bucket_name, bucket_name_encoded: encoded_bucket_name, - retry_config: self.retry_config, + client_config: self.client_config, client_options: self.client_options, max_list_results: None, }), @@ -1218,6 +1229,7 @@ mod test { use std::io::Write; use tempfile::NamedTempFile; + use crate::tests::dedicated_tokio; use crate::{ tests::{ copy_if_not_exists, get_nonexistent_object, list_uses_directories_correctly, @@ -1278,22 +1290,33 @@ mod test { }}; } - #[tokio::test] - async fn gcs_test() { - let integration = maybe_skip_integration!().build().unwrap(); + #[test] + fn gcs_test() { + let builder = maybe_skip_integration!(); + let test = |integration: GoogleCloudStorage| async move { + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + if integration.client.base_url == default_gcs_base_url() { + // Fake GCS server doesn't currently honor ifGenerationMatch + // https://github.com/fsouza/fake-gcs-server/issues/994 + copy_if_not_exists(&integration).await; + // Fake GCS server does not yet implement XML Multipart uploads + // https://github.com/fsouza/fake-gcs-server/issues/852 + stream_get(&integration).await; + } + }; - put_get_delete_list(&integration).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - if integration.client.base_url == default_gcs_base_url() { - // Fake GCS server doesn't currently honor ifGenerationMatch - // https://github.com/fsouza/fake-gcs-server/issues/994 - copy_if_not_exists(&integration).await; - // Fake GCS server does not yet implement XML Multipart uploads - // https://github.com/fsouza/fake-gcs-server/issues/852 - stream_get(&integration).await; - } + let (handle, shutdown) = dedicated_tokio(); + + let integration = builder.clone().build().unwrap(); + handle.block_on(test(integration)); + + let integration = builder.with_tokio_runtime(handle).build().unwrap(); + futures::executor::block_on(test(integration)); + + shutdown(); } #[tokio::test] diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index 5ef272180abc..327171953ac9 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::client::retry::{self, RetryConfig, RetryExt}; +use crate::client::retry::{self, ClientConfig, RetryExt}; use crate::path::{Path, DELIMITER}; use crate::util::{deserialize_rfc1123, format_http_range}; use crate::{ClientOptions, ObjectMeta, Result}; @@ -79,7 +79,7 @@ impl From for crate::Error { pub struct Client { url: Url, client: reqwest::Client, - retry_config: RetryConfig, + client_config: ClientConfig, client_options: ClientOptions, } @@ -87,12 +87,12 @@ impl Client { pub fn new( url: Url, client_options: ClientOptions, - retry_config: RetryConfig, + client_config: ClientConfig, ) -> Result { let client = client_options.client()?; Ok(Self { url, - retry_config, + client_config, client_options, client, }) @@ -118,7 +118,7 @@ impl Client { self.client .request(method, url) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(RequestSnafu)?; @@ -163,7 +163,7 @@ impl Client { builder = builder.header(CONTENT_TYPE, value); } - match builder.send_retry(&self.retry_config).await { + match builder.send_retry(&self.client_config).await { Ok(_) => return Ok(()), Err(source) => match source.status() { // Some implementations return 404 instead of 409 @@ -191,7 +191,7 @@ impl Client { .client .request(method, url) .header("Depth", depth) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await; let response = match result { @@ -223,7 +223,7 @@ impl Client { let url = self.path_url(path); self.client .delete(url) - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .context(RequestSnafu)?; Ok(()) @@ -242,7 +242,7 @@ impl Client { } builder - .send_retry(&self.retry_config) + .send_retry(&self.client_config) .await .map_err(|source| match source.status() { Some(StatusCode::NOT_FOUND) => crate::Error::NotFound { @@ -267,7 +267,7 @@ impl Client { builder = builder.header("Overwrite", "F"); } - match builder.send_retry(&self.retry_config).await { + match builder.send_retry(&self.client_config).await { Ok(_) => Ok(()), Err(e) if !overwrite diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index c91faa2358ac..b14f1a054c30 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -40,8 +40,10 @@ use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use snafu::{OptionExt, ResultExt, Snafu}; use tokio::io::AsyncWrite; +use tokio::runtime::Handle; use url::Url; +use crate::client::retry::ClientConfig; use crate::http::client::Client; use crate::path::Path; use crate::{ @@ -227,7 +229,7 @@ impl ObjectStore for HttpStore { pub struct HttpBuilder { url: Option, client_options: ClientOptions, - retry_config: RetryConfig, + client_config: ClientConfig, } impl HttpBuilder { @@ -244,7 +246,17 @@ impl HttpBuilder { /// Set the retry configuration pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; + self.client_config.retry = retry_config; + self + } + + /// Set the tokio runtime to use to perform IO + /// + /// This allows isolating IO into a dedicated [`Runtime`](tokio::runtime::Runtime) either + /// to ensure acceptable scheduling jitter in the presence of CPU-bound tasks, or to allow + /// using `object_store` outside of a tokio context + pub fn with_tokio_runtime(mut self, runtime: Handle) -> Self { + self.client_config.runtime = Some(runtime); self } @@ -260,7 +272,7 @@ impl HttpBuilder { let parsed = Url::parse(&url).context(UnableToParseUrlSnafu { url })?; Ok(HttpStore { - client: Client::new(parsed, self.client_options, self.retry_config)?, + client: Client::new(parsed, self.client_options, self.client_config)?, }) } } @@ -271,8 +283,16 @@ mod tests { use super::*; - #[tokio::test] - async fn http_test() { + /// Deletes any directories left behind from previous tests + async fn cleanup_directories(integration: &HttpStore) { + let result = integration.list_with_delimiter(None).await.unwrap(); + for r in result.common_prefixes { + integration.delete(&r).await.unwrap(); + } + } + + #[test] + fn http_test() { dotenv::dotenv().ok(); let force = std::env::var("TEST_INTEGRATION"); if force.is_err() { @@ -281,16 +301,27 @@ mod tests { } let url = std::env::var("HTTP_URL").expect("HTTP_URL must be set"); let options = ClientOptions::new().with_allow_http(true); - let integration = HttpBuilder::new() + let builder = HttpBuilder::new() .with_url(url) - .with_client_options(options) - .build() - .unwrap(); - - put_get_delete_list_opts(&integration, false).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - copy_if_not_exists(&integration).await; + .with_client_options(options); + + let (handle, shutdown) = dedicated_tokio(); + + let test = |integration| async move { + cleanup_directories(&integration).await; + put_get_delete_list_opts(&integration, false).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + }; + + let integration = builder.clone().build().unwrap(); + handle.block_on(test(integration)); + + let integration = builder.with_tokio_runtime(handle).build().unwrap(); + futures::executor::block_on(test(integration)); + + shutdown(); } } diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index c31027c0715c..f381b54cf9b8 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -740,11 +740,34 @@ mod tests { use super::*; use crate::test_util::flatten_list_stream; use tokio::io::AsyncWriteExt; + use tokio::runtime::Handle; pub(crate) async fn put_get_delete_list(storage: &DynObjectStore) { put_get_delete_list_opts(storage, false).await } + /// Create a tokio pool in a dedicate thread, returning the [`Handle`] and a shutdown function + pub(crate) fn dedicated_tokio() -> (Handle, impl FnOnce()) { + let (sender, receiver) = futures::channel::oneshot::channel(); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_time() + .enable_io() + .build() + .unwrap(); + let handle = runtime.handle().clone(); + + let join = std::thread::spawn(move || { + runtime.block_on(async move { receiver.await.unwrap() }) + }); + + let shutdown = move || { + let _ = sender.send(()); + join.join().unwrap(); + }; + + (handle, shutdown) + } + pub(crate) async fn put_get_delete_list_opts( storage: &DynObjectStore, skip_list_with_spaces: bool, diff --git a/object_store/src/local.rs b/object_store/src/local.rs index d2553d46f244..1d33d6af0eee 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -981,7 +981,7 @@ mod tests { } #[test] - fn test_non_tokio() { + fn local_non_tokio() { let root = TempDir::new().unwrap(); let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); futures::executor::block_on(async move {