Skip to content
This repository was archived by the owner on Oct 6, 2020. It is now read-only.

Commit 9926697

Browse files
author
Francesco Cogno
authored
Azure AAD: OAuth 2.0 device code flow implementation (#301)
2 parents ffc4f79 + d6a75b8 commit 9926697

File tree

6 files changed

+371
-17
lines changed

6 files changed

+371
-17
lines changed

azure_sdk_auth_aad/Cargo.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "azure_sdk_auth_aad"
3-
version = "0.45.0"
3+
version = "0.45.1"
44
description = "Rust wrappers around Microsoft Azure REST APIs - Azure OAuth2 helper crate"
55
readme = "README.md"
66
authors = ["Francesco Cogno <francesco.cogno@outlook.com>"]
@@ -18,17 +18,20 @@ edition = "2018"
1818
azure_sdk_core = { path = "../azure_sdk_core", version = "0.43.3" }
1919
oauth2 = { version = "3.0.0-alpha.9", features = ["reqwest-010", "futures-03"], default-features = false}
2020
url = "2.1"
21-
failure = "0.1"
2221
futures = "0.3"
2322
serde = "1.0"
2423
serde_derive = "1.0"
2524
chrono = "0.4"
2625
serde_json = "1.0"
2726
log = "0.4"
2827
reqwest = { version = "0.10", features = ["json"] }
28+
async-timer = { version = "1.0.0-beta.3" }
29+
thiserror = "1.0"
2930

3031
[dev-dependencies]
31-
tokio = { version = "0.2", features = ["macros"] }
32+
tokio = { version = "0.2", features = ["macros"] }
33+
azure_sdk_storage_core = { version = "0.44" }
34+
azure_sdk_storage_blob = { version = "0.44" }
3235

3336
[features]
3437
test_e2e = []
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
use azure_sdk_auth_aad::*;
2+
use azure_sdk_storage_blob::prelude::*;
3+
use azure_sdk_storage_core::prelude::*;
4+
use futures::stream::StreamExt;
5+
use oauth2::ClientId;
6+
use std::env;
7+
use std::error::Error;
8+
use std::sync::Arc;
9+
10+
#[tokio::main]
11+
async fn main() -> Result<(), Box<dyn Error>> {
12+
let client_id =
13+
ClientId::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID environment variable."));
14+
let tenant_id = env::var("TENANT_ID").expect("Missing TENANT_ID environment variable.");
15+
16+
let storage_account_name = std::env::args()
17+
.nth(1)
18+
.expect("please specify the storage account name as first command line parameter");
19+
20+
let client = Arc::new(reqwest::Client::new());
21+
22+
// the process requires two steps. The first is to ask for
23+
// the code to show to the user. This is done with the following
24+
// function. Notice you can pass as many scopes as you want.
25+
// Since we are asking for the "offline_access" scope we will
26+
// receive the refresh token as well.
27+
// We are requesting access to the storage account passed as parameter.
28+
let device_code_flow = begin_authorize_device_code_flow(
29+
client.clone(),
30+
&tenant_id,
31+
&client_id,
32+
&[
33+
&format!(
34+
"https://{}.blob.core.windows.net/.default",
35+
storage_account_name
36+
),
37+
"offline_access",
38+
],
39+
)
40+
.await?;
41+
42+
// now we must show the user the authentication message. It
43+
// will point the user to the login page and show the code
44+
// they have to specify.
45+
println!("{}", device_code_flow.message());
46+
47+
// now we poll the auth endpoint until the user
48+
// completes the authentication. The following stream can
49+
// return, besides errors, a success meaning either
50+
// Success or Pending. The loop will continue until we
51+
// get either a Success or an error.
52+
let mut stream = Box::pin(device_code_flow.stream());
53+
let mut authorization = None;
54+
while let Some(resp) = stream.next().await {
55+
println!("{:?}", resp);
56+
57+
// if we have the authorization, let's store it for later use.
58+
if let DeviceCodeResponse::AuthorizationSucceded(auth) = resp? {
59+
authorization = Some(auth);
60+
}
61+
}
62+
63+
// remove the option (this is safe since we
64+
// unwrapped the errors before).
65+
let authorization = authorization.unwrap();
66+
67+
println!(
68+
"\nReceived valid bearer token: {}",
69+
&authorization.access_token.secret()
70+
);
71+
72+
if let Some(refresh_token) = authorization.refresh_token.as_ref() {
73+
println!("Received valid refresh token: {}", &refresh_token.secret());
74+
}
75+
76+
// we can now spend the access token in other crates. In
77+
// this example we are creating an Azure Storage client
78+
// using the access token.
79+
let client = client::with_bearer_token(
80+
&storage_account_name,
81+
&authorization.access_token.secret() as &str,
82+
);
83+
84+
// now we enumerate the containers in the
85+
// specified storage account.
86+
let containers = client.list_containers().finalize().await?;
87+
println!("\nList containers completed succesfully: {:?}", containers);
88+
89+
Ok(())
90+
}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
use crate::device_code_responses::*;
2+
use async_timer::timer::new_timer;
3+
use azure_sdk_core::errors::AzureError;
4+
use futures::stream::unfold;
5+
use log::debug;
6+
pub use oauth2::{ClientId, ClientSecret};
7+
use std::convert::TryInto;
8+
use std::sync::Arc;
9+
use std::time::Duration;
10+
use url::form_urlencoded;
11+
12+
#[derive(Debug, Clone, Deserialize)]
13+
pub struct DeviceCodePhaseOneResponse<'a> {
14+
device_code: String,
15+
user_code: String,
16+
verification_uri: String,
17+
expires_in: u64,
18+
interval: u64,
19+
message: String,
20+
// the skipped fields below do not come
21+
// from the Azure answer. They will be added
22+
// manually after deserialization
23+
#[serde(skip)]
24+
client: Arc<reqwest::Client>,
25+
#[serde(skip)]
26+
tenant_id: &'a str,
27+
// we store the ClientId as string instead of
28+
// the original type because it does not
29+
// implement Default and it's in another
30+
// create
31+
#[serde(skip)]
32+
client_id: String,
33+
}
34+
35+
pub async fn begin_authorize_device_code_flow<'a, 'b>(
36+
client: Arc<reqwest::Client>,
37+
tenant_id: &'a str,
38+
client_id: &'a ClientId,
39+
scopes: &'b [&'b str],
40+
) -> Result<DeviceCodePhaseOneResponse<'a>, AzureError> {
41+
let mut encoded = form_urlencoded::Serializer::new(String::new());
42+
let encoded = encoded.append_pair("client_id", client_id.as_str());
43+
let encoded = encoded.append_pair("scope", &scopes.join(" "));
44+
let encoded = encoded.finish();
45+
46+
debug!("encoded ==> {}", encoded);
47+
48+
let url = url::Url::parse(&format!(
49+
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
50+
tenant_id
51+
))?;
52+
53+
client
54+
.post(url)
55+
.header("ContentType", "application/x-www-form-urlencoded")
56+
.body(encoded)
57+
.send()
58+
.await
59+
.map_err(|e| AzureError::GenericErrorWithText(e.to_string()))?
60+
.text()
61+
.await
62+
.map_err(|e| AzureError::GenericErrorWithText(e.to_string()))
63+
.and_then(|s| {
64+
serde_json::from_str::<DeviceCodePhaseOneResponse>(&s)
65+
// we need to capture some variables that will be useful in
66+
// the second phase (the client, the tenant_id and the client_id)
67+
.map(|device_code_reponse| DeviceCodePhaseOneResponse {
68+
device_code: device_code_reponse.device_code,
69+
user_code: device_code_reponse.user_code,
70+
verification_uri: device_code_reponse.verification_uri,
71+
expires_in: device_code_reponse.expires_in,
72+
interval: device_code_reponse.interval,
73+
message: device_code_reponse.message,
74+
client: client.clone(),
75+
tenant_id,
76+
client_id: client_id.as_str().to_string(),
77+
})
78+
.map_err(|e| {
79+
serde_json::from_str::<crate::errors::ErrorResponse>(&s)
80+
.map(|er| AzureError::GenericErrorWithText(er.to_string()))
81+
.unwrap_or_else(|_| {
82+
AzureError::GenericErrorWithText(format!(
83+
"Failed to parse Azure response: {}",
84+
e.to_string()
85+
))
86+
})
87+
})
88+
})
89+
}
90+
91+
impl<'a> DeviceCodePhaseOneResponse<'a> {
92+
pub fn message(&self) -> &str {
93+
&self.message
94+
}
95+
96+
pub fn stream(
97+
&self,
98+
) -> impl futures::Stream<Item = Result<DeviceCodeResponse, DeviceCodeError>> + '_ {
99+
#[derive(Debug, Clone, PartialEq)]
100+
enum NextState {
101+
Continue,
102+
Finish,
103+
}
104+
105+
unfold(
106+
NextState::Continue,
107+
async move |state: NextState| match state {
108+
NextState::Continue => {
109+
let uri = format!(
110+
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
111+
self.tenant_id,
112+
);
113+
114+
// throttle down as specified by Azure. This could be
115+
// smarter: we could calculate the elapsed time since the
116+
// last poll and wait only the delta. For now we do not
117+
// need such precision.
118+
new_timer(Duration::from_secs(self.interval)).await;
119+
debug!("posting to {}", &uri);
120+
121+
let mut encoded = form_urlencoded::Serializer::new(String::new());
122+
let encoded = encoded
123+
.append_pair("grant_type", "urn:ietf:params:oauth:grant-type:device_code");
124+
let encoded = encoded.append_pair("client_id", self.client_id.as_str());
125+
let encoded = encoded.append_pair("device_code", &self.device_code);
126+
let encoded = encoded.finish();
127+
128+
let result = match self
129+
.client
130+
.post(&uri)
131+
.header("ContentType", "application/x-www-form-urlencoded")
132+
.body(encoded)
133+
.send()
134+
.await
135+
.map_err(DeviceCodeError::ReqwestError)
136+
{
137+
Ok(result) => result,
138+
Err(error) => return Some((Err(error), NextState::Finish)),
139+
};
140+
debug!("result (raw) ==> {:?}", result);
141+
142+
let result = match result.text().await.map_err(DeviceCodeError::ReqwestError) {
143+
Ok(result) => result,
144+
Err(error) => return Some((Err(error), NextState::Finish)),
145+
};
146+
debug!("result (as text) ==> {}", result);
147+
148+
// here either we get an error response from Azure
149+
// or we get a success. A success can be either "Pending" or
150+
// "Completed". We finish the loop only on "Completed" (ie Success)
151+
match result.try_into() {
152+
Ok(device_code_response) => {
153+
let next_state = match &device_code_response {
154+
DeviceCodeResponse::AuthorizationSucceded(_) => NextState::Finish,
155+
DeviceCodeResponse::AuthorizationPending(_) => NextState::Continue,
156+
};
157+
158+
Some((Ok(device_code_response), next_state))
159+
}
160+
Err(error) => Some((Err(error), NextState::Finish)),
161+
}
162+
}
163+
NextState::Finish => None,
164+
},
165+
)
166+
}
167+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use oauth2::AccessToken;
2+
use std::convert::TryInto;
3+
use std::fmt;
4+
use thiserror::Error;
5+
6+
#[derive(Debug, Clone, PartialEq, Deserialize)]
7+
pub struct DeviceCodeErrorResponse {
8+
pub error: String,
9+
pub error_description: String,
10+
pub error_uri: String,
11+
}
12+
13+
impl fmt::Display for DeviceCodeErrorResponse {
14+
// This trait requires `fmt` with this exact signature.
15+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16+
write!(f, "{}. {}", self.error, self.error_description)
17+
}
18+
}
19+
20+
#[derive(Debug, Clone, Deserialize)]
21+
pub struct DeviceCodeAuthorization {
22+
pub token_type: String,
23+
pub scope: String,
24+
pub expires_in: u64,
25+
pub access_token: AccessToken,
26+
pub refresh_token: Option<AccessToken>,
27+
pub id_token: Option<AccessToken>,
28+
}
29+
30+
#[derive(Error, Debug)]
31+
pub enum DeviceCodeError {
32+
#[error("Authorization declined")]
33+
AuthorizationDeclined(DeviceCodeErrorResponse),
34+
#[error("Bad verification code")]
35+
BadVerificationCode(DeviceCodeErrorResponse),
36+
#[error("Expired token")]
37+
ExpiredToken(DeviceCodeErrorResponse),
38+
#[error("Unrecognized error: {0}")]
39+
UnrecognizedError(DeviceCodeErrorResponse),
40+
#[error("Unhandled error: {0}. {1}")]
41+
UnhandledError(String, String),
42+
#[error("Reqwest error: {0}")]
43+
ReqwestError(reqwest::Error),
44+
}
45+
46+
#[derive(Debug, Clone)]
47+
pub enum DeviceCodeResponse {
48+
AuthorizationSucceded(DeviceCodeAuthorization),
49+
AuthorizationPending(DeviceCodeErrorResponse),
50+
}
51+
52+
impl TryInto<DeviceCodeResponse> for String {
53+
type Error = DeviceCodeError;
54+
55+
fn try_into(self) -> Result<DeviceCodeResponse, Self::Error> {
56+
// first we try to deserialize as DeviceCodeAuthorization (success)
57+
match serde_json::from_str::<DeviceCodeAuthorization>(&self) {
58+
Ok(device_code_authorization) => Ok(DeviceCodeResponse::AuthorizationSucceded(
59+
device_code_authorization,
60+
)),
61+
Err(_) => {
62+
// now we try to map it to a DeviceCodeErrorResponse
63+
match serde_json::from_str::<DeviceCodeErrorResponse>(&self) {
64+
Ok(device_code_error_response) => {
65+
match &device_code_error_response.error as &str {
66+
"authorization_pending" => {
67+
Ok(DeviceCodeResponse::AuthorizationPending(
68+
device_code_error_response,
69+
))
70+
}
71+
"authorization_declined" => Err(
72+
DeviceCodeError::AuthorizationDeclined(device_code_error_response),
73+
),
74+
75+
"bad_verification_code" => Err(DeviceCodeError::BadVerificationCode(
76+
device_code_error_response,
77+
)),
78+
"expired_token" => {
79+
Err(DeviceCodeError::ExpiredToken(device_code_error_response))
80+
}
81+
_ => Err(DeviceCodeError::UnrecognizedError(
82+
device_code_error_response,
83+
)),
84+
}
85+
}
86+
// If we cannot, we bail out giving the full error as string
87+
Err(error) => Err(DeviceCodeError::UnhandledError(error.to_string(), self)),
88+
}
89+
}
90+
}
91+
}
92+
}

0 commit comments

Comments
 (0)