Skip to content

Commit b187a7b

Browse files
valkumdjc
authored andcommitted
implement retries for fetching tokens from the various endpoints
The official SDKs implement a retry mechanism for fetching the tokens from the metadata server in the case of I/O errors, etc. This adds a similar mechanism to the provided ServiceAccount implementations
1 parent 75e1e93 commit b187a7b

File tree

3 files changed

+78
-34
lines changed

3 files changed

+78
-34
lines changed

src/custom_service_account.rs

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,30 @@ impl ServiceAccount for CustomServiceAccount {
104104
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
105105
.finish();
106106

107-
let request = hyper::Request::post(&self.credentials.token_uri)
108-
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
109-
.body(hyper::Body::from(rqbody))
110-
.unwrap();
111-
112-
tracing::debug!("requesting token from service account: {:?}", request);
113-
let token = client
114-
.request(request)
115-
.await
116-
.map_err(Error::OAuthConnectionError)?
117-
.deserialize::<Token>()
118-
.await?;
107+
let mut retries = 0;
108+
let response = loop {
109+
let request = hyper::Request::post(&self.credentials.token_uri)
110+
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
111+
.body(hyper::Body::from(rqbody.clone()))
112+
.unwrap();
113+
114+
tracing::debug!("requesting token from service account: {request:?}");
115+
let err = match client.request(request).await {
116+
// Early return when the request succeeds
117+
Ok(response) => break response,
118+
Err(err) => err,
119+
};
120+
121+
tracing::warn!(
122+
"Failed to refresh token with GCP oauth2 token endpoint: {err}, trying again..."
123+
);
124+
retries += 1;
125+
if retries >= RETRY_COUNT {
126+
return Err(Error::OAuthConnectionError(err));
127+
}
128+
};
129+
130+
let token = response.deserialize::<Token>().await?;
119131

120132
let key = scopes.iter().map(|x| (*x).to_string()).collect();
121133
self.tokens.write().unwrap().insert(key, token.clone());
@@ -154,3 +166,6 @@ impl fmt::Debug for ApplicationCredentials {
154166
.finish()
155167
}
156168
}
169+
170+
/// How many times to attempt to fetch a token from the set credentials token endpoint.
171+
const RETRY_COUNT: u8 = 5;

src/default_authorized_user.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,31 @@ impl DefaultAuthorizedUser {
4848

4949
#[tracing::instrument]
5050
async fn get_token(cred: &UserCredentials, client: &HyperClient) -> Result<Token, Error> {
51-
let req = Self::build_token_request(&RefreshRequest {
52-
client_id: &cred.client_id,
53-
client_secret: &cred.client_secret,
54-
grant_type: "refresh_token",
55-
refresh_token: &cred.refresh_token,
56-
});
57-
58-
let token = client
59-
.request(req)
60-
.await
61-
.map_err(Error::OAuthConnectionError)?
62-
.deserialize()
63-
.await?;
64-
Ok(token)
51+
let mut retries = 0;
52+
let response = loop {
53+
let req = Self::build_token_request(&RefreshRequest {
54+
client_id: &cred.client_id,
55+
client_secret: &cred.client_secret,
56+
grant_type: "refresh_token",
57+
refresh_token: &cred.refresh_token,
58+
});
59+
60+
let err = match client.request(req).await {
61+
// Early return when the request succeeds
62+
Ok(response) => break response,
63+
Err(err) => err,
64+
};
65+
66+
tracing::warn!(
67+
"Failed to get token from GCP oauth2 token endpoint: {err}, trying again..."
68+
);
69+
retries += 1;
70+
if retries >= RETRY_COUNT {
71+
return Err(Error::OAuthConnectionError(err));
72+
}
73+
};
74+
75+
response.deserialize().await.map_err(Into::into)
6576
}
6677
}
6778

@@ -106,3 +117,6 @@ struct UserCredentials {
106117
/// Type
107118
pub(crate) r#type: String,
108119
}
120+
121+
/// How many times to attempt to fetch a token from the GCP token endpoint.
122+
const RETRY_COUNT: u8 = 5;

src/default_service_account.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,27 @@ impl DefaultServiceAccount {
3636

3737
#[tracing::instrument]
3838
async fn get_token(client: &HyperClient) -> Result<Token, Error> {
39+
let mut retries = 0;
3940
tracing::debug!("Getting token from GCP instance metadata server");
40-
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);
41-
let token = client
42-
.request(req)
43-
.await
44-
.map_err(Error::ConnectionError)?
45-
.deserialize()
46-
.await?;
47-
Ok(token)
41+
let response = loop {
42+
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);
43+
44+
let err = match client.request(req).await {
45+
// Early return when the request succeeds
46+
Ok(response) => break response,
47+
Err(err) => err,
48+
};
49+
50+
tracing::warn!(
51+
"Failed to get token from GCP instance metadata server: {err}, trying again..."
52+
);
53+
retries += 1;
54+
if retries >= RETRY_COUNT {
55+
return Err(Error::ConnectionError(err));
56+
}
57+
};
58+
59+
response.deserialize().await.map_err(Into::into)
4860
}
4961
}
5062

@@ -75,3 +87,6 @@ impl ServiceAccount for DefaultServiceAccount {
7587
Ok(token)
7688
}
7789
}
90+
91+
/// How many times to attempt to fetch a token from the GCP metadata server.
92+
const RETRY_COUNT: u8 = 5;

0 commit comments

Comments
 (0)