Skip to content

Commit a4b0863

Browse files
authored
First pass at improving generated code (#1)
* swap over imports to the fork, use go modules, remove dependency on envconfig * update client.go to create arguments for each command based on the fields of the request object, add some protos to test against as well as new servers for those protos. TODO: clean up how STDIN is handled so you don't need flags and also objects; clean up generated code; clean up generated command line help text - make it use proto comments; handle enum values * make flags the default method of getting request data and force users to opt-in to stdin with a flag (--stdin or -f -); fix how flags are attached to commands so that each command's request-specific args/flags are grouped together and separate from the configuration for calling the endpoint (those flags are now attached at the top level command for the service and inherited by the commands for each method) * Shuffle around how commands are created so there's no need for init functions - instead you call `pb_pkg.ServiceClientCommand()` and the function returns the command you need. At creation time we wire up the flags and subcommands which we were previously initializing in init functions. * remove debug output from generated files, but leave in place the infra to generate it again if we need it to debugw * skip initializing map fields (it's not needed) and correctly initialize list fields; do not create flags for them yet * go mod tidy * fix some copy/paste errors, move around flag declarations so similar declaraions are next to each other; remove imports on request messages in our client handlers, since we're putting the clients into the same package as the protobufs themselves
1 parent dffa0bf commit a4b0863

33 files changed

+6299
-758
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.idea
2+
vendor
3+
protoc-gen-cobra

client/args.go

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package client
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
"strings"
8+
9+
pb "github.com/golang/protobuf/protoc-gen-go/descriptor"
10+
11+
"github.com/tetratelabs/protoc-gen-cobra/generator"
12+
)
13+
14+
type protoTypeCache map[string]entry
15+
type entry struct {
16+
d *pb.DescriptorProto
17+
f, n bool
18+
}
19+
20+
func (p protoTypeCache) byName(desc []*pb.DescriptorProto, name string, log func(...interface{})) (*pb.DescriptorProto, bool, bool) {
21+
return byName(p, desc, name, false, log)
22+
}
23+
24+
func byName(p protoTypeCache, desc []*pb.DescriptorProto, name string, nested bool, log func(...interface{})) (*pb.DescriptorProto, bool, bool) {
25+
log("searching for ", name)
26+
if entry, found := p[name]; found {
27+
log("* found ", entry.d.GetName(), "in cache ", fmt.Sprintf("%v", p))
28+
return entry.d, entry.f, entry.n
29+
}
30+
31+
for _, d := range desc {
32+
if d.GetName() == name {
33+
p[name] = entry{d, true, nested}
34+
log("* comparing against ", d.GetName(), " inserting into cache: \n// ", fmt.Sprintf("%v", p))
35+
return d, true, nested
36+
} else {
37+
log(" comparing against ", d.GetName())
38+
}
39+
if desc, found, _ := byName(p, d.NestedType, name, true, prefix(" ", log)); found {
40+
return desc, found, true
41+
}
42+
}
43+
return nil, false, false
44+
}
45+
46+
func prefix(pre string, l func(...interface{})) func(...interface{}) {
47+
return func(i ...interface{}) { l(append([]interface{}{pre}, i...)...) }
48+
}
49+
50+
func noop(...interface{}) {}
51+
52+
// first return is the instantiation of the struct and fields that are messages; second is the set of
53+
// flag declarations using the fields of the struct to receive values
54+
func (c *client) generateRequestFlags(file *generator.FileDescriptor, d *pb.DescriptorProto, types protoTypeCache) (string, []string) {
55+
if d == nil {
56+
return "", []string{}
57+
}
58+
flags := c.generateSubMessageRequestFlags("reqArgs", "", d, file, types)
59+
initialize := c.generateRequestInitialization(d, file, types)
60+
return initialize, flags
61+
}
62+
63+
func (c *client) generateSubMessageRequestFlags(objectName, flagPrefix string, d *pb.DescriptorProto, file *generator.FileDescriptor, types protoTypeCache) []string {
64+
out := make([]string, 0, len(d.Field))
65+
66+
for _, f := range d.Field {
67+
fieldName := goFieldName(f)
68+
fieldFlagName := strings.ToLower(fieldName)
69+
if f.GetLabel() == pb.FieldDescriptorProto_LABEL_REPEATED {
70+
// TODO
71+
out = append(out, fmt.Sprintf(`.PersistentFlags() // Warning: list flags are not yet supported (field %q)`, fieldName))
72+
continue
73+
}
74+
75+
switch f.GetType() {
76+
// Field is a complex type (another message, or an enum)
77+
case pb.FieldDescriptorProto_TYPE_MESSAGE:
78+
// if both type and name are set, descriptor must be either a message or enum
79+
_, _, ttype := inputNames(f.GetTypeName())
80+
if fdesc, found, _ := types.byName(file.MessageType, ttype, noop /*prefix("// ", c.P)*/); found {
81+
if fdesc.GetOptions().GetMapEntry() {
82+
// TODO
83+
return []string{fmt.Sprintf(`.PersistentFlags() // Warning: map flags are not yet supported (message %q)`, d.GetName())}
84+
}
85+
86+
flags := c.generateSubMessageRequestFlags(objectName+"."+fieldName, flagPrefix+fieldFlagName+"-", fdesc, file, types)
87+
out = append(out, flags...)
88+
}
89+
case pb.FieldDescriptorProto_TYPE_ENUM:
90+
// TODO
91+
case pb.FieldDescriptorProto_TYPE_STRING:
92+
out = append(out, fmt.Sprintf(`.PersistentFlags().StringVar(&%s.%s, "%s%s", "", "%s")`,
93+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
94+
case pb.FieldDescriptorProto_TYPE_BYTES:
95+
out = append(out, fmt.Sprintf(`.PersistentFlags().BytesBase64Var(&%s.%s, "%s%s", []byte{}, "%s")`,
96+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
97+
case pb.FieldDescriptorProto_TYPE_BOOL:
98+
out = append(out, fmt.Sprintf(`.PersistentFlags().BoolVar(&%s.%s, "%s%s", false, "%s")`,
99+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
100+
case pb.FieldDescriptorProto_TYPE_FLOAT:
101+
out = append(out, fmt.Sprintf(`.PersistentFlags().Float32Var(&%s.%s, "%s%s", 0, "%s")`,
102+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
103+
case pb.FieldDescriptorProto_TYPE_DOUBLE:
104+
out = append(out, fmt.Sprintf(`.PersistentFlags().Float64Var(&%s.%s, "%s%s", 0, "%s")`,
105+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
106+
case pb.FieldDescriptorProto_TYPE_INT32:
107+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int32Var(&%s.%s, "%s%s", 0, "%s")`,
108+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
109+
case pb.FieldDescriptorProto_TYPE_FIXED32:
110+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int32Var(&%s.%s, "%s%s", 0, "%s")`,
111+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
112+
case pb.FieldDescriptorProto_TYPE_SFIXED32:
113+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int32Var(&%s.%s, "%s%s", 0, "%s")`,
114+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
115+
case pb.FieldDescriptorProto_TYPE_SINT32:
116+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int32Var(&%s.%s, "%s%s", 0, "%s")`,
117+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
118+
case pb.FieldDescriptorProto_TYPE_UINT32:
119+
out = append(out, fmt.Sprintf(`.PersistentFlags().Uint32Var(&%s.%s, "%s%s", 0, "%s")`,
120+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
121+
case pb.FieldDescriptorProto_TYPE_INT64:
122+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int64Var(&%s.%s, "%s%s", 0, "%s")`,
123+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
124+
case pb.FieldDescriptorProto_TYPE_FIXED64:
125+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int64Var(&%s.%s, "%s%s", 0, "%s")`,
126+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
127+
case pb.FieldDescriptorProto_TYPE_SFIXED64:
128+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int64Var(&%s.%s, "%s%s", 0, "%s")`,
129+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
130+
case pb.FieldDescriptorProto_TYPE_SINT64:
131+
out = append(out, fmt.Sprintf(`.PersistentFlags().Int64Var(&%s.%s, "%s%s", 0, "%s")`,
132+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
133+
case pb.FieldDescriptorProto_TYPE_UINT64:
134+
out = append(out, fmt.Sprintf(`.PersistentFlags().Uint64Var(&%s.%s, "%s%s", 0, "%s")`,
135+
objectName, fieldName, flagPrefix, fieldFlagName, "get-comment-from-proto"))
136+
137+
case pb.FieldDescriptorProto_TYPE_GROUP:
138+
default:
139+
}
140+
}
141+
return out
142+
}
143+
144+
func goFieldName(f *pb.FieldDescriptorProto) string {
145+
fieldName := f.GetJsonName()
146+
if fieldName != "" {
147+
fieldName = strings.ToUpper(string(fieldName[0])) + fieldName[1:]
148+
}
149+
return fieldName
150+
}
151+
152+
func (c *client) generateRequestInitialization(d *pb.DescriptorProto, file *generator.FileDescriptor, types protoTypeCache) string {
153+
debug := &bytes.Buffer{}
154+
initialize := genReqInit(d, file, types, "", false, debug, noop /*prefix("// ", c.P)*/)
155+
// c.P(debug.String())
156+
return initialize
157+
}
158+
159+
func genReqInit(d *pb.DescriptorProto, file *generator.FileDescriptor, types protoTypeCache, typePrefix string, repeated bool, w io.Writer, log func(...interface{})) string {
160+
if repeated {
161+
// if we're repeated, we only want to compute the type then bail, we won't figure out if we're trying to create an instance
162+
out := fmt.Sprintf("[]*%s%s{}", typePrefix, d.GetName())
163+
fmt.Fprintf(w, "// computed %q\n", out)
164+
return out
165+
}
166+
167+
fields := make(map[string]string)
168+
fmt.Fprintf(w, "// generating initialization for %s with prefix %q which has %d fields\n", d.GetName(), typePrefix, len(d.Field))
169+
for _, f := range d.Field {
170+
switch f.GetType() {
171+
case pb.FieldDescriptorProto_TYPE_MESSAGE:
172+
_, _, ttype := inputNames(f.GetTypeName())
173+
desc, found, nested := types.byName(file.MessageType, ttype, log)
174+
fmt.Fprintf(w, "// searching for type %q with ttype %q for field %q\n", f.GetTypeName(), ttype, f.GetName())
175+
if !found {
176+
fmt.Fprint(w, "// not found, skipping\n")
177+
continue
178+
}
179+
180+
if desc.GetOptions().GetMapEntry() {
181+
fmt.Fprintf(w, "// skipping map fields, which do not need to be initialized")
182+
continue
183+
}
184+
185+
prefix := typePrefix
186+
if nested {
187+
prefix += d.GetName() + "_"
188+
}
189+
190+
fmt.Fprintf(w, "// found, recursing with %q\n", desc.GetName())
191+
m := genReqInit(desc, file, types, prefix, listField(f), w, log)
192+
fmt.Fprintf(w, "// found field %q which we'll initialize with %q\n", goFieldName(f), m)
193+
fields[goFieldName(f)] = m
194+
default:
195+
fmt.Fprintf(w, "// found non-message field %q\n", f.GetName())
196+
}
197+
}
198+
199+
vals := make([]string, 0, len(fields))
200+
for n, v := range fields {
201+
vals = append(vals, n+": "+v)
202+
}
203+
values := "{}"
204+
if len(vals) > 0 {
205+
values = fmt.Sprintf("{\n%s,\n}", strings.Join(vals, ",\n"))
206+
}
207+
208+
prefix := fmt.Sprintf("&%s%s", typePrefix, d.GetName())
209+
210+
out := prefix + values
211+
fmt.Fprintf(w, "// computed %q\n", out)
212+
return out
213+
}
214+
215+
func listField(d *pb.FieldDescriptorProto) bool {
216+
return d.GetLabel() == pb.FieldDescriptorProto_LABEL_REPEATED
217+
}

0 commit comments

Comments
 (0)