Skip to content

Commit 870f85a

Browse files
committed
Add tests for oauth_session, implement oauth_session::store
1 parent 26f1d04 commit 870f85a

File tree

4 files changed

+87
-23
lines changed

4 files changed

+87
-23
lines changed

atrium-api/src/agent.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,13 @@ where
234234
S: Store<(), U>,
235235
U: Clone,
236236
{
237-
async fn get(&self) -> Result<Option<U>, S::Error> {
237+
pub async fn get(&self) -> Result<Option<U>, S::Error> {
238238
self.inner.get(&()).await
239239
}
240-
async fn set(&self, value: U) -> Result<(), S::Error> {
240+
pub async fn set(&self, value: U) -> Result<(), S::Error> {
241241
self.inner.set((), value).await
242242
}
243-
async fn clear(&self) -> Result<(), S::Error> {
243+
pub async fn clear(&self) -> Result<(), S::Error> {
244244
self.inner.clear().await
245245
}
246246
}

atrium-oauth/oauth-client/src/oauth_session.rs

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@ use atrium_api::{
99
use atrium_common::store::{memory::MemoryStore, Store};
1010
use atrium_xrpc::{
1111
http::{Request, Response},
12-
Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
12+
HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
1313
};
1414
use jose_jwk::Key;
1515
use serde::{de::DeserializeOwned, Serialize};
1616
use std::{fmt::Debug, sync::Arc};
1717
use store::MemorySessionStore;
18+
use thiserror::Error;
19+
20+
#[derive(Error, Debug)]
21+
pub enum Error {
22+
#[error(transparent)]
23+
Dpop(#[from] dpop::Error),
24+
#[error(transparent)]
25+
Store(#[from] atrium_common::store::memory::Error),
26+
}
1827

1928
pub struct OAuthSession<T, D, H, S = MemoryStore<String, String>>
2029
where
@@ -31,13 +40,14 @@ impl<T, D, H> OAuthSession<T, D, H>
3140
where
3241
T: HttpClient + Send + Sync,
3342
{
34-
pub(crate) fn new(
43+
pub(crate) async fn new(
3544
server_agent: OAuthServerAgent<T, D, H>,
3645
dpop_key: Key,
3746
http_client: Arc<T>,
3847
token_set: TokenSet,
39-
) -> Result<Self, dpop::Error> {
48+
) -> Result<Self, Error> {
4049
let store = Arc::new(InnerStore::new(MemorySessionStore::default(), token_set.aud.clone()));
50+
store.set(token_set.access_token.clone()).await?;
4151
let inner = inner::Client::new(
4252
Arc::clone(&store),
4353
DpopClient::new(
@@ -81,7 +91,7 @@ where
8191
async fn send_xrpc<P, I, O, E>(
8292
&self,
8393
request: &XrpcRequest<P, I>,
84-
) -> Result<OutputDataOrBytes<O>, Error<E>>
94+
) -> Result<OutputDataOrBytes<O>, atrium_xrpc::Error<E>>
8595
where
8696
P: Serialize + Send + Sync,
8797
I: Serialize + Send + Sync,
@@ -147,7 +157,7 @@ mod tests {
147157
client::Service,
148158
did_doc::DidDocument,
149159
types::string::Handle,
150-
xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue},
160+
xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode},
151161
};
152162
use atrium_common::resolver::Resolver;
153163
use atrium_identity::{did::DidResolver, handle::HandleResolver};
@@ -170,6 +180,17 @@ mod tests {
170180
request: Request<Vec<u8>>,
171181
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
172182
let mut headers = request.headers().clone();
183+
let Some(authorization) = headers
184+
.remove("authorization")
185+
.and_then(|value| value.to_str().map(String::from).ok())
186+
else {
187+
return Ok(Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())?);
188+
};
189+
let Some(_token) = authorization.strip_prefix("DPoP ") else {
190+
panic!("authorization header should start with DPoP");
191+
};
192+
// TODO: verify token
193+
173194
let dpop_jwt = headers.remove("dpop").expect("dpop header should be present");
174195
let payload = dpop_jwt
175196
.to_str()
@@ -227,9 +248,14 @@ mod tests {
227248

228249
impl HandleResolver for NoopHandleResolver {}
229250

230-
fn oauth_agent(
251+
async fn oauth_session(
231252
data: Arc<Mutex<Option<RecordData>>>,
232-
) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
253+
) -> OAuthSession<
254+
MockHttpClient,
255+
NoopDidResolver,
256+
NoopHandleResolver,
257+
MemoryStore<String, String>,
258+
> {
233259
let dpop_key = serde_json::from_str::<Key>(
234260
r#"{
235261
"kty": "EC",
@@ -270,14 +296,21 @@ mod tests {
270296
token_type: OAuthTokenType::DPoP,
271297
expires_at: None,
272298
};
273-
let oauth_session = OAuthSession::new(server_agent, dpop_key, http_client, token_set)
274-
.expect("failed to create oauth session");
275-
Agent::new(oauth_session)
299+
OAuthSession::new(server_agent, dpop_key, http_client, token_set)
300+
.await
301+
.expect("failed to create oauth session")
302+
}
303+
304+
async fn oauth_agent(
305+
data: Arc<Mutex<Option<RecordData>>>,
306+
) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
307+
Agent::new(oauth_session(data).await)
276308
}
277309

278310
async fn call_service(
279311
service: &Service<impl SessionManager + Send + Sync>,
280-
) -> Result<(), Error<atrium_api::com::atproto::server::get_service_auth::Error>> {
312+
) -> Result<(), atrium_xrpc::Error<atrium_api::com::atproto::server::get_service_auth::Error>>
313+
{
281314
let output = service
282315
.com
283316
.atproto
@@ -298,15 +331,15 @@ mod tests {
298331

299332
#[tokio::test]
300333
async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
301-
let agent = oauth_agent(Arc::new(Mutex::new(Default::default())));
334+
let agent = oauth_agent(Arc::new(Mutex::new(Default::default()))).await;
302335
assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test"));
303336
Ok(())
304337
}
305338

306339
#[tokio::test]
307340
async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
308341
let data = Arc::new(Mutex::new(Default::default()));
309-
let agent = oauth_agent(Arc::clone(&data));
342+
let agent = oauth_agent(Arc::clone(&data)).await;
310343
call_service(&agent.api).await?;
311344
assert_eq!(
312345
data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
@@ -324,7 +357,7 @@ mod tests {
324357
#[tokio::test]
325358
async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
326359
let data = Arc::new(Mutex::new(Default::default()));
327-
let agent = oauth_agent(Arc::clone(&data));
360+
let agent = oauth_agent(Arc::clone(&data)).await;
328361
// not configured
329362
{
330363
call_service(&agent.api).await?;
@@ -371,7 +404,7 @@ mod tests {
371404
#[tokio::test]
372405
async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
373406
let data = Arc::new(Mutex::new(Default::default()));
374-
let agent = oauth_agent(data.clone());
407+
let agent = oauth_agent(data.clone()).await;
375408
// not configured
376409
{
377410
call_service(&agent.api).await?;
@@ -437,4 +470,33 @@ mod tests {
437470
}
438471
Ok(())
439472
}
473+
474+
#[tokio::test]
475+
async fn test_xrpc_without_token() -> Result<(), Box<dyn std::error::Error>> {
476+
let oauth_session = oauth_session(Arc::new(Mutex::new(Default::default()))).await;
477+
oauth_session.store.clear().await?;
478+
let agent = Agent::new(oauth_session);
479+
let result = agent
480+
.api
481+
.com
482+
.atproto
483+
.server
484+
.get_service_auth(
485+
atrium_api::com::atproto::server::get_service_auth::ParametersData {
486+
aud: Did::new(String::from("did:fake:handle.test"))
487+
.expect("did should be valid"),
488+
exp: None,
489+
lxm: None,
490+
}
491+
.into(),
492+
)
493+
.await;
494+
match result.expect_err("should fail without token") {
495+
atrium_xrpc::Error::XrpcResponse(err) => {
496+
assert_eq!(err.status, StatusCode::UNAUTHORIZED);
497+
}
498+
_ => panic!("unexpected error"),
499+
}
500+
Ok(())
501+
}
440502
}

atrium-oauth/oauth-client/src/oauth_session/store.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ impl Store<(), String> for MemorySessionStore {
1414
type Error = store::memory::Error;
1515

1616
async fn get(&self, key: &()) -> Result<Option<String>, Self::Error> {
17-
todo!()
17+
self.0.get(key).await
1818
}
1919
async fn set(&self, key: (), value: String) -> Result<(), Self::Error> {
20-
todo!()
20+
self.0.set(key, value).await
2121
}
2222
async fn del(&self, key: &()) -> Result<(), Self::Error> {
23-
todo!()
23+
self.0.del(key).await
2424
}
2525
async fn clear(&self) -> Result<(), Self::Error> {
26-
todo!()
26+
self.0.clear().await
2727
}
2828
}
2929

atrium-oauth/oauth-client/src/server_agent.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ pub enum Error {
4444
#[error(transparent)]
4545
DpopClient(#[from] crate::http_client::dpop::Error),
4646
#[error(transparent)]
47+
OAuthSession(#[from] crate::oauth_session::Error),
48+
#[error(transparent)]
4749
Http(#[from] atrium_xrpc::http::Error),
4850
#[error("http client error: {0}")]
4951
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
@@ -317,7 +319,7 @@ where
317319
let dpop_key = self.dpop_client.key.clone();
318320
// TODO
319321
let session = session_getter.get(&sub).await.expect("").unwrap();
320-
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?)
322+
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set).await?)
321323
}
322324
}
323325

0 commit comments

Comments
 (0)