From 8422476560a9d8cf6e77a6b66f51bac0ec32c6ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Burzy=C5=84ski?= Date: Fri, 3 Oct 2025 19:59:56 +0200 Subject: [PATCH] Port "Infer from annotated return type nodes before assigning contextual parameter types" --- internal/checker/checker.go | 13 +++- .../compiler/inferFromAnnotatedReturn1.types | 34 ++++----- .../inferFromAnnotatedReturn1.types.diff | 71 ------------------- 3 files changed, 27 insertions(+), 91 deletions(-) delete mode 100644 testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types.diff diff --git a/internal/checker/checker.go b/internal/checker/checker.go index a6b3a684c5..9d8e4c0a58 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -9827,7 +9827,7 @@ func (c *Checker) contextuallyCheckFunctionExpressionOrObjectLiteralMethod(node inferenceContext := c.getInferenceContext(node) var instantiatedContextualSignature *Signature if checkMode&CheckModeInferential != 0 { - c.inferFromAnnotatedParameters(signature, contextualSignature, inferenceContext) + c.inferFromAnnotatedParametersAndReturn(signature, contextualSignature, inferenceContext) restType := c.getEffectiveRestType(contextualSignature) if restType != nil && restType.flags&TypeFlagsTypeParameter != 0 { instantiatedContextualSignature = c.instantiateSignature(contextualSignature, inferenceContext.nonFixingMapper) @@ -9848,7 +9848,7 @@ func (c *Checker) contextuallyCheckFunctionExpressionOrObjectLiteralMethod(node } else if contextualSignature != nil && node.TypeParameters() == nil && len(contextualSignature.parameters) > len(node.Parameters()) { inferenceContext := c.getInferenceContext(node) if checkMode&CheckModeInferential != 0 { - c.inferFromAnnotatedParameters(signature, contextualSignature, inferenceContext) + c.inferFromAnnotatedParametersAndReturn(signature, contextualSignature, inferenceContext) } } if contextualSignature != nil && c.getReturnTypeFromAnnotation(node) == nil && signature.resolvedReturnType == nil { @@ -9895,7 +9895,7 @@ func (c *Checker) checkFunctionExpressionOrObjectLiteralMethodDeferred(node *ast } } -func (c *Checker) inferFromAnnotatedParameters(sig *Signature, context *Signature, inferenceContext *InferenceContext) { +func (c *Checker) inferFromAnnotatedParametersAndReturn(sig *Signature, context *Signature, inferenceContext *InferenceContext) { length := len(sig.parameters) - core.IfElse(signatureHasRestParameter(sig), 1, 0) for i := range length { declaration := sig.parameters[i].ValueDeclaration @@ -9906,6 +9906,13 @@ func (c *Checker) inferFromAnnotatedParameters(sig *Signature, context *Signatur c.inferTypes(inferenceContext.inferences, source, target, InferencePriorityNone, false) } } + if declaration := sig.Declaration(); declaration != nil { + if returnTypeNode := declaration.Type(); returnTypeNode != nil { + source := c.getTypeFromTypeNode(returnTypeNode) + target := c.getReturnTypeOfSignature(context) + c.inferTypes(inferenceContext.inferences, source, target, InferencePriorityNone, false) + } + } } // Return the contextual signature for a given expression node. A contextual type provides a diff --git a/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types b/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types index 6fa55005ef..192879a21d 100644 --- a/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types +++ b/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types @@ -7,19 +7,19 @@ declare function test(cb: (arg: T) => T): T; >arg : T const res1 = test((arg): number => 1); // ok ->res1 : unknown ->test((arg): number => 1) : unknown +>res1 : number +>test((arg): number => 1) : number >test : (cb: (arg: T) => T) => T ->(arg): number => 1 : (arg: unknown) => number ->arg : unknown +>(arg): number => 1 : (arg: number) => number +>arg : number >1 : 1 const res2 = test((arg): number => 'foo'); // error ->res2 : unknown ->test((arg): number => 'foo') : unknown +>res2 : number +>test((arg): number => 'foo') : number >test : (cb: (arg: T) => T) => T ->(arg): number => 'foo' : (arg: unknown) => number ->arg : unknown +>(arg): number => 'foo' : (arg: number) => number +>arg : number >'foo' : "foo" export declare function linkedSignal(options: { @@ -36,10 +36,10 @@ export declare function linkedSignal(options: { }): D; const signal = linkedSignal({ ->signal : unknown ->linkedSignal({ source: () => 3, computation: (s): number => 3,}) : unknown +>signal : number +>linkedSignal({ source: () => 3, computation: (s): number => 3,}) : number >linkedSignal : (options: { source: () => S; computation: (source: NoInfer) => D; }) => D ->{ source: () => 3, computation: (s): number => 3,} : { source: () => number; computation: (s: unknown) => number; } +>{ source: () => 3, computation: (s): number => 3,} : { source: () => number; computation: (s: number) => number; } source: () => 3, >source : () => number @@ -47,9 +47,9 @@ const signal = linkedSignal({ >3 : 3 computation: (s): number => 3, ->computation : (s: unknown) => number ->(s): number => 3 : (s: unknown) => number ->s : unknown +>computation : (s: number) => number +>(s): number => 3 : (s: number) => number +>s : number >3 : 3 }); @@ -66,10 +66,10 @@ class Foo { } const _1 = new Foo((name: string, { x }): { name: string; x: number } => ({ ->_1 : Foo ->new Foo((name: string, { x }): { name: string; x: number } => ({ name, x,})) : Foo +>_1 : Foo +>new Foo((name: string, { x }): { name: string; x: number } => ({ name, x,})) : Foo >Foo : typeof Foo ->(name: string, { x }): { name: string; x: number } => ({ name, x,}) : (name: string, { x }: { x: number; other: unknown; }) => { name: string; x: number; } +>(name: string, { x }): { name: string; x: number } => ({ name, x,}) : (name: string, { x }: { x: number; other: NoInfer<{ name: string; x: number; }>; }) => { name: string; x: number; } >name : string >x : number >name : string diff --git a/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types.diff b/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types.diff deleted file mode 100644 index 762072c2e8..0000000000 --- a/testdata/baselines/reference/submodule/compiler/inferFromAnnotatedReturn1.types.diff +++ /dev/null @@ -1,71 +0,0 @@ ---- old.inferFromAnnotatedReturn1.types -+++ new.inferFromAnnotatedReturn1.types -@@= skipped -6, +6 lines =@@ - >arg : T - - const res1 = test((arg): number => 1); // ok -->res1 : number -->test((arg): number => 1) : number -+>res1 : unknown -+>test((arg): number => 1) : unknown - >test : (cb: (arg: T) => T) => T -->(arg): number => 1 : (arg: number) => number -->arg : number -+>(arg): number => 1 : (arg: unknown) => number -+>arg : unknown - >1 : 1 - - const res2 = test((arg): number => 'foo'); // error -->res2 : number -->test((arg): number => 'foo') : number -+>res2 : unknown -+>test((arg): number => 'foo') : unknown - >test : (cb: (arg: T) => T) => T -->(arg): number => 'foo' : (arg: number) => number -->arg : number -+>(arg): number => 'foo' : (arg: unknown) => number -+>arg : unknown - >'foo' : "foo" - - export declare function linkedSignal(options: { -@@= skipped -29, +29 lines =@@ - }): D; - - const signal = linkedSignal({ -->signal : number -->linkedSignal({ source: () => 3, computation: (s): number => 3,}) : number -+>signal : unknown -+>linkedSignal({ source: () => 3, computation: (s): number => 3,}) : unknown - >linkedSignal : (options: { source: () => S; computation: (source: NoInfer) => D; }) => D -->{ source: () => 3, computation: (s): number => 3,} : { source: () => number; computation: (s: number) => number; } -+>{ source: () => 3, computation: (s): number => 3,} : { source: () => number; computation: (s: unknown) => number; } - - source: () => 3, - >source : () => number -@@= skipped -11, +11 lines =@@ - >3 : 3 - - computation: (s): number => 3, -->computation : (s: number) => number -->(s): number => 3 : (s: number) => number -->s : number -+>computation : (s: unknown) => number -+>(s): number => 3 : (s: unknown) => number -+>s : unknown - >3 : 3 - - }); -@@= skipped -19, +19 lines =@@ - } - - const _1 = new Foo((name: string, { x }): { name: string; x: number } => ({ -->_1 : Foo -->new Foo((name: string, { x }): { name: string; x: number } => ({ name, x,})) : Foo -+>_1 : Foo -+>new Foo((name: string, { x }): { name: string; x: number } => ({ name, x,})) : Foo - >Foo : typeof Foo -->(name: string, { x }): { name: string; x: number } => ({ name, x,}) : (name: string, { x }: { x: number; other: NoInfer<{ name: string; x: number; }>; }) => { name: string; x: number; } -+>(name: string, { x }): { name: string; x: number } => ({ name, x,}) : (name: string, { x }: { x: number; other: unknown; }) => { name: string; x: number; } - >name : string - >x : number - >name : string \ No newline at end of file