Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update oauth_client #254

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ trait-variant.workspace = true

[dev-dependencies]
hickory-resolver.workspace = true
p256 = { workspace = true, features = ["pem"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

[features]
Expand Down
15 changes: 11 additions & 4 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_P
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient,
OAuthClientConfig, OAuthResolverConfig,
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
OAuthClientConfig, OAuthResolverConfig, Scope,
};
use atrium_xrpc::http::Uri;
use hickory_resolver::TokioAsyncResolver;
Expand Down Expand Up @@ -37,7 +37,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let http_client = Arc::new(DefaultHttpClient::default());
let config = OAuthClientConfig {
client_metadata: AtprotoLocalhostClientMetadata {
redirect_uris: vec!["http://127.0.0.1".to_string()],
redirect_uris: Some(vec![String::from("http://127.0.0.1/callback")]),
scopes: Some(vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
]),
},
keys: None,
resolver: OAuthResolverConfig {
Expand All @@ -61,7 +65,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.authorize(
std::env::var("HANDLE").unwrap_or(String::from("https://bsky.social")),
AuthorizeOptions {
scopes: Some(vec![String::from("atproto")]),
scopes: vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric)
],
..Default::default()
}
)
Expand Down
266 changes: 250 additions & 16 deletions atrium-oauth/oauth-client/src/atproto.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::keyset::Keyset;
use crate::types::{OAuthClientMetadata, TryIntoOAuthClientMetadata};
use atrium_xrpc::http::Uri;
use atrium_xrpc::http::uri::{InvalidUri, Scheme, Uri};
use serde::{Deserialize, Serialize};
use thiserror::Error;

Expand All @@ -18,6 +18,22 @@ pub enum Error {
EmptyJwks,
#[error("`private_key_jwt` auth method requires `token_endpoint_auth_signing_alg`, otherwise must not be provided")]
AuthSigningAlg,
#[error(transparent)]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error(transparent)]
LocalhostClient(#[from] LocalhostClientError),
}

#[derive(Error, Debug)]
pub enum LocalhostClientError {
#[error("invalid redirect_uri: {0}")]
Invalid(#[from] InvalidUri),
#[error("loopback client_id must use `http:` redirect_uri")]
NotHttpScheme,
#[error("loopback client_id must not use `localhost` as redirect_uri hostname")]
Localhost,
#[error("loopback client_id must not use loopback addresses as redirect_uri")]
NotLoopbackHost,
}

pub type Result<T> = core::result::Result<T, Error>;
Expand Down Expand Up @@ -56,22 +72,37 @@ impl From<GrantType> for String {
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum Scope {
Known(KnownScope),
Unknown(String),
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum KnownScope {
#[serde(rename = "atproto")]
Atproto,
#[serde(rename = "transition:generic")]
TransitionGeneric,
#[serde(rename = "transition:chat.bsky")]
TransitionChatBsky,
}

impl From<Scope> for String {
fn from(value: Scope) -> Self {
match value {
Scope::Atproto => String::from("atproto"),
impl AsRef<str> for Scope {
fn as_ref(&self) -> &str {
match self {
Self::Known(KnownScope::Atproto) => "atproto",
Self::Known(KnownScope::TransitionGeneric) => "transition:generic",
Self::Known(KnownScope::TransitionChatBsky) => "transition:chat.bsky",
Self::Unknown(value) => value,
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AtprotoLocalhostClientMetadata {
pub redirect_uris: Vec<String>,
pub redirect_uris: Option<Vec<String>>,
pub scopes: Option<Vec<Scope>>,
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand All @@ -90,14 +121,46 @@ impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata {
type Error = Error;

fn try_into_client_metadata(self, _: &Option<Keyset>) -> Result<OAuthClientMetadata> {
if self.redirect_uris.is_empty() {
return Err(Error::EmptyRedirectUris);
// validate redirect_uris
if let Some(redirect_uris) = &self.redirect_uris {
for redirect_uri in redirect_uris {
let uri = redirect_uri.parse::<Uri>().map_err(LocalhostClientError::Invalid)?;
if uri.scheme() != Some(&Scheme::HTTP) {
return Err(Error::LocalhostClient(LocalhostClientError::NotHttpScheme));
}
if uri.host() == Some("localhost") {
return Err(Error::LocalhostClient(LocalhostClientError::Localhost));
}
if uri.host().map_or(true, |host| host != "127.0.0.1" && host != "[::1]") {
return Err(Error::LocalhostClient(LocalhostClientError::NotLoopbackHost));
}
}
}
// determine client_id
#[derive(serde::Serialize)]
struct Parameters {
#[serde(skip_serializing_if = "Option::is_none")]
redirect_uri: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
}
let query = serde_html_form::to_string(Parameters {
redirect_uri: self.redirect_uris.clone(),
scope: self
.scopes
.map(|scopes| scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
})?;
let mut client_id = String::from("http://localhost");
if !query.is_empty() {
client_id.push_str(&format!("?{query}"));
}
Ok(OAuthClientMetadata {
client_id: String::from("http://localhost"),
client_id,
client_uri: None,
redirect_uris: self.redirect_uris,
scope: None, // will be set to `atproto`
redirect_uris: self
.redirect_uris
.unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]),
scope: None,
grant_types: None, // will be set to `authorization_code` and `refresh_token`
token_endpoint_auth_method: Some(String::from("none")),
dpop_bound_access_tokens: None, // will be set to `true`
Expand All @@ -121,7 +184,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
if !self.grant_types.contains(&GrantType::AuthorizationCode) {
return Err(Error::InvalidGrantTypes);
}
if !self.scopes.contains(&Scope::Atproto) {
if !self.scopes.contains(&Scope::Known(KnownScope::Atproto)) {
return Err(Error::InvalidScope);
}
let (jwks_uri, mut jwks) = (self.jwks_uri, None);
Expand Down Expand Up @@ -150,13 +213,184 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
redirect_uris: self.redirect_uris,
token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()),
grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()),
scope: Some(
self.scopes.into_iter().map(|v| v.into()).collect::<Vec<String>>().join(" "),
),
scope: Some(self.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
dpop_bound_access_tokens: Some(true),
jwks_uri,
jwks,
token_endpoint_auth_signing_alg: self.token_endpoint_auth_signing_alg,
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use elliptic_curve::SecretKey;
use jose_jwk::{Jwk, Key, Parameters};
use p256::pkcs8::DecodePrivateKey;

const PRIVATE_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgED1AAgC7Fc9kPh5T
4i4Tn+z+tc47W1zYgzXtyjJtD92hRANCAAT80DqC+Z/JpTO7/pkPBmWqIV1IGh1P
gbGGr0pN+oSing7cZ0169JaRHTNh+0LNQXrFobInX6cj95FzEdRyT4T3
-----END PRIVATE KEY-----"#;

#[test]
fn test_localhost_client_metadata_default() {
let metadata = AtprotoLocalhostClientMetadata::default();
assert_eq!(
metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("http://localhost"),
client_uri: None,
redirect_uris: vec![
String::from("http://127.0.0.1/"),
String::from("http://[::1]/"),
],
scope: None,
grant_types: None,
token_endpoint_auth_method: Some(AuthMethod::None.into()),
dpop_bound_access_tokens: None,
jwks_uri: None,
jwks: None,
token_endpoint_auth_signing_alg: None,
}
);
}

#[test]
fn test_localhost_client_metadata_custom() {
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![
String::from("http://127.0.0.1/callback"),
String::from("http://[::1]/callback"),
]),
scopes: Some(vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
Scope::Unknown(String::from("unknown")),
]),
};
assert_eq!(
metadata.try_into_client_metadata(&None).expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("http://localhost?redirect_uri=http%3A%2F%2F127.0.0.1%2Fcallback&redirect_uri=http%3A%2F%2F%5B%3A%3A1%5D%2Fcallback&scope=atproto+transition%3Ageneric+unknown"),
client_uri: None,
redirect_uris: vec![
String::from("http://127.0.0.1/callback"),
String::from("http://[::1]/callback"),
],
scope: None,
grant_types: None,
token_endpoint_auth_method: Some(AuthMethod::None.into()),
dpop_bound_access_tokens: None,
jwks_uri: None,
jwks: None,
token_endpoint_auth_signing_alg: None,
}
);
}

#[test]
fn test_localhost_client_metadata_invalid() {
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Invalid(_))));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("https://127.0.0.1/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotHttpScheme)));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://localhost:8000/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::Localhost)));
}
{
let metadata = AtprotoLocalhostClientMetadata {
redirect_uris: Some(vec![String::from("http://192.168.0.0/")]),
..Default::default()
};
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::LocalhostClient(LocalhostClientError::NotLoopbackHost)));
}
}

#[test]
fn test_client_metadata() {
let metadata = AtprotoClientMetadata {
client_id: String::from("https://example.com/client_metadata.json"),
client_uri: String::from("https://example.com"),
redirect_uris: vec![String::from("https://example.com/callback")],
token_endpoint_auth_method: AuthMethod::PrivateKeyJwt,
grant_types: vec![GrantType::AuthorizationCode],
scopes: vec![Scope::Known(KnownScope::Atproto)],
jwks_uri: None,
token_endpoint_auth_signing_alg: Some(String::from("ES256")),
};
{
let metadata = metadata.clone();
let err = metadata.try_into_client_metadata(&None).expect_err("expected to fail");
assert!(matches!(err, Error::EmptyJwks));
}
{
let metadata = metadata.clone();
let secret_key = SecretKey::<p256::NistP256>::from_pkcs8_pem(PRIVATE_KEY)
.expect("failed to parse private key");
let keys = vec![Jwk {
key: Key::from(&secret_key.into()),
prm: Parameters { kid: Some(String::from("kid00")), ..Default::default() },
}];
let keyset = Keyset::try_from(keys.clone()).expect("failed to create keyset");
assert_eq!(
metadata
.try_into_client_metadata(&Some(keyset.clone()))
.expect("failed to convert metadata"),
OAuthClientMetadata {
client_id: String::from("https://example.com/client_metadata.json"),
client_uri: Some(String::from("https://example.com")),
redirect_uris: vec![String::from("https://example.com/callback"),],
scope: Some(String::from("atproto")),
grant_types: Some(vec![String::from("authorization_code")]),
token_endpoint_auth_method: Some(AuthMethod::PrivateKeyJwt.into()),
dpop_bound_access_tokens: Some(true),
jwks_uri: None,
jwks: Some(keyset.public_jwks()),
token_endpoint_auth_signing_alg: Some(String::from("ES256")),
}
);
}
}

#[test]
fn test_scope_serde() {
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Scopes {
scopes: Vec<Scope>,
}

let scopes = Scopes {
scopes: vec![
Scope::Known(KnownScope::Atproto),
Scope::Known(KnownScope::TransitionGeneric),
Scope::Unknown(String::from("unknown")),
],
};
let json = serde_json::to_string(&scopes).expect("failed to serialize scopes");
assert_eq!(json, r#"{"scopes":["atproto","transition:generic","unknown"]}"#);
let deserialized =
serde_json::from_str::<Scopes>(&json).expect("failed to deserialize scopes");
assert_eq!(deserialized, scopes);
}
}
Loading
Loading