5
5
use base64:: { prelude:: BASE64_URL_SAFE_NO_PAD , Engine } ;
6
6
use chrono:: { TimeDelta , Utc } ;
7
7
use dropshot:: {
8
- http_response_temporary_redirect, Body , ClientErrorStatusCode , HttpError , HttpResponseOk ,
9
- HttpResponseTemporaryRedirect , Path , Query , RequestContext , RequestInfo , TypedBody ,
8
+ http_response_temporary_redirect, Body , ClientErrorStatusCode , ExclusiveExtractor , HttpError ,
9
+ HttpResponseOk , HttpResponseTemporaryRedirect , Path , Query , RequestContext , RequestInfo ,
10
+ SharedExtractor , TypedBody ,
10
11
} ;
12
+ use dropshot_authorization_header:: basic:: BasicAuth ;
11
13
use http:: {
12
14
header:: { LOCATION , SET_COOKIE } ,
13
15
HeaderValue , StatusCode ,
@@ -18,7 +20,7 @@ use oauth2::{
18
20
AuthorizationCode , CsrfToken , PkceCodeChallenge , PkceCodeVerifier , Scope , TokenResponse ,
19
21
} ;
20
22
use schemars:: JsonSchema ;
21
- use secrecy:: SecretString ;
23
+ use secrecy:: { ExposeSecret , SecretString } ;
22
24
use serde:: { Deserialize , Serialize } ;
23
25
use sha2:: { Digest , Sha256 } ;
24
26
use std:: { fmt:: Debug , ops:: Add } ;
@@ -121,7 +123,7 @@ where
121
123
{
122
124
let client = ctx
123
125
. oauth
124
- . get_oauth_client ( & ctx. builtin_registration_user ( ) , & client_id)
126
+ . get_oauth_client ( & ctx. builtin_registration_user ( ) , client_id)
125
127
. await
126
128
. map_err ( |err| {
127
129
tracing:: error!( ?err, "Failed to lookup OAuth client" ) ;
@@ -437,8 +439,8 @@ where
437
439
438
440
#[ derive( Debug , Deserialize , JsonSchema ) ]
439
441
pub struct OAuthAuthzCodeExchangeBody {
440
- pub client_id : TypedUuid < OAuthClientId > ,
441
- pub client_secret : OpenApiSecretString ,
442
+ pub client_id : Option < TypedUuid < OAuthClientId > > ,
443
+ pub client_secret : Option < OpenApiSecretString > ,
442
444
pub redirect_uri : String ,
443
445
pub grant_type : String ,
444
446
pub code : String ,
@@ -464,6 +466,34 @@ where
464
466
let ctx = rqctx. v_ctx ( ) ;
465
467
let path = path. into_inner ( ) ;
466
468
let body = body. into_inner ( ) ;
469
+
470
+ let ( client_id, client_secret) =
471
+ if let ( Some ( client_id) , Some ( client_secret) ) = ( body. client_id , body. client_secret ) {
472
+ Ok :: < _ , HttpError > ( ( client_id, client_secret) )
473
+ } else {
474
+ // Attempt to extract basic authorization credentials from the request if they were not
475
+ // present in the request body
476
+ let auth = <BasicAuth as SharedExtractor >:: from_request ( rqctx)
477
+ . await
478
+ . tap_err ( |err| {
479
+ tracing:: warn!( ?err, "Failed to extract basic authentication values" ) ;
480
+ } ) ;
481
+ let ( client_id, client_secret) = match auth {
482
+ Ok ( auth) if auth. username ( ) . is_some ( ) && auth. password ( ) . is_some ( ) => Ok ( (
483
+ auth. username ( ) . unwrap ( ) . to_string ( ) ,
484
+ auth. password ( ) . unwrap ( ) . to_string ( ) ,
485
+ ) ) ,
486
+ _ => Err ( internal_error (
487
+ "Missing client id and client secret from authz code exchange" ,
488
+ ) ) ,
489
+ } ?;
490
+
491
+ Ok ( (
492
+ client_id. parse ( ) . map_err ( to_internal_error) ?,
493
+ OpenApiSecretString ( client_secret. into ( ) ) ,
494
+ ) )
495
+ } ?;
496
+
467
497
let provider = ctx
468
498
. get_oauth_provider ( & path. provider )
469
499
. await
@@ -475,8 +505,8 @@ where
475
505
authorize_code_exchange (
476
506
& ctx,
477
507
& body. grant_type ,
478
- & body . client_id ,
479
- & body . client_secret . 0 ,
508
+ client_id,
509
+ & client_secret. 0 ,
480
510
& body. redirect_uri ,
481
511
)
482
512
. await ?;
@@ -499,7 +529,7 @@ where
499
529
// Verify that the login attempt is valid and matches the submitted client credentials
500
530
verify_login_attempt (
501
531
& attempt,
502
- & body . client_id ,
532
+ client_id,
503
533
& body. redirect_uri ,
504
534
body. pkce_verifier . as_deref ( ) ,
505
535
) ?;
@@ -544,7 +574,7 @@ where
544
574
async fn authorize_code_exchange < T > (
545
575
ctx : & VContext < T > ,
546
576
grant_type : & str ,
547
- client_id : & TypedUuid < OAuthClientId > ,
577
+ client_id : TypedUuid < OAuthClientId > ,
548
578
client_secret : & SecretString ,
549
579
redirect_uri : & str ,
550
580
) -> Result < ( ) , OAuthError >
@@ -594,11 +624,11 @@ where
594
624
595
625
fn verify_login_attempt (
596
626
attempt : & LoginAttempt ,
597
- client_id : & TypedUuid < OAuthClientId > ,
627
+ client_id : TypedUuid < OAuthClientId > ,
598
628
redirect_uri : & str ,
599
629
pkce_verifier : Option < & str > ,
600
630
) -> Result < ( ) , OAuthError > {
601
- if attempt. client_id != * client_id {
631
+ if attempt. client_id != client_id {
602
632
Err ( OAuthError {
603
633
error : OAuthErrorCode :: InvalidGrant ,
604
634
error_description : Some ( "Invalid client id" . to_string ( ) ) ,
@@ -1238,7 +1268,7 @@ mod tests {
1238
1268
authorize_code_exchange(
1239
1269
& ctx,
1240
1270
"authorization_code" ,
1241
- & wrong_client_id,
1271
+ wrong_client_id,
1242
1272
& client_secret,
1243
1273
& redirect_uri,
1244
1274
)
@@ -1253,7 +1283,7 @@ mod tests {
1253
1283
authorize_code_exchange(
1254
1284
& ctx,
1255
1285
"authorization_code" ,
1256
- & client_id,
1286
+ client_id,
1257
1287
& client_secret,
1258
1288
"wrong-callback-destination" ,
1259
1289
)
@@ -1268,7 +1298,7 @@ mod tests {
1268
1298
authorize_code_exchange(
1269
1299
& ctx,
1270
1300
"authorization_code" ,
1271
- & client_id,
1301
+ client_id,
1272
1302
& client_secret,
1273
1303
& redirect_uri,
1274
1304
)
@@ -1299,7 +1329,7 @@ mod tests {
1299
1329
authorize_code_exchange(
1300
1330
& ctx,
1301
1331
"not_authorization_code" ,
1302
- & client_id,
1332
+ client_id,
1303
1333
& client_secret,
1304
1334
& redirect_uri
1305
1335
)
@@ -1313,7 +1343,7 @@ mod tests {
1313
1343
authorize_code_exchange(
1314
1344
& ctx,
1315
1345
"authorization_code" ,
1316
- & client_id,
1346
+ client_id,
1317
1347
& client_secret,
1318
1348
& redirect_uri
1319
1349
)
@@ -1351,7 +1381,7 @@ mod tests {
1351
1381
authorize_code_exchange(
1352
1382
& ctx,
1353
1383
"authorization_code" ,
1354
- & client_id,
1384
+ client_id,
1355
1385
& "too-short" . to_string( ) . into( ) ,
1356
1386
& redirect_uri
1357
1387
)
@@ -1365,7 +1395,7 @@ mod tests {
1365
1395
authorize_code_exchange(
1366
1396
& ctx,
1367
1397
"authorization_code" ,
1368
- & client_id,
1398
+ client_id,
1369
1399
& invalid_secret. into( ) ,
1370
1400
& redirect_uri
1371
1401
)
@@ -1379,7 +1409,7 @@ mod tests {
1379
1409
authorize_code_exchange(
1380
1410
& ctx,
1381
1411
"authorization_code" ,
1382
- & client_id,
1412
+ client_id,
1383
1413
& client_secret,
1384
1414
& redirect_uri
1385
1415
)
@@ -1425,7 +1455,7 @@ mod tests {
1425
1455
} ,
1426
1456
verify_login_attempt(
1427
1457
& bad_client_id,
1428
- & attempt. client_id,
1458
+ attempt. client_id,
1429
1459
& attempt. redirect_uri,
1430
1460
Some ( verifier. secret( ) . as_str( ) ) ,
1431
1461
)
@@ -1446,7 +1476,7 @@ mod tests {
1446
1476
} ,
1447
1477
verify_login_attempt(
1448
1478
& bad_redirect_uri,
1449
- & attempt. client_id,
1479
+ attempt. client_id,
1450
1480
& attempt. redirect_uri,
1451
1481
Some ( verifier. secret( ) . as_str( ) ) ,
1452
1482
)
@@ -1467,7 +1497,7 @@ mod tests {
1467
1497
} ,
1468
1498
verify_login_attempt(
1469
1499
& unconfirmed_state,
1470
- & attempt. client_id,
1500
+ attempt. client_id,
1471
1501
& attempt. redirect_uri,
1472
1502
Some ( verifier. secret( ) . as_str( ) ) ,
1473
1503
)
@@ -1488,7 +1518,7 @@ mod tests {
1488
1518
} ,
1489
1519
verify_login_attempt(
1490
1520
& already_used_state,
1491
- & attempt. client_id,
1521
+ attempt. client_id,
1492
1522
& attempt. redirect_uri,
1493
1523
Some ( verifier. secret( ) . as_str( ) ) ,
1494
1524
)
@@ -1509,7 +1539,7 @@ mod tests {
1509
1539
} ,
1510
1540
verify_login_attempt(
1511
1541
& failed_state,
1512
- & attempt. client_id,
1542
+ attempt. client_id,
1513
1543
& attempt. redirect_uri,
1514
1544
Some ( verifier. secret( ) . as_str( ) ) ,
1515
1545
)
@@ -1530,7 +1560,7 @@ mod tests {
1530
1560
} ,
1531
1561
verify_login_attempt(
1532
1562
& expired,
1533
- & attempt. client_id,
1563
+ attempt. client_id,
1534
1564
& attempt. redirect_uri,
1535
1565
Some ( verifier. secret( ) . as_str( ) ) ,
1536
1566
)
@@ -1548,7 +1578,7 @@ mod tests {
1548
1578
} ,
1549
1579
verify_login_attempt(
1550
1580
& missing_pkce,
1551
- & attempt. client_id,
1581
+ attempt. client_id,
1552
1582
& attempt. redirect_uri,
1553
1583
None ,
1554
1584
)
@@ -1569,7 +1599,7 @@ mod tests {
1569
1599
} ,
1570
1600
verify_login_attempt(
1571
1601
& invalid_pkce,
1572
- & attempt. client_id,
1602
+ attempt. client_id,
1573
1603
& attempt. redirect_uri,
1574
1604
Some ( verifier. secret( ) . as_str( ) ) ,
1575
1605
)
@@ -1580,7 +1610,7 @@ mod tests {
1580
1610
( ) ,
1581
1611
verify_login_attempt(
1582
1612
& attempt,
1583
- & attempt. client_id,
1613
+ attempt. client_id,
1584
1614
& attempt. redirect_uri,
1585
1615
Some ( verifier. secret( ) . as_str( ) ) ,
1586
1616
)
0 commit comments