@@ -25471,8 +25471,8 @@ namespace ts {
25471
25471
25472
25472
if (functionFlags & FunctionFlags.Async) { // Async function or AsyncGenerator function
25473
25473
// Get the awaited type without the `Awaited<T>` alias
25474
- const contextualAwaitedType = mapType(contextualReturnType, getAwaitedType );
25475
- return contextualAwaitedType && getUnionType([unwrapAwaitedType( contextualAwaitedType) , createPromiseLikeType(contextualAwaitedType)]);
25474
+ const contextualAwaitedType = mapType(contextualReturnType, getAwaitedTypeNoAlias );
25475
+ return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
25476
25476
}
25477
25477
25478
25478
return contextualReturnType; // Regular function or Generator function
@@ -25484,8 +25484,8 @@ namespace ts {
25484
25484
function getContextualTypeForAwaitOperand(node: AwaitExpression, contextFlags?: ContextFlags): Type | undefined {
25485
25485
const contextualType = getContextualType(node, contextFlags);
25486
25486
if (contextualType) {
25487
- const contextualAwaitedType = getAwaitedType (contextualType);
25488
- return contextualAwaitedType && getUnionType([unwrapAwaitedType( contextualAwaitedType) , createPromiseLikeType(contextualAwaitedType)]);
25487
+ const contextualAwaitedType = getAwaitedTypeNoAlias (contextualType);
25488
+ return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
25489
25489
}
25490
25490
return undefined;
25491
25491
}
@@ -31158,7 +31158,8 @@ namespace ts {
31158
31158
const globalPromiseType = getGlobalPromiseType(/*reportErrors*/ true);
31159
31159
if (globalPromiseType !== emptyGenericType) {
31160
31160
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31161
- promisedType = getAwaitedType(promisedType) || unknownType;
31161
+ // Unwrap an `Awaited<T>` to `T` to improve inference.
31162
+ promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
31162
31163
return createTypeReference(globalPromiseType, [promisedType]);
31163
31164
}
31164
31165
@@ -31170,7 +31171,8 @@ namespace ts {
31170
31171
const globalPromiseLikeType = getGlobalPromiseLikeType(/*reportErrors*/ true);
31171
31172
if (globalPromiseLikeType !== emptyGenericType) {
31172
31173
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31173
- promisedType = getAwaitedType(promisedType) || unknownType;
31174
+ // Unwrap an `Awaited<T>` to `T` to improve inference.
31175
+ promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
31174
31176
return createTypeReference(globalPromiseLikeType, [promisedType]);
31175
31177
}
31176
31178
@@ -31227,7 +31229,7 @@ namespace ts {
31227
31229
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
31228
31230
// return type of the body should be unwrapped to its awaited type, which we will wrap in
31229
31231
// the native Promise<T> type later in this function.
31230
- returnType = checkAwaitedType(returnType, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31232
+ returnType = unwrapAwaitedType( checkAwaitedType(returnType, /*withAlias*/ false, /* errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member) );
31231
31233
}
31232
31234
}
31233
31235
else if (isGenerator) { // Generator or AsyncGenerator function
@@ -31460,7 +31462,7 @@ namespace ts {
31460
31462
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
31461
31463
// return type of the body should be unwrapped to its awaited type, which should be wrapped in
31462
31464
// the native Promise<T> type by the caller.
31463
- type = checkAwaitedType(type, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31465
+ type = unwrapAwaitedType( checkAwaitedType(type, /*withAlias*/ false, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member) );
31464
31466
}
31465
31467
if (type.flags & TypeFlags.Never) {
31466
31468
hasReturnOfTypeNever = true;
@@ -31662,7 +31664,7 @@ namespace ts {
31662
31664
const returnOrPromisedType = returnType && unwrapReturnType(returnType, functionFlags);
31663
31665
if (returnOrPromisedType) {
31664
31666
if ((functionFlags & FunctionFlags.AsyncGenerator) === FunctionFlags.Async) { // Async function
31665
- const awaitedType = checkAwaitedType(exprType, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31667
+ const awaitedType = checkAwaitedType(exprType, /*withAlias*/ false, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31666
31668
checkTypeAssignableToAndOptionallyElaborate(awaitedType, returnOrPromisedType, node.body, node.body);
31667
31669
}
31668
31670
else { // Normal function
@@ -31879,7 +31881,7 @@ namespace ts {
31879
31881
}
31880
31882
31881
31883
const operandType = checkExpression(node.expression);
31882
- const awaitedType = checkAwaitedType(operandType, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31884
+ const awaitedType = checkAwaitedType(operandType, /*withAlias*/ true, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31883
31885
if (awaitedType === operandType && awaitedType !== errorType && !(operandType.flags & TypeFlags.AnyOrUnknown)) {
31884
31886
addErrorOrSuggestion(/*isError*/ false, createDiagnosticForNode(node, Diagnostics.await_has_no_effect_on_the_type_of_this_expression));
31885
31887
}
@@ -32831,8 +32833,8 @@ namespace ts {
32831
32833
let wouldWorkWithAwait = false;
32832
32834
const errNode = errorNode || operatorToken;
32833
32835
if (isRelated) {
32834
- const awaitedLeftType = unwrapAwaitedType(getAwaitedType( leftType) );
32835
- const awaitedRightType = unwrapAwaitedType(getAwaitedType( rightType) );
32836
+ const awaitedLeftType = getAwaitedTypeNoAlias( leftType);
32837
+ const awaitedRightType = getAwaitedTypeNoAlias( rightType);
32836
32838
wouldWorkWithAwait = !(awaitedLeftType === leftType && awaitedRightType === rightType)
32837
32839
&& !!(awaitedLeftType && awaitedRightType)
32838
32840
&& isRelated(awaitedLeftType, awaitedRightType);
@@ -34914,12 +34916,15 @@ namespace ts {
34914
34916
/**
34915
34917
* Gets the "awaited type" of a type.
34916
34918
* @param type The type to await.
34919
+ * @param withAlias When `true`, wraps the "awaited type" in `Awaited<T>` if needed.
34917
34920
* @remarks The "awaited type" of an expression is its "promised type" if the expression is a
34918
34921
* Promise-like type; otherwise, it is the type of the expression. This is used to reflect
34919
34922
* The runtime behavior of the `await` keyword.
34920
34923
*/
34921
- function checkAwaitedType(type: Type, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34922
- const awaitedType = getAwaitedType(type, errorNode, diagnosticMessage, arg0);
34924
+ function checkAwaitedType(type: Type, withAlias: boolean, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34925
+ const awaitedType = withAlias ?
34926
+ getAwaitedType(type, errorNode, diagnosticMessage, arg0) :
34927
+ getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
34923
34928
return awaitedType || errorType;
34924
34929
}
34925
34930
@@ -34953,10 +34958,7 @@ namespace ts {
34953
34958
/**
34954
34959
* For a generic `Awaited<T>`, gets `T`.
34955
34960
*/
34956
- function unwrapAwaitedType(type: Type): Type;
34957
- function unwrapAwaitedType(type: Type | undefined): Type | undefined;
34958
- function unwrapAwaitedType(type: Type | undefined) {
34959
- if (!type) return undefined;
34961
+ function unwrapAwaitedType(type: Type) {
34960
34962
return type.flags & TypeFlags.Union ? mapType(type, unwrapAwaitedType) :
34961
34963
isAwaitedTypeInstantiation(type) ? type.aliasTypeArguments[0] :
34962
34964
type;
@@ -35011,6 +35013,16 @@ namespace ts {
35011
35013
* This is used to reflect the runtime behavior of the `await` keyword.
35012
35014
*/
35013
35015
function getAwaitedType(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
35016
+ const awaitedType = getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
35017
+ return awaitedType && createAwaitedTypeIfNeeded(awaitedType);
35018
+ }
35019
+
35020
+ /**
35021
+ * Gets the "awaited type" of a type without introducing an `Awaited<T>` wrapper.
35022
+ *
35023
+ * @see {@link getAwaitedType}
35024
+ */
35025
+ function getAwaitedTypeNoAlias(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
35014
35026
if (isTypeAny(type)) {
35015
35027
return type;
35016
35028
}
@@ -35023,14 +35035,13 @@ namespace ts {
35023
35035
// If we've already cached an awaited type, return a possible `Awaited<T>` for it.
35024
35036
const typeAsAwaitable = type as PromiseOrAwaitableType;
35025
35037
if (typeAsAwaitable.awaitedTypeOfType) {
35026
- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType) ;
35038
+ return typeAsAwaitable.awaitedTypeOfType;
35027
35039
}
35028
35040
35029
35041
// For a union, get a union of the awaited types of each constituent.
35030
35042
if (type.flags & TypeFlags.Union) {
35031
- const mapper = errorNode ? (constituentType: Type) => getAwaitedType(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedType;
35032
- typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
35033
- return typeAsAwaitable.awaitedTypeOfType && createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
35043
+ const mapper = errorNode ? (constituentType: Type) => getAwaitedTypeNoAlias(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedTypeNoAlias;
35044
+ return typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
35034
35045
}
35035
35046
35036
35047
const promisedType = getPromisedTypeOfPromise(type);
@@ -35078,14 +35089,14 @@ namespace ts {
35078
35089
// Keep track of the type we're about to unwrap to avoid bad recursive promise types.
35079
35090
// See the comments above for more information.
35080
35091
awaitedTypeStack.push(type.id);
35081
- const awaitedType = getAwaitedType (promisedType, errorNode, diagnosticMessage, arg0);
35092
+ const awaitedType = getAwaitedTypeNoAlias (promisedType, errorNode, diagnosticMessage, arg0);
35082
35093
awaitedTypeStack.pop();
35083
35094
35084
35095
if (!awaitedType) {
35085
35096
return undefined;
35086
35097
}
35087
35098
35088
- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType = awaitedType) ;
35099
+ return typeAsAwaitable.awaitedTypeOfType = awaitedType;
35089
35100
}
35090
35101
35091
35102
// The type was not a promise, so it could not be unwrapped any further.
@@ -35111,7 +35122,7 @@ namespace ts {
35111
35122
return undefined;
35112
35123
}
35113
35124
35114
- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType = type) ;
35125
+ return typeAsAwaitable.awaitedTypeOfType = type;
35115
35126
}
35116
35127
35117
35128
/**
@@ -35161,7 +35172,7 @@ namespace ts {
35161
35172
if (globalPromiseType !== emptyGenericType && !isReferenceToType(returnType, globalPromiseType)) {
35162
35173
// The promise type was not a valid type reference to the global promise type, so we
35163
35174
// report an error and return the unknown type.
35164
- error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(unwrapAwaitedType(getAwaitedType( returnType) ) || voidType));
35175
+ error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(getAwaitedTypeNoAlias( returnType) || voidType));
35165
35176
return;
35166
35177
}
35167
35178
}
@@ -35214,7 +35225,7 @@ namespace ts {
35214
35225
return;
35215
35226
}
35216
35227
}
35217
- checkAwaitedType(returnType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
35228
+ checkAwaitedType(returnType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
35218
35229
}
35219
35230
35220
35231
/** Check a decorator */
@@ -37495,7 +37506,7 @@ namespace ts {
37495
37506
const isGenerator = !!(functionFlags & FunctionFlags.Generator);
37496
37507
const isAsync = !!(functionFlags & FunctionFlags.Async);
37497
37508
return isGenerator ? getIterationTypeOfGeneratorFunctionReturnType(IterationTypeKind.Return, returnType, isAsync) || errorType :
37498
- isAsync ? unwrapAwaitedType(getAwaitedType( returnType) ) || errorType :
37509
+ isAsync ? getAwaitedTypeNoAlias( returnType) || errorType :
37499
37510
returnType;
37500
37511
}
37501
37512
@@ -37539,7 +37550,7 @@ namespace ts {
37539
37550
else if (getReturnTypeFromAnnotation(container)) {
37540
37551
const unwrappedReturnType = unwrapReturnType(returnType, functionFlags) ?? returnType;
37541
37552
const unwrappedExprType = functionFlags & FunctionFlags.Async
37542
- ? checkAwaitedType(exprType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
37553
+ ? checkAwaitedType(exprType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
37543
37554
: exprType;
37544
37555
if (unwrappedReturnType) {
37545
37556
// If the function has a return type, but promisedType is
0 commit comments