Skip to content

Commit

Permalink
Add tests for oauth_session, implement oauth_session::store
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Jan 16, 2025
1 parent 26f1d04 commit 870f85a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 23 deletions.
6 changes: 3 additions & 3 deletions atrium-api/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ where
S: Store<(), U>,
U: Clone,
{
async fn get(&self) -> Result<Option<U>, S::Error> {
pub async fn get(&self) -> Result<Option<U>, S::Error> {
self.inner.get(&()).await
}
async fn set(&self, value: U) -> Result<(), S::Error> {
pub async fn set(&self, value: U) -> Result<(), S::Error> {
self.inner.set((), value).await
}
async fn clear(&self) -> Result<(), S::Error> {
pub async fn clear(&self) -> Result<(), S::Error> {
self.inner.clear().await
}
}
Expand Down
92 changes: 77 additions & 15 deletions atrium-oauth/oauth-client/src/oauth_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@ use atrium_api::{
use atrium_common::store::{memory::MemoryStore, Store};
use atrium_xrpc::{
http::{Request, Response},
Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
};
use jose_jwk::Key;
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt::Debug, sync::Arc};
use store::MemorySessionStore;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
Dpop(#[from] dpop::Error),
#[error(transparent)]
Store(#[from] atrium_common::store::memory::Error),
}

pub struct OAuthSession<T, D, H, S = MemoryStore<String, String>>
where
Expand All @@ -31,13 +40,14 @@ impl<T, D, H> OAuthSession<T, D, H>
where
T: HttpClient + Send + Sync,
{
pub(crate) fn new(
pub(crate) async fn new(
server_agent: OAuthServerAgent<T, D, H>,
dpop_key: Key,
http_client: Arc<T>,
token_set: TokenSet,
) -> Result<Self, dpop::Error> {
) -> Result<Self, Error> {
let store = Arc::new(InnerStore::new(MemorySessionStore::default(), token_set.aud.clone()));
store.set(token_set.access_token.clone()).await?;
let inner = inner::Client::new(
Arc::clone(&store),
DpopClient::new(
Expand Down Expand Up @@ -81,7 +91,7 @@ where
async fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> Result<OutputDataOrBytes<O>, Error<E>>
) -> Result<OutputDataOrBytes<O>, atrium_xrpc::Error<E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
Expand Down Expand Up @@ -147,7 +157,7 @@ mod tests {
client::Service,
did_doc::DidDocument,
types::string::Handle,
xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue},
xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode},
};
use atrium_common::resolver::Resolver;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
Expand All @@ -170,6 +180,17 @@ mod tests {
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
let mut headers = request.headers().clone();
let Some(authorization) = headers
.remove("authorization")
.and_then(|value| value.to_str().map(String::from).ok())
else {
return Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())?);
};
let Some(_token) = authorization.strip_prefix("DPoP ") else {
panic!("authorization header should start with DPoP");
};
// TODO: verify token

let dpop_jwt = headers.remove("dpop").expect("dpop header should be present");
let payload = dpop_jwt
.to_str()
Expand Down Expand Up @@ -227,9 +248,14 @@ mod tests {

impl HandleResolver for NoopHandleResolver {}

fn oauth_agent(
async fn oauth_session(
data: Arc<Mutex<Option<RecordData>>>,
) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
) -> OAuthSession<
MockHttpClient,
NoopDidResolver,
NoopHandleResolver,
MemoryStore<String, String>,
> {
let dpop_key = serde_json::from_str::<Key>(
r#"{
"kty": "EC",
Expand Down Expand Up @@ -270,14 +296,21 @@ mod tests {
token_type: OAuthTokenType::DPoP,
expires_at: None,
};
let oauth_session = OAuthSession::new(server_agent, dpop_key, http_client, token_set)
.expect("failed to create oauth session");
Agent::new(oauth_session)
OAuthSession::new(server_agent, dpop_key, http_client, token_set)
.await
.expect("failed to create oauth session")
}

async fn oauth_agent(
data: Arc<Mutex<Option<RecordData>>>,
) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
Agent::new(oauth_session(data).await)
}

async fn call_service(
service: &Service<impl SessionManager + Send + Sync>,
) -> Result<(), Error<atrium_api::com::atproto::server::get_service_auth::Error>> {
) -> Result<(), atrium_xrpc::Error<atrium_api::com::atproto::server::get_service_auth::Error>>
{
let output = service
.com
.atproto
Expand All @@ -298,15 +331,15 @@ mod tests {

#[tokio::test]
async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
let agent = oauth_agent(Arc::new(Mutex::new(Default::default())));
let agent = oauth_agent(Arc::new(Mutex::new(Default::default()))).await;
assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test"));
Ok(())
}

#[tokio::test]
async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let data = Arc::new(Mutex::new(Default::default()));
let agent = oauth_agent(Arc::clone(&data));
let agent = oauth_agent(Arc::clone(&data)).await;
call_service(&agent.api).await?;
assert_eq!(
data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
Expand All @@ -324,7 +357,7 @@ mod tests {
#[tokio::test]
async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
let data = Arc::new(Mutex::new(Default::default()));
let agent = oauth_agent(Arc::clone(&data));
let agent = oauth_agent(Arc::clone(&data)).await;
// not configured
{
call_service(&agent.api).await?;
Expand Down Expand Up @@ -371,7 +404,7 @@ mod tests {
#[tokio::test]
async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
let data = Arc::new(Mutex::new(Default::default()));
let agent = oauth_agent(data.clone());
let agent = oauth_agent(data.clone()).await;
// not configured
{
call_service(&agent.api).await?;
Expand Down Expand Up @@ -437,4 +470,33 @@ mod tests {
}
Ok(())
}

#[tokio::test]
async fn test_xrpc_without_token() -> Result<(), Box<dyn std::error::Error>> {
let oauth_session = oauth_session(Arc::new(Mutex::new(Default::default()))).await;
oauth_session.store.clear().await?;
let agent = Agent::new(oauth_session);
let result = agent
.api
.com
.atproto
.server
.get_service_auth(
atrium_api::com::atproto::server::get_service_auth::ParametersData {
aud: Did::new(String::from("did:fake:handle.test"))
.expect("did should be valid"),
exp: None,
lxm: None,
}
.into(),
)
.await;
match result.expect_err("should fail without token") {
atrium_xrpc::Error::XrpcResponse(err) => {
assert_eq!(err.status, StatusCode::UNAUTHORIZED);
}
_ => panic!("unexpected error"),
}
Ok(())
}
}
8 changes: 4 additions & 4 deletions atrium-oauth/oauth-client/src/oauth_session/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ impl Store<(), String> for MemorySessionStore {
type Error = store::memory::Error;

async fn get(&self, key: &()) -> Result<Option<String>, Self::Error> {
todo!()
self.0.get(key).await
}
async fn set(&self, key: (), value: String) -> Result<(), Self::Error> {
todo!()
self.0.set(key, value).await
}
async fn del(&self, key: &()) -> Result<(), Self::Error> {
todo!()
self.0.del(key).await
}
async fn clear(&self) -> Result<(), Self::Error> {
todo!()
self.0.clear().await
}
}

Expand Down
4 changes: 3 additions & 1 deletion atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub enum Error {
#[error(transparent)]
DpopClient(#[from] crate::http_client::dpop::Error),
#[error(transparent)]
OAuthSession(#[from] crate::oauth_session::Error),
#[error(transparent)]
Http(#[from] atrium_xrpc::http::Error),
#[error("http client error: {0}")]
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
Expand Down Expand Up @@ -317,7 +319,7 @@ where
let dpop_key = self.dpop_client.key.clone();
// TODO
let session = session_getter.get(&sub).await.expect("").unwrap();
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?)
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set).await?)
}
}

Expand Down

0 comments on commit 870f85a

Please sign in to comment.