Skip to content

Commit 480c8e8

Browse files
authored
Fix token refresh racing with other requests and not using new token (#4798)
* Fix token refresh racing with other requests and not using new token Signed-off-by: Michael Telatynski <[email protected]> * Iterate Signed-off-by: Michael Telatynski <[email protected]> --------- Signed-off-by: Michael Telatynski <[email protected]>
1 parent 1ba4412 commit 480c8e8

File tree

2 files changed

+88
-14
lines changed

2 files changed

+88
-14
lines changed

spec/unit/http-api/fetch.spec.ts

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@ import {
2929
Method,
3030
} from "../../../src";
3131
import { emitPromise } from "../../test-utils/test-utils";
32-
import { defer, type QueryDict } from "../../../src/utils";
32+
import { defer, type QueryDict, sleep } from "../../../src/utils";
3333
import { type Logger } from "../../../src/logger";
3434

3535
describe("FetchHttpApi", () => {
3636
const baseUrl = "http://baseUrl";
3737
const idBaseUrl = "http://idBaseUrl";
3838
const prefix = ClientPrefix.V3;
39+
const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401);
40+
41+
beforeEach(() => {
42+
jest.useRealTimers();
43+
});
3944

4045
it("should support aborting multiple times", () => {
4146
const fetchFn = jest.fn().mockResolvedValue({ ok: true });
@@ -492,8 +497,6 @@ describe("FetchHttpApi", () => {
492497
});
493498

494499
it("should not make multiple concurrent refresh token requests", async () => {
495-
const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401);
496-
497500
const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>();
498501
const fetchFn = jest.fn().mockResolvedValue({
499502
ok: false,
@@ -523,7 +526,7 @@ describe("FetchHttpApi", () => {
523526
const prom1 = api.authedRequest(Method.Get, "/path1");
524527
const prom2 = api.authedRequest(Method.Get, "/path2");
525528

526-
await jest.advanceTimersByTimeAsync(10); // wait for requests to fire
529+
await sleep(0); // wait for requests to fire
527530
expect(fetchFn).toHaveBeenCalledTimes(2);
528531
fetchFn.mockResolvedValue({
529532
ok: true,
@@ -547,4 +550,66 @@ describe("FetchHttpApi", () => {
547550
expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN");
548551
expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN");
549552
});
553+
554+
it("should use newly refreshed token if request starts mid-refresh", async () => {
555+
const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>();
556+
const fetchFn = jest.fn().mockResolvedValue({
557+
ok: false,
558+
status: tokenInactiveError.httpStatus,
559+
async text() {
560+
return JSON.stringify(tokenInactiveError.data);
561+
},
562+
async json() {
563+
return tokenInactiveError.data;
564+
},
565+
headers: {
566+
get: jest.fn().mockReturnValue("application/json"),
567+
},
568+
});
569+
const tokenRefreshFunction = jest.fn().mockReturnValue(deferredTokenRefresh.promise);
570+
571+
const api = new FetchHttpApi(new TypedEventEmitter<any, any>(), {
572+
baseUrl,
573+
prefix,
574+
fetchFn,
575+
doNotAttemptTokenRefresh: false,
576+
tokenRefreshFunction,
577+
accessToken: "ACCESS_TOKEN",
578+
refreshToken: "REFRESH_TOKEN",
579+
});
580+
581+
const prom1 = api.authedRequest(Method.Get, "/path1");
582+
await sleep(0); // wait for request to fire
583+
584+
const prom2 = api.authedRequest(Method.Get, "/path2");
585+
await sleep(0); // wait for request to fire
586+
587+
deferredTokenRefresh.resolve({ accessToken: "NEW_ACCESS_TOKEN", refreshToken: "NEW_REFRESH_TOKEN" });
588+
fetchFn.mockResolvedValue({
589+
ok: true,
590+
status: 200,
591+
async text() {
592+
return "{}";
593+
},
594+
async json() {
595+
return {};
596+
},
597+
headers: {
598+
get: jest.fn().mockReturnValue("application/json"),
599+
},
600+
});
601+
602+
await prom1;
603+
await prom2;
604+
expect(fetchFn).toHaveBeenCalledTimes(3); // 2 original calls + 1 retry
605+
expect(fetchFn.mock.calls[0][1]).toEqual(
606+
expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer ACCESS_TOKEN" }) }),
607+
);
608+
expect(fetchFn.mock.calls[2][1]).toEqual(
609+
expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer NEW_ACCESS_TOKEN" }) }),
610+
);
611+
expect(tokenRefreshFunction).toHaveBeenCalledTimes(1);
612+
expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN");
613+
expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN");
614+
});
550615
});

src/http-api/fetch.ts

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,28 @@ export class FetchHttpApi<O extends IHttpOpts> {
158158
// avoid mutating paramOpts so they can be used on retry
159159
const opts = { ...paramOpts };
160160

161-
if (this.opts.accessToken) {
161+
// Await any ongoing token refresh before we build the headers/params
162+
await this.tokenRefreshPromise;
163+
164+
// Take a copy of the access token so we have a record of the token we used for this request if it fails
165+
const accessToken = this.opts.accessToken;
166+
if (accessToken) {
162167
if (this.opts.useAuthorizationHeader) {
163168
if (!opts.headers) {
164169
opts.headers = {};
165170
}
166171
if (!opts.headers.Authorization) {
167-
opts.headers.Authorization = "Bearer " + this.opts.accessToken;
172+
opts.headers.Authorization = `Bearer ${accessToken}`;
168173
}
169174
if (queryParams.access_token) {
170175
delete queryParams.access_token;
171176
}
172177
} else if (!queryParams.access_token) {
173-
queryParams.access_token = this.opts.accessToken;
178+
queryParams.access_token = accessToken;
174179
}
175180
}
176181

177182
try {
178-
// Await any ongoing token refresh
179-
await this.tokenRefreshPromise;
180183
const response = await this.request<T>(method, path, queryParams, body, opts);
181184
return response;
182185
} catch (error) {
@@ -185,15 +188,21 @@ export class FetchHttpApi<O extends IHttpOpts> {
185188
}
186189

187190
if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) {
188-
const tokenRefreshPromise = this.tryRefreshToken();
189-
this.tokenRefreshPromise = Promise.allSettled([tokenRefreshPromise]);
190-
const outcome = await tokenRefreshPromise;
191+
// If the access token has changed since we started the request, but before we refreshed it,
192+
// then it was refreshed due to another request failing, so retry before refreshing again.
193+
let outcome: TokenRefreshOutcome | null = null;
194+
if (accessToken === this.opts.accessToken) {
195+
const tokenRefreshPromise = this.tryRefreshToken();
196+
this.tokenRefreshPromise = tokenRefreshPromise;
197+
outcome = await tokenRefreshPromise;
198+
}
191199

192-
if (outcome === TokenRefreshOutcome.Success) {
200+
if (outcome === TokenRefreshOutcome.Success || outcome === null) {
193201
// if we got a new token retry the request
194202
return this.authedRequest(method, path, queryParams, body, {
195203
...paramOpts,
196-
doNotAttemptTokenRefresh: true,
204+
// Only attempt token refresh once for each failed request
205+
doNotAttemptTokenRefresh: outcome !== null,
197206
});
198207
}
199208
if (outcome === TokenRefreshOutcome.Failure) {

0 commit comments

Comments
 (0)