@@ -29,13 +29,18 @@ import {
29
29
Method ,
30
30
} from "../../../src" ;
31
31
import { emitPromise } from "../../test-utils/test-utils" ;
32
- import { defer , type QueryDict } from "../../../src/utils" ;
32
+ import { defer , type QueryDict , sleep } from "../../../src/utils" ;
33
33
import { type Logger } from "../../../src/logger" ;
34
34
35
35
describe ( "FetchHttpApi" , ( ) => {
36
36
const baseUrl = "http://baseUrl" ;
37
37
const idBaseUrl = "http://idBaseUrl" ;
38
38
const prefix = ClientPrefix . V3 ;
39
+ const tokenInactiveError = new MatrixError ( { errcode : "M_UNKNOWN_TOKEN" , error : "Token is not active" } , 401 ) ;
40
+
41
+ beforeEach ( ( ) => {
42
+ jest . useRealTimers ( ) ;
43
+ } ) ;
39
44
40
45
it ( "should support aborting multiple times" , ( ) => {
41
46
const fetchFn = jest . fn ( ) . mockResolvedValue ( { ok : true } ) ;
@@ -492,8 +497,6 @@ describe("FetchHttpApi", () => {
492
497
} ) ;
493
498
494
499
it ( "should not make multiple concurrent refresh token requests" , async ( ) => {
495
- const tokenInactiveError = new MatrixError ( { errcode : "M_UNKNOWN_TOKEN" , error : "Token is not active" } , 401 ) ;
496
-
497
500
const deferredTokenRefresh = defer < { accessToken : string ; refreshToken : string } > ( ) ;
498
501
const fetchFn = jest . fn ( ) . mockResolvedValue ( {
499
502
ok : false ,
@@ -523,7 +526,7 @@ describe("FetchHttpApi", () => {
523
526
const prom1 = api . authedRequest ( Method . Get , "/path1" ) ;
524
527
const prom2 = api . authedRequest ( Method . Get , "/path2" ) ;
525
528
526
- await jest . advanceTimersByTimeAsync ( 10 ) ; // wait for requests to fire
529
+ await sleep ( 0 ) ; // wait for requests to fire
527
530
expect ( fetchFn ) . toHaveBeenCalledTimes ( 2 ) ;
528
531
fetchFn . mockResolvedValue ( {
529
532
ok : true ,
@@ -547,4 +550,66 @@ describe("FetchHttpApi", () => {
547
550
expect ( api . opts . accessToken ) . toBe ( "NEW_ACCESS_TOKEN" ) ;
548
551
expect ( api . opts . refreshToken ) . toBe ( "NEW_REFRESH_TOKEN" ) ;
549
552
} ) ;
553
+
554
+ it ( "should use newly refreshed token if request starts mid-refresh" , async ( ) => {
555
+ const deferredTokenRefresh = defer < { accessToken : string ; refreshToken : string } > ( ) ;
556
+ const fetchFn = jest . fn ( ) . mockResolvedValue ( {
557
+ ok : false ,
558
+ status : tokenInactiveError . httpStatus ,
559
+ async text ( ) {
560
+ return JSON . stringify ( tokenInactiveError . data ) ;
561
+ } ,
562
+ async json ( ) {
563
+ return tokenInactiveError . data ;
564
+ } ,
565
+ headers : {
566
+ get : jest . fn ( ) . mockReturnValue ( "application/json" ) ,
567
+ } ,
568
+ } ) ;
569
+ const tokenRefreshFunction = jest . fn ( ) . mockReturnValue ( deferredTokenRefresh . promise ) ;
570
+
571
+ const api = new FetchHttpApi ( new TypedEventEmitter < any , any > ( ) , {
572
+ baseUrl,
573
+ prefix,
574
+ fetchFn,
575
+ doNotAttemptTokenRefresh : false ,
576
+ tokenRefreshFunction,
577
+ accessToken : "ACCESS_TOKEN" ,
578
+ refreshToken : "REFRESH_TOKEN" ,
579
+ } ) ;
580
+
581
+ const prom1 = api . authedRequest ( Method . Get , "/path1" ) ;
582
+ await sleep ( 0 ) ; // wait for request to fire
583
+
584
+ const prom2 = api . authedRequest ( Method . Get , "/path2" ) ;
585
+ await sleep ( 0 ) ; // wait for request to fire
586
+
587
+ deferredTokenRefresh . resolve ( { accessToken : "NEW_ACCESS_TOKEN" , refreshToken : "NEW_REFRESH_TOKEN" } ) ;
588
+ fetchFn . mockResolvedValue ( {
589
+ ok : true ,
590
+ status : 200 ,
591
+ async text ( ) {
592
+ return "{}" ;
593
+ } ,
594
+ async json ( ) {
595
+ return { } ;
596
+ } ,
597
+ headers : {
598
+ get : jest . fn ( ) . mockReturnValue ( "application/json" ) ,
599
+ } ,
600
+ } ) ;
601
+
602
+ await prom1 ;
603
+ await prom2 ;
604
+ expect ( fetchFn ) . toHaveBeenCalledTimes ( 3 ) ; // 2 original calls + 1 retry
605
+ expect ( fetchFn . mock . calls [ 0 ] [ 1 ] ) . toEqual (
606
+ expect . objectContaining ( { headers : expect . objectContaining ( { Authorization : "Bearer ACCESS_TOKEN" } ) } ) ,
607
+ ) ;
608
+ expect ( fetchFn . mock . calls [ 2 ] [ 1 ] ) . toEqual (
609
+ expect . objectContaining ( { headers : expect . objectContaining ( { Authorization : "Bearer NEW_ACCESS_TOKEN" } ) } ) ,
610
+ ) ;
611
+ expect ( tokenRefreshFunction ) . toHaveBeenCalledTimes ( 1 ) ;
612
+ expect ( api . opts . accessToken ) . toBe ( "NEW_ACCESS_TOKEN" ) ;
613
+ expect ( api . opts . refreshToken ) . toBe ( "NEW_REFRESH_TOKEN" ) ;
614
+ } ) ;
550
615
} ) ;
0 commit comments