diff --git a/resolver/cel.go b/resolver/cel.go index 92e85eb4..3c1fe7ca 100644 --- a/resolver/cel.go +++ b/resolver/cel.go @@ -8,6 +8,7 @@ import ( celtypes "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" grpcfedcel "github.com/mercari/grpc-federation/grpc/federation/cel" @@ -16,11 +17,13 @@ import ( type CELRegistry struct { *celtypes.Registry - messageMap map[string]*Message - enumTypeMap map[*celtypes.Type]*Enum - enumValueMap map[string]*EnumValue - usedEnumValueMap map[*EnumValue]struct{} - errs []error + registryFiles *protoregistry.Files + registeredFileMap map[string]struct{} + messageMap map[string]*Message + enumTypeMap map[*celtypes.Type]*Enum + enumValueMap map[string]*EnumValue + usedEnumValueMap map[*EnumValue]struct{} + errs []error } func (r *CELRegistry) clear() { @@ -127,30 +130,66 @@ func toEnumSelectorCELType(sel *Message) *cel.Type { func newCELRegistry(messageMap map[string]*Message, enumValueMap map[string]*EnumValue) *CELRegistry { return &CELRegistry{ - Registry: celtypes.NewEmptyRegistry(), - messageMap: messageMap, - enumTypeMap: make(map[*celtypes.Type]*Enum), - enumValueMap: enumValueMap, - usedEnumValueMap: make(map[*EnumValue]struct{}), + Registry: celtypes.NewEmptyRegistry(), + registryFiles: new(protoregistry.Files), + registeredFileMap: make(map[string]struct{}), + messageMap: messageMap, + enumTypeMap: make(map[*celtypes.Type]*Enum), + enumValueMap: enumValueMap, + usedEnumValueMap: make(map[*EnumValue]struct{}), } } -func (r *CELRegistry) RegisterFiles(files ...*descriptorpb.FileDescriptorProto) error { - registryFiles, err := protodesc.NewFiles(&descriptorpb.FileDescriptorSet{ - File: files, - }) - if err != nil { - return err +func (r *CELRegistry) RegisterFiles(fds ...*descriptorpb.FileDescriptorProto) error { + fileMap := make(map[string]*descriptorpb.FileDescriptorProto) + for _, fd := range fds { + fileName := fd.GetName() + if _, ok := fileMap[fileName]; ok { + return fmt.Errorf("file appears multiple times: %q", fileName) + } + fileMap[fileName] = fd } - for _, file := range files { - rf, err := protodesc.NewFile(file, registryFiles) - if err != nil { + for _, fd := range fileMap { + if err := r.registerFileDeps(fd, fileMap); err != nil { return err } - if err := r.Registry.RegisterDescriptor(rf); err != nil { + } + return nil +} + +func (r *CELRegistry) registerFileDeps(fd *descriptorpb.FileDescriptorProto, fileMap map[string]*descriptorpb.FileDescriptorProto) error { + // set the entry to nil while descending into a file's dependencies to detect cycles. + fileName := fd.GetName() + fileMap[fileName] = nil + for _, dep := range fd.GetDependency() { + depFD, ok := fileMap[dep] + if depFD == nil { + if ok { + return fmt.Errorf("import cycle in file: %q", dep) + } + continue + } + if err := r.registerFileDeps(depFD, fileMap); err != nil { return err } } + // delete the entry once dependencies are processed. + delete(fileMap, fileName) + if _, exists := r.registeredFileMap[fileName]; exists { + return nil + } + + f, err := protodesc.NewFile(fd, r.registryFiles) + if err != nil { + return err + } + if err := r.registryFiles.RegisterFile(f); err != nil { + return err + } + if err := r.Registry.RegisterDescriptor(f); err != nil { + return err + } + r.registeredFileMap[fileName] = struct{}{} return nil } diff --git a/resolver/resolver.go b/resolver/resolver.go index d9bcbd05..3f08d02d 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -43,6 +43,7 @@ type Resolver struct { protoPackageNameToFileDefs map[string][]*descriptorpb.FileDescriptorProto protoPackageNameToPackage map[string]*Package celPluginMap map[string]*CELPlugin + enumAccessors []cel.EnvOption serviceToRuleMap map[*Service]*federation.ServiceRule methodToRuleMap map[*Method]*federation.MethodRule @@ -146,9 +147,6 @@ func (r *Resolver) ResolveWellknownFiles() (Files, error) { } func (r *Resolver) Resolve() (*Result, error) { - if err := r.celRegistry.RegisterFiles(r.files...); err != nil { - return nil, err - } // In order to return multiple errors with source code location information, // we add all errors to the context when they occur. // Therefore, functions called from Resolve() do not return errors directly. @@ -161,6 +159,10 @@ func (r *Resolver) Resolve() (*Result, error) { return nil, err } + if err := r.celRegistry.RegisterFiles(r.files...); err != nil { + return nil, err + } + files := r.resolveFiles(ctx) r.resolveRule(ctx, files) @@ -339,6 +341,10 @@ func (r *Resolver) resolveFiles(ctx *context) []*File { for _, fileDef := range r.files { files = append(files, r.resolveFile(ctx, fileDef, source.NewLocationBuilder(fileDef.GetName()))) } + + // After resolving all file references, the enum references will also be resolved, + // allowing the construction of enumAccessors at this point. + r.resolveEnumAccessors() return files } @@ -3661,7 +3667,7 @@ func (r *Resolver) resolveMessageArgumentRecursive( builder := newMessageBuilderFromMessage(msg) arg := msg.Rule.MessageArgument fileDesc := messageArgumentFileDescriptor(arg) - if err := r.celRegistry.RegisterFiles(append(r.files, fileDesc)...); err != nil { + if err := r.celRegistry.RegisterFiles(fileDesc); err != nil { ctx.addError( ErrWithLocation( err.Error(), @@ -4472,7 +4478,7 @@ func (r *Resolver) createServiceCELEnv(svc *Service, env *Env) (*cel.Env, error) if env != nil { envMsg := envVarsToMessage(svc.File, svc.Name, env.Vars) fileDesc := dynamicMsgFileDescriptor(envMsg, strings.Replace(svc.FQDN()+"Env", ".", "_", -1)) - if err := r.celRegistry.RegisterFiles(append(r.files, fileDesc)...); err != nil { + if err := r.celRegistry.RegisterFiles(fileDesc); err != nil { return nil, err } envOpts = append(envOpts, cel.Variable("grpc.federation.env", cel.ObjectType(envMsg.FQDN()))) @@ -4487,7 +4493,7 @@ func (r *Resolver) createMessageCELEnv(ctx *context, msg *Message, svcMsgSet map envMsg := r.buildEnvMessage(ctx, msg, svcMsgSet, builder) if envMsg != nil { fileDesc := dynamicMsgFileDescriptor(envMsg, strings.Replace(msg.FQDN()+"Env", ".", "_", -1)) - if err := r.celRegistry.RegisterFiles(append(r.files, fileDesc)...); err != nil { + if err := r.celRegistry.RegisterFiles(fileDesc); err != nil { return nil, err } envOpts = append(envOpts, cel.Variable("grpc.federation.env", cel.ObjectType(envMsg.FQDN()))) @@ -4495,7 +4501,7 @@ func (r *Resolver) createMessageCELEnv(ctx *context, msg *Message, svcMsgSet map svcVarsMsg := r.buildServiceVariablesMessage(ctx, msg, svcMsgSet, builder) if svcVarsMsg != nil { fileDesc := dynamicMsgFileDescriptor(svcVarsMsg, strings.Replace(msg.FQDN()+"Variable", ".", "_", -1)) - if err := r.celRegistry.RegisterFiles(append(r.files, fileDesc)...); err != nil { + if err := r.celRegistry.RegisterFiles(fileDesc); err != nil { return nil, err } envOpts = append(envOpts, cel.Variable("grpc.federation.var", cel.ObjectType(svcVarsMsg.FQDN()))) @@ -4518,7 +4524,7 @@ func (r *Resolver) createCELEnv(envOpts ...cel.EnvOption) (*cel.Env, error) { cel.ASTValidators(grpcfedcel.NewASTValidators()...), cel.Variable(federation.ContextVariableName, cel.ObjectType(federation.ContextTypeName)), }...) - envOpts = append(envOpts, r.enumAccessors()...) + envOpts = append(envOpts, r.enumAccessors...) envOpts = append(envOpts, r.enumOperators()...) for _, plugin := range r.celPluginMap { envOpts = append(envOpts, cel.Lib(plugin)) @@ -4719,10 +4725,9 @@ func (r *Resolver) buildServiceVariablesMessage(ctx *context, msg *Message, svcM return svcVarsToMessage(msg.File, msg.Name, svcVars) } -func (r *Resolver) enumAccessors() []cel.EnvOption { - var ret []cel.EnvOption +func (r *Resolver) resolveEnumAccessors() { for _, enum := range r.cachedEnumMap { - ret = append(ret, + r.enumAccessors = append(r.enumAccessors, []cel.EnvOption{ cel.Function( fmt.Sprintf("%s.name", enum.FQDN()), cel.Overload(fmt.Sprintf("%s_name_int_string", enum.FQDN()), []*cel.Type{cel.IntType}, cel.StringType, @@ -4750,12 +4755,11 @@ func (r *Resolver) enumAccessors() []cel.EnvOption { cel.BinaryBinding(func(enumValue, key ref.Val) ref.Val { return nil }), ), ), - ) + }...) for _, value := range enum.Values { r.cachedEnumValueMap[value.FQDN()] = value } } - return ret } // enumOperators an enum may be treated as an `opaque` or as an `int`.