Skip to content

Commit 4c0a51e

Browse files
authored
Avoid Promise<Awaited<T>> in return type inference (#45925)
1 parent 12f72ca commit 4c0a51e

File tree

6 files changed

+192
-37
lines changed

6 files changed

+192
-37
lines changed

src/compiler/checker.ts

+40-29
Original file line numberDiff line numberDiff line change
@@ -25471,8 +25471,8 @@ namespace ts {
2547125471

2547225472
if (functionFlags & FunctionFlags.Async) { // Async function or AsyncGenerator function
2547325473
// Get the awaited type without the `Awaited<T>` alias
25474-
const contextualAwaitedType = mapType(contextualReturnType, getAwaitedType);
25475-
return contextualAwaitedType && getUnionType([unwrapAwaitedType(contextualAwaitedType), createPromiseLikeType(contextualAwaitedType)]);
25474+
const contextualAwaitedType = mapType(contextualReturnType, getAwaitedTypeNoAlias);
25475+
return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
2547625476
}
2547725477

2547825478
return contextualReturnType; // Regular function or Generator function
@@ -25484,8 +25484,8 @@ namespace ts {
2548425484
function getContextualTypeForAwaitOperand(node: AwaitExpression, contextFlags?: ContextFlags): Type | undefined {
2548525485
const contextualType = getContextualType(node, contextFlags);
2548625486
if (contextualType) {
25487-
const contextualAwaitedType = getAwaitedType(contextualType);
25488-
return contextualAwaitedType && getUnionType([unwrapAwaitedType(contextualAwaitedType), createPromiseLikeType(contextualAwaitedType)]);
25487+
const contextualAwaitedType = getAwaitedTypeNoAlias(contextualType);
25488+
return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
2548925489
}
2549025490
return undefined;
2549125491
}
@@ -31158,7 +31158,8 @@ namespace ts {
3115831158
const globalPromiseType = getGlobalPromiseType(/*reportErrors*/ true);
3115931159
if (globalPromiseType !== emptyGenericType) {
3116031160
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31161-
promisedType = getAwaitedType(promisedType) || unknownType;
31161+
// Unwrap an `Awaited<T>` to `T` to improve inference.
31162+
promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
3116231163
return createTypeReference(globalPromiseType, [promisedType]);
3116331164
}
3116431165

@@ -31170,7 +31171,8 @@ namespace ts {
3117031171
const globalPromiseLikeType = getGlobalPromiseLikeType(/*reportErrors*/ true);
3117131172
if (globalPromiseLikeType !== emptyGenericType) {
3117231173
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31173-
promisedType = getAwaitedType(promisedType) || unknownType;
31174+
// Unwrap an `Awaited<T>` to `T` to improve inference.
31175+
promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
3117431176
return createTypeReference(globalPromiseLikeType, [promisedType]);
3117531177
}
3117631178

@@ -31227,7 +31229,7 @@ namespace ts {
3122731229
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
3122831230
// return type of the body should be unwrapped to its awaited type, which we will wrap in
3122931231
// the native Promise<T> type later in this function.
31230-
returnType = checkAwaitedType(returnType, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31232+
returnType = unwrapAwaitedType(checkAwaitedType(returnType, /*withAlias*/ false, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member));
3123131233
}
3123231234
}
3123331235
else if (isGenerator) { // Generator or AsyncGenerator function
@@ -31460,7 +31462,7 @@ namespace ts {
3146031462
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
3146131463
// return type of the body should be unwrapped to its awaited type, which should be wrapped in
3146231464
// the native Promise<T> type by the caller.
31463-
type = checkAwaitedType(type, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31465+
type = unwrapAwaitedType(checkAwaitedType(type, /*withAlias*/ false, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member));
3146431466
}
3146531467
if (type.flags & TypeFlags.Never) {
3146631468
hasReturnOfTypeNever = true;
@@ -31662,7 +31664,7 @@ namespace ts {
3166231664
const returnOrPromisedType = returnType && unwrapReturnType(returnType, functionFlags);
3166331665
if (returnOrPromisedType) {
3166431666
if ((functionFlags & FunctionFlags.AsyncGenerator) === FunctionFlags.Async) { // Async function
31665-
const awaitedType = checkAwaitedType(exprType, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31667+
const awaitedType = checkAwaitedType(exprType, /*withAlias*/ false, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3166631668
checkTypeAssignableToAndOptionallyElaborate(awaitedType, returnOrPromisedType, node.body, node.body);
3166731669
}
3166831670
else { // Normal function
@@ -31879,7 +31881,7 @@ namespace ts {
3187931881
}
3188031882

3188131883
const operandType = checkExpression(node.expression);
31882-
const awaitedType = checkAwaitedType(operandType, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31884+
const awaitedType = checkAwaitedType(operandType, /*withAlias*/ true, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3188331885
if (awaitedType === operandType && awaitedType !== errorType && !(operandType.flags & TypeFlags.AnyOrUnknown)) {
3188431886
addErrorOrSuggestion(/*isError*/ false, createDiagnosticForNode(node, Diagnostics.await_has_no_effect_on_the_type_of_this_expression));
3188531887
}
@@ -32831,8 +32833,8 @@ namespace ts {
3283132833
let wouldWorkWithAwait = false;
3283232834
const errNode = errorNode || operatorToken;
3283332835
if (isRelated) {
32834-
const awaitedLeftType = unwrapAwaitedType(getAwaitedType(leftType));
32835-
const awaitedRightType = unwrapAwaitedType(getAwaitedType(rightType));
32836+
const awaitedLeftType = getAwaitedTypeNoAlias(leftType);
32837+
const awaitedRightType = getAwaitedTypeNoAlias(rightType);
3283632838
wouldWorkWithAwait = !(awaitedLeftType === leftType && awaitedRightType === rightType)
3283732839
&& !!(awaitedLeftType && awaitedRightType)
3283832840
&& isRelated(awaitedLeftType, awaitedRightType);
@@ -34914,12 +34916,15 @@ namespace ts {
3491434916
/**
3491534917
* Gets the "awaited type" of a type.
3491634918
* @param type The type to await.
34919+
* @param withAlias When `true`, wraps the "awaited type" in `Awaited<T>` if needed.
3491734920
* @remarks The "awaited type" of an expression is its "promised type" if the expression is a
3491834921
* Promise-like type; otherwise, it is the type of the expression. This is used to reflect
3491934922
* The runtime behavior of the `await` keyword.
3492034923
*/
34921-
function checkAwaitedType(type: Type, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34922-
const awaitedType = getAwaitedType(type, errorNode, diagnosticMessage, arg0);
34924+
function checkAwaitedType(type: Type, withAlias: boolean, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34925+
const awaitedType = withAlias ?
34926+
getAwaitedType(type, errorNode, diagnosticMessage, arg0) :
34927+
getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
3492334928
return awaitedType || errorType;
3492434929
}
3492534930

@@ -34953,10 +34958,7 @@ namespace ts {
3495334958
/**
3495434959
* For a generic `Awaited<T>`, gets `T`.
3495534960
*/
34956-
function unwrapAwaitedType(type: Type): Type;
34957-
function unwrapAwaitedType(type: Type | undefined): Type | undefined;
34958-
function unwrapAwaitedType(type: Type | undefined) {
34959-
if (!type) return undefined;
34961+
function unwrapAwaitedType(type: Type) {
3496034962
return type.flags & TypeFlags.Union ? mapType(type, unwrapAwaitedType) :
3496134963
isAwaitedTypeInstantiation(type) ? type.aliasTypeArguments[0] :
3496234964
type;
@@ -35011,6 +35013,16 @@ namespace ts {
3501135013
* This is used to reflect the runtime behavior of the `await` keyword.
3501235014
*/
3501335015
function getAwaitedType(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
35016+
const awaitedType = getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
35017+
return awaitedType && createAwaitedTypeIfNeeded(awaitedType);
35018+
}
35019+
35020+
/**
35021+
* Gets the "awaited type" of a type without introducing an `Awaited<T>` wrapper.
35022+
*
35023+
* @see {@link getAwaitedType}
35024+
*/
35025+
function getAwaitedTypeNoAlias(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
3501435026
if (isTypeAny(type)) {
3501535027
return type;
3501635028
}
@@ -35023,14 +35035,13 @@ namespace ts {
3502335035
// If we've already cached an awaited type, return a possible `Awaited<T>` for it.
3502435036
const typeAsAwaitable = type as PromiseOrAwaitableType;
3502535037
if (typeAsAwaitable.awaitedTypeOfType) {
35026-
return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
35038+
return typeAsAwaitable.awaitedTypeOfType;
3502735039
}
3502835040

3502935041
// For a union, get a union of the awaited types of each constituent.
3503035042
if (type.flags & TypeFlags.Union) {
35031-
const mapper = errorNode ? (constituentType: Type) => getAwaitedType(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedType;
35032-
typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
35033-
return typeAsAwaitable.awaitedTypeOfType && createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
35043+
const mapper = errorNode ? (constituentType: Type) => getAwaitedTypeNoAlias(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedTypeNoAlias;
35044+
return typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
3503435045
}
3503535046

3503635047
const promisedType = getPromisedTypeOfPromise(type);
@@ -35078,14 +35089,14 @@ namespace ts {
3507835089
// Keep track of the type we're about to unwrap to avoid bad recursive promise types.
3507935090
// See the comments above for more information.
3508035091
awaitedTypeStack.push(type.id);
35081-
const awaitedType = getAwaitedType(promisedType, errorNode, diagnosticMessage, arg0);
35092+
const awaitedType = getAwaitedTypeNoAlias(promisedType, errorNode, diagnosticMessage, arg0);
3508235093
awaitedTypeStack.pop();
3508335094

3508435095
if (!awaitedType) {
3508535096
return undefined;
3508635097
}
3508735098

35088-
return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType = awaitedType);
35099+
return typeAsAwaitable.awaitedTypeOfType = awaitedType;
3508935100
}
3509035101

3509135102
// The type was not a promise, so it could not be unwrapped any further.
@@ -35111,7 +35122,7 @@ namespace ts {
3511135122
return undefined;
3511235123
}
3511335124

35114-
return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType = type);
35125+
return typeAsAwaitable.awaitedTypeOfType = type;
3511535126
}
3511635127

3511735128
/**
@@ -35161,7 +35172,7 @@ namespace ts {
3516135172
if (globalPromiseType !== emptyGenericType && !isReferenceToType(returnType, globalPromiseType)) {
3516235173
// The promise type was not a valid type reference to the global promise type, so we
3516335174
// report an error and return the unknown type.
35164-
error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(unwrapAwaitedType(getAwaitedType(returnType)) || voidType));
35175+
error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(getAwaitedTypeNoAlias(returnType) || voidType));
3516535176
return;
3516635177
}
3516735178
}
@@ -35214,7 +35225,7 @@ namespace ts {
3521435225
return;
3521535226
}
3521635227
}
35217-
checkAwaitedType(returnType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
35228+
checkAwaitedType(returnType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3521835229
}
3521935230

3522035231
/** Check a decorator */
@@ -37495,7 +37506,7 @@ namespace ts {
3749537506
const isGenerator = !!(functionFlags & FunctionFlags.Generator);
3749637507
const isAsync = !!(functionFlags & FunctionFlags.Async);
3749737508
return isGenerator ? getIterationTypeOfGeneratorFunctionReturnType(IterationTypeKind.Return, returnType, isAsync) || errorType :
37498-
isAsync ? unwrapAwaitedType(getAwaitedType(returnType)) || errorType :
37509+
isAsync ? getAwaitedTypeNoAlias(returnType) || errorType :
3749937510
returnType;
3750037511
}
3750137512

@@ -37539,7 +37550,7 @@ namespace ts {
3753937550
else if (getReturnTypeFromAnnotation(container)) {
3754037551
const unwrappedReturnType = unwrapReturnType(returnType, functionFlags) ?? returnType;
3754137552
const unwrappedExprType = functionFlags & FunctionFlags.Async
37542-
? checkAwaitedType(exprType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
37553+
? checkAwaitedType(exprType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
3754337554
: exprType;
3754437555
if (unwrappedReturnType) {
3754537556
// If the function has a return type, but promisedType is

tests/baselines/reference/awaitedTypeStrictNull.errors.txt

+16
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ tests/cases/compiler/awaitedTypeStrictNull.ts(22,12): error TS2589: Type instant
4747
])
4848
}
4949

50+
// https://github.com/microsoft/TypeScript/issues/45924
51+
class Api<D = {}> {
52+
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
53+
async post<T = D>() { return this.request<T>(); }
54+
async request<D>(): Promise<D> { throw new Error(); }
55+
}
56+
57+
declare const api: Api;
58+
interface Obj { x: number }
59+
60+
async function fn<T>(): Promise<T extends object ? { [K in keyof T]: Obj } : Obj> {
61+
// Per #45924, this was failing due to incorrect inference both above and here.
62+
// Should not error.
63+
return api.post();
64+
}
65+
5066
// helps with tests where '.types' just prints out the type alias name
5167
type _Expect<TActual extends TExpected, TExpected> = TActual;
5268

tests/baselines/reference/awaitedTypeStrictNull.js

+27
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ async function main() {
3939
])
4040
}
4141

42+
// https://github.com/microsoft/TypeScript/issues/45924
43+
class Api<D = {}> {
44+
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
45+
async post<T = D>() { return this.request<T>(); }
46+
async request<D>(): Promise<D> { throw new Error(); }
47+
}
48+
49+
declare const api: Api;
50+
interface Obj { x: number }
51+
52+
async function fn<T>(): Promise<T extends object ? { [K in keyof T]: Obj } : Obj> {
53+
// Per #45924, this was failing due to incorrect inference both above and here.
54+
// Should not error.
55+
return api.post();
56+
}
57+
4258
// helps with tests where '.types' just prints out the type alias name
4359
type _Expect<TActual extends TExpected, TExpected> = TActual;
4460

@@ -56,3 +72,14 @@ async function main() {
5672
MaybePromise(true),
5773
]);
5874
}
75+
// https://github.com/microsoft/TypeScript/issues/45924
76+
class Api {
77+
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
78+
async post() { return this.request(); }
79+
async request() { throw new Error(); }
80+
}
81+
async function fn() {
82+
// Per #45924, this was failing due to incorrect inference both above and here.
83+
// Should not error.
84+
return api.post();
85+
}

0 commit comments

Comments
 (0)