diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 83679b57..e7320481 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -234,13 +234,13 @@ where S: Store<(), U>, U: Clone, { - async fn get(&self) -> Result, S::Error> { + pub async fn get(&self) -> Result, S::Error> { self.inner.get(&()).await } - async fn set(&self, value: U) -> Result<(), S::Error> { + pub async fn set(&self, value: U) -> Result<(), S::Error> { self.inner.set((), value).await } - async fn clear(&self) -> Result<(), S::Error> { + pub async fn clear(&self) -> Result<(), S::Error> { self.inner.clear().await } } diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs index c92ab6c4..7403b2de 100644 --- a/atrium-oauth/oauth-client/src/oauth_session.rs +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -9,12 +9,21 @@ use atrium_api::{ use atrium_common::store::{memory::MemoryStore, Store}; use atrium_xrpc::{ http::{Request, Response}, - Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, + HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, }; use jose_jwk::Key; use serde::{de::DeserializeOwned, Serialize}; use std::{fmt::Debug, sync::Arc}; use store::MemorySessionStore; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + Dpop(#[from] dpop::Error), + #[error(transparent)] + Store(#[from] atrium_common::store::memory::Error), +} pub struct OAuthSession> where @@ -31,13 +40,14 @@ impl OAuthSession where T: HttpClient + Send + Sync, { - pub(crate) fn new( + pub(crate) async fn new( server_agent: OAuthServerAgent, dpop_key: Key, http_client: Arc, token_set: TokenSet, - ) -> Result { + ) -> Result { let store = Arc::new(InnerStore::new(MemorySessionStore::default(), token_set.aud.clone())); + store.set(token_set.access_token.clone()).await?; let inner = inner::Client::new( Arc::clone(&store), DpopClient::new( @@ -81,7 +91,7 @@ where async fn send_xrpc( &self, request: &XrpcRequest, - ) -> Result, Error> + ) -> Result, atrium_xrpc::Error> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, @@ -147,7 +157,7 @@ mod tests { client::Service, did_doc::DidDocument, types::string::Handle, - xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue}, + xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode}, }; use atrium_common::resolver::Resolver; use atrium_identity::{did::DidResolver, handle::HandleResolver}; @@ -170,6 +180,17 @@ mod tests { request: Request>, ) -> Result>, Box> { let mut headers = request.headers().clone(); + let Some(authorization) = headers + .remove("authorization") + .and_then(|value| value.to_str().map(String::from).ok()) + else { + return Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())?); + }; + let Some(_token) = authorization.strip_prefix("DPoP ") else { + panic!("authorization header should start with DPoP"); + }; + // TODO: verify token + let dpop_jwt = headers.remove("dpop").expect("dpop header should be present"); let payload = dpop_jwt .to_str() @@ -227,9 +248,14 @@ mod tests { impl HandleResolver for NoopHandleResolver {} - fn oauth_agent( + async fn oauth_session( data: Arc>>, - ) -> Agent { + ) -> OAuthSession< + MockHttpClient, + NoopDidResolver, + NoopHandleResolver, + MemoryStore, + > { let dpop_key = serde_json::from_str::( r#"{ "kty": "EC", @@ -270,14 +296,21 @@ mod tests { token_type: OAuthTokenType::DPoP, expires_at: None, }; - let oauth_session = OAuthSession::new(server_agent, dpop_key, http_client, token_set) - .expect("failed to create oauth session"); - Agent::new(oauth_session) + OAuthSession::new(server_agent, dpop_key, http_client, token_set) + .await + .expect("failed to create oauth session") + } + + async fn oauth_agent( + data: Arc>>, + ) -> Agent { + Agent::new(oauth_session(data).await) } async fn call_service( service: &Service, - ) -> Result<(), Error> { + ) -> Result<(), atrium_xrpc::Error> + { let output = service .com .atproto @@ -298,7 +331,7 @@ mod tests { #[tokio::test] async fn test_new() -> Result<(), Box> { - let agent = oauth_agent(Arc::new(Mutex::new(Default::default()))); + let agent = oauth_agent(Arc::new(Mutex::new(Default::default()))).await; assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test")); Ok(()) } @@ -306,7 +339,7 @@ mod tests { #[tokio::test] async fn test_configure_endpoint() -> Result<(), Box> { let data = Arc::new(Mutex::new(Default::default())); - let agent = oauth_agent(Arc::clone(&data)); + let agent = oauth_agent(Arc::clone(&data)).await; call_service(&agent.api).await?; assert_eq!( data.lock().await.as_ref().expect("data should be recorded").host.as_deref(), @@ -324,7 +357,7 @@ mod tests { #[tokio::test] async fn test_configure_labelers_header() -> Result<(), Box> { let data = Arc::new(Mutex::new(Default::default())); - let agent = oauth_agent(Arc::clone(&data)); + let agent = oauth_agent(Arc::clone(&data)).await; // not configured { call_service(&agent.api).await?; @@ -371,7 +404,7 @@ mod tests { #[tokio::test] async fn test_configure_proxy_header() -> Result<(), Box> { let data = Arc::new(Mutex::new(Default::default())); - let agent = oauth_agent(data.clone()); + let agent = oauth_agent(data.clone()).await; // not configured { call_service(&agent.api).await?; @@ -437,4 +470,33 @@ mod tests { } Ok(()) } + + #[tokio::test] + async fn test_xrpc_without_token() -> Result<(), Box> { + let oauth_session = oauth_session(Arc::new(Mutex::new(Default::default()))).await; + oauth_session.store.clear().await?; + let agent = Agent::new(oauth_session); + let result = agent + .api + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await; + match result.expect_err("should fail without token") { + atrium_xrpc::Error::XrpcResponse(err) => { + assert_eq!(err.status, StatusCode::UNAUTHORIZED); + } + _ => panic!("unexpected error"), + } + Ok(()) + } } diff --git a/atrium-oauth/oauth-client/src/oauth_session/store.rs b/atrium-oauth/oauth-client/src/oauth_session/store.rs index 7a7f2312..eed70d55 100644 --- a/atrium-oauth/oauth-client/src/oauth_session/store.rs +++ b/atrium-oauth/oauth-client/src/oauth_session/store.rs @@ -14,16 +14,16 @@ impl Store<(), String> for MemorySessionStore { type Error = store::memory::Error; async fn get(&self, key: &()) -> Result, Self::Error> { - todo!() + self.0.get(key).await } async fn set(&self, key: (), value: String) -> Result<(), Self::Error> { - todo!() + self.0.set(key, value).await } async fn del(&self, key: &()) -> Result<(), Self::Error> { - todo!() + self.0.del(key).await } async fn clear(&self) -> Result<(), Self::Error> { - todo!() + self.0.clear().await } } diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index 6045358c..ea03ccd3 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -44,6 +44,8 @@ pub enum Error { #[error(transparent)] DpopClient(#[from] crate::http_client::dpop::Error), #[error(transparent)] + OAuthSession(#[from] crate::oauth_session::Error), + #[error(transparent)] Http(#[from] atrium_xrpc::http::Error), #[error("http client error: {0}")] HttpClient(Box), @@ -317,7 +319,7 @@ where let dpop_key = self.dpop_client.key.clone(); // TODO let session = session_getter.get(&sub).await.expect("").unwrap(); - Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?) + Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set).await?) } }