@@ -14,6 +14,7 @@ import (
14
14
"golang.org/x/tools/go/analysis"
15
15
"golang.org/x/tools/go/analysis/passes/inspect"
16
16
"golang.org/x/tools/go/ast/inspector"
17
+ "golang.org/x/tools/internal/typeparams"
17
18
)
18
19
19
20
const Doc = `check for unkeyed composite literals
@@ -67,41 +68,52 @@ func run(pass *analysis.Pass) (interface{}, error) {
67
68
// skip whitelisted types
68
69
return
69
70
}
70
- under := typ .Underlying ()
71
- for {
72
- ptr , ok := under .(* types.Pointer )
73
- if ! ok {
74
- break
75
- }
76
- under = ptr .Elem ().Underlying ()
77
- }
78
- if _ , ok := under .(* types.Struct ); ! ok {
79
- // skip non-struct composite literals
80
- return
81
- }
82
- if isLocalType (pass , typ ) {
83
- // allow unkeyed locally defined composite literal
84
- return
71
+ terms , err := typeparams .StructuralTerms (typ )
72
+ if err != nil {
73
+ return // invalid type
85
74
}
75
+ for _ , term := range terms {
76
+ under := deref (term .Type ().Underlying ())
77
+ if _ , ok := under .(* types.Struct ); ! ok {
78
+ // skip non-struct composite literals
79
+ continue
80
+ }
81
+ if isLocalType (pass , term .Type ()) {
82
+ // allow unkeyed locally defined composite literal
83
+ continue
84
+ }
86
85
87
- // check if the CompositeLit contains an unkeyed field
88
- allKeyValue := true
89
- for _ , e := range cl .Elts {
90
- if _ , ok := e .(* ast.KeyValueExpr ); ! ok {
91
- allKeyValue = false
92
- break
86
+ // check if the CompositeLit contains an unkeyed field
87
+ allKeyValue := true
88
+ for _ , e := range cl .Elts {
89
+ if _ , ok := e .(* ast.KeyValueExpr ); ! ok {
90
+ allKeyValue = false
91
+ break
92
+ }
93
93
}
94
- }
95
- if allKeyValue {
96
- // all the composite literal fields are keyed
94
+ if allKeyValue {
95
+ // all the composite literal fields are keyed
96
+ continue
97
+ }
98
+
99
+ pass .ReportRangef (cl , "%s composite literal uses unkeyed fields" , typeName )
97
100
return
98
101
}
99
-
100
- pass .ReportRangef (cl , "%s composite literal uses unkeyed fields" , typeName )
101
102
})
102
103
return nil , nil
103
104
}
104
105
106
+ func deref (typ types.Type ) types.Type {
107
+ for {
108
+ ptr , ok := typ .(* types.Pointer )
109
+ if ! ok {
110
+ break
111
+ }
112
+ typ = ptr .Elem ().Underlying ()
113
+ }
114
+ return typ
115
+ }
116
+
105
117
func isLocalType (pass * analysis.Pass , typ types.Type ) bool {
106
118
switch x := typ .(type ) {
107
119
case * types.Struct :
@@ -112,6 +124,8 @@ func isLocalType(pass *analysis.Pass, typ types.Type) bool {
112
124
case * types.Named :
113
125
// names in package foo are local to foo_test too
114
126
return strings .TrimSuffix (x .Obj ().Pkg ().Path (), "_test" ) == strings .TrimSuffix (pass .Pkg .Path (), "_test" )
127
+ case * typeparams.TypeParam :
128
+ return strings .TrimSuffix (x .Obj ().Pkg ().Path (), "_test" ) == strings .TrimSuffix (pass .Pkg .Path (), "_test" )
115
129
}
116
130
return false
117
131
}
0 commit comments