Skip to content

Commit

Permalink
rebase fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 21, 2024
1 parent 5d90387 commit 98ff557
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 78 deletions.
5 changes: 0 additions & 5 deletions atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ keywords = ["atproto", "bluesky", "oauth"]

[dependencies]
atrium-api = { workspace = true, features = ["agent"] }
# atrium-api = { workspace = true, default-features = false }
atrium-common.workspace = true
atrium-identity.workspace = true
atrium-xrpc.workspace = true
Expand Down Expand Up @@ -45,7 +44,3 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[features]
default = ["default-client"]
default-client = ["reqwest/default-tls"]

[[bin]]
name = "example"
path = "examples/main.rs"
1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let uri = url.trim().parse::<Uri>()?;
let params = serde_html_form::from_str(uri.query().unwrap())?;

let session_manager = client.callback::<MemoryMapStore<(), Session>>(params).await?;
let session = session_manager.get_session(false).await?;
println!("{}", serde_json::to_string_pretty(&session)?);
Expand Down
28 changes: 15 additions & 13 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub enum Error {
JwkCrypto(crypto::Error),
#[error("key does not match any alg supported by the server")]
UnsupportedKey,
#[error("nonce store error: {0}")]
Nonces(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
}
Expand Down Expand Up @@ -103,16 +105,16 @@ where
_ => unimplemented!(),
}
}
fn is_use_dpop_nonce_error(&self, response: &Response<Vec<u8>>) -> bool {
fn is_use_dpop_nonce_error(&self, response: &Response<Vec<u8>>, is_auth_server: bool) -> bool {
// https://datatracker.ietf.org/doc/html/rfc9449#name-authorization-server-provid
if response.status() == 400 {
if is_auth_server && response.status() == 400 {
if let Ok(res) = serde_json::from_slice::<ErrorResponse>(response.body()) {
return res.error == "use_dpop_nonce";
};
}
// https://datatracker.ietf.org/doc/html/rfc6750#section-3
// https://datatracker.ietf.org/doc/html/rfc9449#name-resource-server-provided-no
if response.status() == 401 {
if !is_auth_server && response.status() == 401 {
// https://datatracker.ietf.org/doc/html/rfc6750#section-3
if let Some(www_auth) =
response.headers().get("WWW-Authenticate").and_then(|v| v.to_str().ok())
{
Expand All @@ -135,6 +137,7 @@ impl<T, S> HttpClient for DpopClient<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: MapStore<String, String> + Send + Sync + 'static,
S::Error: Send + Sync + 'static,
{
async fn send_http(
&self,
Expand All @@ -145,21 +148,17 @@ where
let nonce_key = uri.authority().unwrap().to_string();
let htm = request.method().to_string();
let htu = uri.to_string();
// https://datatracker.ietf.org/doc/html/rfc9449#section-4.2
let ath = request
.headers()
.get("Authorization")
.filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP ")))
.map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..])));

let is_auth_server = uri.path().starts_with("/oauth");
let ath = match request.headers().get("Authorization").and_then(|v| v.to_str().ok()) {
Some(s) if s.starts_with("DPoP ") => {
Some(URL_SAFE_NO_PAD.encode(Sha256::digest(s.strip_prefix("DPoP ").unwrap())))
}
_ => None,
};

let init_nonce = self.nonces.get(&nonce_key).await?;
let init_nonce =
self.nonces.get(&nonce_key).await.map_err(|e| Error::Nonces(Box::new(e)))?;
let init_proof =
self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?;
request.headers_mut().insert("DPoP", init_proof.parse()?);
Expand All @@ -170,7 +169,10 @@ where
match &next_nonce {
Some(s) if next_nonce != init_nonce => {
// Store the fresh nonce for future requests
self.nonces.set(nonce_key, s.clone()).await?;
self.nonces
.set(nonce_key, s.clone())
.await
.map_err(|e| Error::Nonces(Box::new(e)))?;
}
_ => {
// No nonce was returned or it is the same as the one we sent. No need to
Expand All @@ -179,7 +181,7 @@ where
}
}

if !self.is_use_dpop_nonce_error(&response) {
if !self.is_use_dpop_nonce_error(&response, is_auth_server) {
return Ok(response);
}
let next_proof = self.build_proof(htm, htu, ath, next_nonce)?;
Expand Down
20 changes: 2 additions & 18 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::constants::FALLBACK_ALG;
use crate::error::{Error, Result};
use crate::http_client::dpop::{DpopClient, Error as DpopError};
use crate::keyset::Keyset;
use crate::oauth_session::OAuthSession;
use crate::resolver::{OAuthResolver, OAuthResolverConfig};
Expand Down Expand Up @@ -223,7 +222,7 @@ where
}
pub async fn callback<S>(&self, params: CallbackParams) -> Result<OAuthSession<S, T, D, H>>
where
S: MapStore<(), Session> + Default,
S: MapStore<(), Session> + Default + Send + Sync + 'static,
{
let Some(state_key) = params.state else {
return Err(Error::Callback("missing `state` parameter".into()));
Expand Down Expand Up @@ -258,6 +257,7 @@ where
let token_set = server.exchange_code(&params.code, &state.verifier).await?;
// TODO: store token_set to session store

let session = Session { dpop_key: state.dpop_key.clone(), token_set: token_set.clone() };
self.session_store.set(token_set.sub.clone(), session.clone()).await.unwrap();

let session_store = S::default();
Expand All @@ -280,22 +280,6 @@ where
URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default()));
(URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
}
fn create_session(
&self,
dpop_key: Key,
server_metadata: &OAuthAuthorizationServerMetadata,
client_metadata: &OAuthClientMetadata,
token_set: TokenSet,
) -> core::result::Result<OAuthSession<T>, DpopError> {
let dpop_client = DpopClient::new(
dpop_key,
client_metadata.client_id.clone(),
self.http_client.clone(),
false,
&server_metadata.token_endpoint_auth_signing_alg_values_supported,
)?;
Ok(OAuthSession::new(dpop_client, token_set))
}
pub async fn server_from_issuer(
&self,
issuer: &str,
Expand Down
78 changes: 37 additions & 41 deletions atrium-oauth/oauth-client/src/oauth_session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{store::session::Session, DpopClient, TokenSet};
use std::fmt::Debug;

use atrium_api::{agent::SessionManager, types::string::Did};
use atrium_common::store::MapStore;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
Expand All @@ -7,21 +8,28 @@ use atrium_xrpc::{
types::AuthorizationToken,
HttpClient, XrpcClient,
};
use chrono::TimeDelta;
use thiserror::Error;

use crate::{server_agent::OAuthServerAgent, store::session::Session};

#[derive(Clone, Debug, Error)]
pub enum Error {}

pub struct OAuthSession<T, S = MemoryMapStore<String, String>>
pub struct OAuthSession<S, T, D, H>
where
S: MapStore<(), Session>,
S: MapStore<(), Session> + Default,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
inner: DpopClient<T, S>,
token_set: TokenSet, // TODO: replace with a session store?
session_store: S,
server: OAuthServerAgent<T, D, H>,
}

impl<T, S> OAuthSession<T, S>
impl<S, T, D, H> OAuthSession<S, T, D, H>
where
S: MapStore<(), Session>,
S: MapStore<(), Session> + Default,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
Expand Down Expand Up @@ -50,57 +58,45 @@ where
}
}

impl<T, S> HttpClient for OAuthSession<T, S>
impl<S, T, D, H> HttpClient for OAuthSession<S, T, D, H>
where
S: MapStore<(), Session> + Default + Sync,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.send_http(request).await
self.server.send_http(request).await
}
}

impl<T, S> XrpcClient for OAuthSession<T, S>
impl<S, T, D, H> XrpcClient for OAuthSession<S, T, D, H>
where
S: MapStore<(), Session> + Default + Sync,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
fn base_uri(&self) -> String {
self.token_set.aud.clone()
}
async fn authorization_token(&self, _is_refresh: bool) -> Option<AuthorizationToken> {
Some(AuthorizationToken::Dpop(self.token_set.access_token.clone()))
let Ok(Some(Session { dpop_key: _, token_set })) =
futures::FutureExt::now_or_never(self.get_session(false)).transpose()
else {
panic!("session, now or never");
};
dbg!(&token_set);
token_set.aud
}
// async fn atproto_proxy_header(&self) -> Option<String> {
// todo!()
// }
// async fn atproto_accept_labelers_header(&self) -> Option<Vec<String>> {
// todo!()
// }
// async fn send_xrpc<P, I, O, E>(
// &self,
// request: &XrpcRequest<P, I>,
// ) -> Result<OutputDataOrBytes<O>, Error<E>>
// where
// P: Serialize + Send + Sync,
// I: Serialize + Send + Sync,
// O: DeserializeOwned + Send + Sync,
// E: DeserializeOwned + Send + Sync + Debug,
// {
// todo!()
// }
}

impl<T, S> SessionManager for OAuthSession<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: MapStore<String, String> + Send + Sync + 'static,
{
async fn did(&self) -> Option<Did> {
todo!()
async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
let Session { dpop_key: _, token_set } = self.get_session(false).await.ok()?;
dbg!(&token_set);
if is_refresh {
token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop)
} else {
Some(AuthorizationToken::Bearer(token_set.access_token.clone()))
}
}
}

Expand Down
1 change: 0 additions & 1 deletion atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ where
dpop_key,
client_metadata.client_id.clone(),
http_client,
true,
&server_metadata.token_endpoint_auth_signing_alg_values_supported,
)?;
Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset })
Expand Down

0 comments on commit 98ff557

Please sign in to comment.