@@ -9,12 +9,21 @@ use atrium_api::{
9
9
use atrium_common:: store:: { memory:: MemoryStore , Store } ;
10
10
use atrium_xrpc:: {
11
11
http:: { Request , Response } ,
12
- Error , HttpClient , OutputDataOrBytes , XrpcClient , XrpcRequest ,
12
+ HttpClient , OutputDataOrBytes , XrpcClient , XrpcRequest ,
13
13
} ;
14
14
use jose_jwk:: Key ;
15
15
use serde:: { de:: DeserializeOwned , Serialize } ;
16
16
use std:: { fmt:: Debug , sync:: Arc } ;
17
17
use store:: MemorySessionStore ;
18
+ use thiserror:: Error ;
19
+
20
+ #[ derive( Error , Debug ) ]
21
+ pub enum Error {
22
+ #[ error( transparent) ]
23
+ Dpop ( #[ from] dpop:: Error ) ,
24
+ #[ error( transparent) ]
25
+ Store ( #[ from] atrium_common:: store:: memory:: Error ) ,
26
+ }
18
27
19
28
pub struct OAuthSession < T , D , H , S = MemoryStore < String , String > >
20
29
where
@@ -31,13 +40,14 @@ impl<T, D, H> OAuthSession<T, D, H>
31
40
where
32
41
T : HttpClient + Send + Sync ,
33
42
{
34
- pub ( crate ) fn new (
43
+ pub ( crate ) async fn new (
35
44
server_agent : OAuthServerAgent < T , D , H > ,
36
45
dpop_key : Key ,
37
46
http_client : Arc < T > ,
38
47
token_set : TokenSet ,
39
- ) -> Result < Self , dpop :: Error > {
48
+ ) -> Result < Self , Error > {
40
49
let store = Arc :: new ( InnerStore :: new ( MemorySessionStore :: default ( ) , token_set. aud . clone ( ) ) ) ;
50
+ store. set ( token_set. access_token . clone ( ) ) . await ?;
41
51
let inner = inner:: Client :: new (
42
52
Arc :: clone ( & store) ,
43
53
DpopClient :: new (
81
91
async fn send_xrpc < P , I , O , E > (
82
92
& self ,
83
93
request : & XrpcRequest < P , I > ,
84
- ) -> Result < OutputDataOrBytes < O > , Error < E > >
94
+ ) -> Result < OutputDataOrBytes < O > , atrium_xrpc :: Error < E > >
85
95
where
86
96
P : Serialize + Send + Sync ,
87
97
I : Serialize + Send + Sync ,
@@ -147,7 +157,7 @@ mod tests {
147
157
client:: Service ,
148
158
did_doc:: DidDocument ,
149
159
types:: string:: Handle ,
150
- xrpc:: http:: { header:: CONTENT_TYPE , HeaderMap , HeaderName , HeaderValue } ,
160
+ xrpc:: http:: { header:: CONTENT_TYPE , HeaderMap , HeaderName , HeaderValue , StatusCode } ,
151
161
} ;
152
162
use atrium_common:: resolver:: Resolver ;
153
163
use atrium_identity:: { did:: DidResolver , handle:: HandleResolver } ;
@@ -170,6 +180,17 @@ mod tests {
170
180
request : Request < Vec < u8 > > ,
171
181
) -> Result < Response < Vec < u8 > > , Box < dyn std:: error:: Error + Send + Sync + ' static > > {
172
182
let mut headers = request. headers ( ) . clone ( ) ;
183
+ let Some ( authorization) = headers
184
+ . remove ( "authorization" )
185
+ . and_then ( |value| value. to_str ( ) . map ( String :: from) . ok ( ) )
186
+ else {
187
+ return Ok ( Response :: builder ( ) . status ( StatusCode :: UNAUTHORIZED ) . body ( Vec :: new ( ) ) ?) ;
188
+ } ;
189
+ let Some ( _token) = authorization. strip_prefix ( "DPoP " ) else {
190
+ panic ! ( "authorization header should start with DPoP" ) ;
191
+ } ;
192
+ // TODO: verify token
193
+
173
194
let dpop_jwt = headers. remove ( "dpop" ) . expect ( "dpop header should be present" ) ;
174
195
let payload = dpop_jwt
175
196
. to_str ( )
@@ -227,9 +248,14 @@ mod tests {
227
248
228
249
impl HandleResolver for NoopHandleResolver { }
229
250
230
- fn oauth_agent (
251
+ async fn oauth_session (
231
252
data : Arc < Mutex < Option < RecordData > > > ,
232
- ) -> Agent < impl SessionManager + Configure + CloneWithProxy > {
253
+ ) -> OAuthSession <
254
+ MockHttpClient ,
255
+ NoopDidResolver ,
256
+ NoopHandleResolver ,
257
+ MemoryStore < String , String > ,
258
+ > {
233
259
let dpop_key = serde_json:: from_str :: < Key > (
234
260
r#"{
235
261
"kty": "EC",
@@ -270,14 +296,21 @@ mod tests {
270
296
token_type : OAuthTokenType :: DPoP ,
271
297
expires_at : None ,
272
298
} ;
273
- let oauth_session = OAuthSession :: new ( server_agent, dpop_key, http_client, token_set)
274
- . expect ( "failed to create oauth session" ) ;
275
- Agent :: new ( oauth_session)
299
+ OAuthSession :: new ( server_agent, dpop_key, http_client, token_set)
300
+ . await
301
+ . expect ( "failed to create oauth session" )
302
+ }
303
+
304
+ async fn oauth_agent (
305
+ data : Arc < Mutex < Option < RecordData > > > ,
306
+ ) -> Agent < impl SessionManager + Configure + CloneWithProxy > {
307
+ Agent :: new ( oauth_session ( data) . await )
276
308
}
277
309
278
310
async fn call_service (
279
311
service : & Service < impl SessionManager + Send + Sync > ,
280
- ) -> Result < ( ) , Error < atrium_api:: com:: atproto:: server:: get_service_auth:: Error > > {
312
+ ) -> Result < ( ) , atrium_xrpc:: Error < atrium_api:: com:: atproto:: server:: get_service_auth:: Error > >
313
+ {
281
314
let output = service
282
315
. com
283
316
. atproto
@@ -298,15 +331,15 @@ mod tests {
298
331
299
332
#[ tokio:: test]
300
333
async fn test_new ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
301
- let agent = oauth_agent ( Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ) ;
334
+ let agent = oauth_agent ( Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ) . await ;
302
335
assert_eq ! ( agent. did( ) . await . as_deref( ) , Some ( "did:fake:sub.test" ) ) ;
303
336
Ok ( ( ) )
304
337
}
305
338
306
339
#[ tokio:: test]
307
340
async fn test_configure_endpoint ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
308
341
let data = Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ;
309
- let agent = oauth_agent ( Arc :: clone ( & data) ) ;
342
+ let agent = oauth_agent ( Arc :: clone ( & data) ) . await ;
310
343
call_service ( & agent. api ) . await ?;
311
344
assert_eq ! (
312
345
data. lock( ) . await . as_ref( ) . expect( "data should be recorded" ) . host. as_deref( ) ,
@@ -324,7 +357,7 @@ mod tests {
324
357
#[ tokio:: test]
325
358
async fn test_configure_labelers_header ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
326
359
let data = Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ;
327
- let agent = oauth_agent ( Arc :: clone ( & data) ) ;
360
+ let agent = oauth_agent ( Arc :: clone ( & data) ) . await ;
328
361
// not configured
329
362
{
330
363
call_service ( & agent. api ) . await ?;
@@ -371,7 +404,7 @@ mod tests {
371
404
#[ tokio:: test]
372
405
async fn test_configure_proxy_header ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
373
406
let data = Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ;
374
- let agent = oauth_agent ( data. clone ( ) ) ;
407
+ let agent = oauth_agent ( data. clone ( ) ) . await ;
375
408
// not configured
376
409
{
377
410
call_service ( & agent. api ) . await ?;
@@ -437,4 +470,33 @@ mod tests {
437
470
}
438
471
Ok ( ( ) )
439
472
}
473
+
474
+ #[ tokio:: test]
475
+ async fn test_xrpc_without_token ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
476
+ let oauth_session = oauth_session ( Arc :: new ( Mutex :: new ( Default :: default ( ) ) ) ) . await ;
477
+ oauth_session. store . clear ( ) . await ?;
478
+ let agent = Agent :: new ( oauth_session) ;
479
+ let result = agent
480
+ . api
481
+ . com
482
+ . atproto
483
+ . server
484
+ . get_service_auth (
485
+ atrium_api:: com:: atproto:: server:: get_service_auth:: ParametersData {
486
+ aud : Did :: new ( String :: from ( "did:fake:handle.test" ) )
487
+ . expect ( "did should be valid" ) ,
488
+ exp : None ,
489
+ lxm : None ,
490
+ }
491
+ . into ( ) ,
492
+ )
493
+ . await ;
494
+ match result. expect_err ( "should fail without token" ) {
495
+ atrium_xrpc:: Error :: XrpcResponse ( err) => {
496
+ assert_eq ! ( err. status, StatusCode :: UNAUTHORIZED ) ;
497
+ }
498
+ _ => panic ! ( "unexpected error" ) ,
499
+ }
500
+ Ok ( ( ) )
501
+ }
440
502
}
0 commit comments