From b0a7984a507c8cce05382553f85566ac80c74c75 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 5 Feb 2025 13:24:48 -0800 Subject: [PATCH] Prototype: Refactor scope handling to allow wiring in the runtime representation of annotations. PiperOrigin-RevId: 723628841 --- checker/BUILD | 1 + checker/checker_options.h | 8 + checker/internal/BUILD | 2 + checker/internal/type_check_env.cc | 13 + checker/internal/type_check_env.h | 11 +- checker/internal/type_checker_builder_impl.cc | 11 + checker/internal/type_checker_builder_impl.h | 1 + checker/internal/type_checker_impl.cc | 152 ++++++- checker/type_check_issue.cc | 18 +- checker/type_check_issue.h | 5 +- checker/type_checker_builder.h | 3 + checker/validation_result.h | 10 + common/decl.h | 61 +++ compiler/BUILD | 2 + compiler/compiler_factory_test.cc | 269 +++++++++++ eval/compiler/BUILD | 2 + eval/compiler/flat_expr_builder.cc | 427 ++++++++++++------ eval/eval/BUILD | 2 + eval/eval/evaluator_core.h | 18 +- eval/public/cel_options.cc | 1 + eval/public/cel_options.h | 6 + internal/BUILD | 12 + internal/annotations.cc | 80 ++++ internal/annotations.h | 46 ++ parser/BUILD | 1 + parser/options.h | 25 + parser/parser.cc | 248 +++++++++- parser/parser_test.cc | 257 ++++++++++- runtime/runtime_options.h | 19 + 29 files changed, 1523 insertions(+), 188 deletions(-) create mode 100644 internal/annotations.cc create mode 100644 internal/annotations.h diff --git a/checker/BUILD b/checker/BUILD index a8ebbb653..a5fac153f 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -53,6 +53,7 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/checker/checker_options.h b/checker/checker_options.h index 5101281a6..ebab6863b 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -17,6 +17,8 @@ namespace cel { +enum class CheckerAnnotationSupport { kStrip, kRetain, kCheck }; + // Options for enabling core type checker features. struct CheckerOptions { // Enable overloads for numeric comparisons across types. @@ -55,6 +57,12 @@ struct CheckerOptions { // If exceeded, the checker will stop processing the ast and return // the current set of issues. int max_error_issues = 20; + + // Annotation support level. + // + // Default behavior is to strip annotations. + CheckerAnnotationSupport annotation_support = + CheckerAnnotationSupport::kStrip; }; } // namespace cel diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 336106073..9aa0276c1 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -137,6 +137,7 @@ cc_library( "//common:source", "//common:type", "//common:type_kind", + "//internal:annotations", "//internal:status_macros", "//parser:macro", "@com_google_absl//absl/algorithm:container", @@ -144,6 +145,7 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 1ac9bd618..b38a718ac 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -45,6 +45,19 @@ absl::Nullable TypeCheckEnv::LookupVariable( return nullptr; } +absl::Nullable TypeCheckEnv::LookupAnnotation( + absl::string_view name) const { + const TypeCheckEnv* scope = this; + while (scope != nullptr) { + if (auto it = scope->annotations_.find(name); + it != scope->annotations_.end()) { + return &it->second; + } + scope = scope->parent_; + } + return nullptr; +} + absl::Nullable TypeCheckEnv::LookupFunction( absl::string_view name) const { const TypeCheckEnv* scope = this; diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index f42a205a9..fa63f3237 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -84,10 +84,6 @@ class VariableScope { // // This class is thread-compatible. class TypeCheckEnv { - private: - using VariableDeclPtr = absl::Nonnull; - using FunctionDeclPtr = absl::Nonnull; - public: explicit TypeCheckEnv( absl::Nonnull> @@ -130,6 +126,10 @@ class TypeCheckEnv { return variables_.insert({decl.name(), std::move(decl)}).second; } + bool InsertAnnotationIfAbsent(AnnotationDecl decl) { + return annotations_.insert({decl.name(), std::move(decl)}).second; + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -158,6 +158,8 @@ class TypeCheckEnv { absl::string_view name) const; absl::Nullable LookupFunction( absl::string_view name) const; + absl::Nullable LookupAnnotation( + absl::string_view name) const; absl::StatusOr> LookupTypeName( absl::string_view name) const; @@ -195,6 +197,7 @@ class TypeCheckEnv { // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; + absl::flat_hash_map annotations_; // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 4897205a4..0a07d83f0 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -205,6 +205,17 @@ absl::Status TypeCheckerBuilderImpl::MergeFunction(const FunctionDecl& decl) { return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddAnnotation(const AnnotationDecl& decl) { + if (decl.name().empty()) { + return absl::InvalidArgumentError("annotation name must not be empty"); + } + if (!env_.InsertAnnotationIfAbsent(decl)) { + return absl::AlreadyExistsError( + absl::StrCat("annotation '", decl.name(), "' already exists")); + } + return absl::OkStatus(); +} + void TypeCheckerBuilderImpl::AddTypeProvider( std::unique_ptr provider) { env_.AddTypeProvider(std::move(provider)); diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index c9028f90b..99b4612fa 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -60,6 +60,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; absl::Status AddFunction(const FunctionDecl& decl) override; + absl::Status AddAnnotation(const AnnotationDecl& decl) override; void SetExpectedType(const Type& type) override; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index f5c8c481b..c0bb18c81 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -51,6 +51,7 @@ #include "common/source.h" #include "common/type.h" #include "common/type_kind.h" +#include "internal/annotations.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" @@ -68,6 +69,10 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } +static const TraversalOptions kTraversalOptions = { + /*.use_comprehension_callbacks=*/true, +}; + SourceLocation ComputeSourceLocation(const AstImpl& ast, int64_t expr_id) { const auto& source_info = ast.source_info(); auto iter = source_info.positions().find(expr_id); @@ -235,6 +240,8 @@ absl::StatusOr FlattenType(const Type& type) { } } +using AnnotationMap = internal::AnnotationMap; + class ResolveVisitor : public AstVisitorBase { public: struct FunctionResolution { @@ -247,13 +254,17 @@ class ResolveVisitor : public AstVisitorBase { const TypeCheckEnv& env, const AstImpl& ast, TypeInferenceContext& inference_context, std::vector& issues, + AnnotationMap& annotations, + cel::CheckerAnnotationSupport annotation_support, absl::Nonnull arena) : container_(container), + annotation_support_(annotation_support), namespace_generator_(std::move(namespace_generator)), env_(&env), inference_context_(&inference_context), issues_(&issues), ast_(&ast), + annotations_(&annotations), root_scope_(env.MakeVariableScope()), arena_(arena), current_scope_(&root_scope_) {} @@ -265,6 +276,43 @@ class ResolveVisitor : public AstVisitorBase { return; } expr_stack_.pop_back(); + if (!status_.ok()) { + return; + } + if (annotation_support_ != CheckerAnnotationSupport::kCheck) { + return; + } + auto annotations = annotations_->find(expr.id()); + if (annotations == annotations_->end()) { + return; + } + if (annotation_context_.has_value()) { + issues_->push_back( + TypeCheckIssue::CreateError(ComputeSourceLocation(*ast_, expr.id()), + "Nested annotations are not supported.")); + return; + } + auto annotation_scope = current_scope_->MakeNestedScope(); + VariableScope* annotation_scope_ptr = annotation_scope.get(); + // bit of a misuse, but annotation scope is largely the same as for + // comprehensions. + comprehension_vars_.push_back(std::move(annotation_scope)); + + annotation_context_ = {current_scope_}; + current_scope_ = annotation_scope_ptr; + Type annotated_expr_type = GetDeducedType(&expr); + annotation_scope_ptr->InsertVariableIfAbsent( + MakeVariableDecl("cel.annotated_value", annotated_expr_type)); + + // Note: this does not need to happen now during the main traversal, but + // it's a easier to reason about for me. It's equally valid to just record + // the relevant annotations and do a separate check pass later. + for (const auto& annotation : annotations->second) { + CheckAnnotation(annotation, expr, annotated_expr_type); + } + + current_scope_ = annotation_context_->parent; + annotation_context_.reset(); } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -341,6 +389,10 @@ class ResolveVisitor : public AstVisitorBase { const FunctionDecl* decl; }; + struct AnnotationContext { + const VariableScope* parent; + }; + void ResolveSimpleIdentifier(const Expr& expr, absl::string_view name); void ResolveQualifiedIdentifier(const Expr& expr, @@ -459,12 +511,17 @@ class ResolveVisitor : public AstVisitorBase { return DynType(); } + void CheckAnnotation(const internal::AnnotationRep& annotation_expr, + const Expr& annotated_expr, const Type& annotated_type); + absl::string_view container_; + CheckerAnnotationSupport annotation_support_; NamespaceGenerator namespace_generator_; absl::Nonnull env_; absl::Nonnull inference_context_; absl::Nonnull*> issues_; absl::Nonnull ast_; + absl::Nonnull annotations_; VariableScope root_scope_; absl::Nonnull arena_; @@ -479,6 +536,7 @@ class ResolveVisitor : public AstVisitorBase { absl::flat_hash_set deferred_select_operations_; std::vector> comprehension_vars_; std::vector comprehension_scopes_; + absl::optional annotation_context_; absl::Status status_; int error_count_ = 0; @@ -535,6 +593,64 @@ void ResolveVisitor::PostVisitIdent(const Expr& expr, const IdentExpr& ident) { } } +void ResolveVisitor::CheckAnnotation(const internal::AnnotationRep& annotation, + const Expr& annotated_expr, + const Type& annotated_type) { + const auto* annotation_decl = env_->LookupAnnotation(annotation.name); + if (annotation_decl == nullptr) { + ReportIssue(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, annotated_expr.id()), + absl::StrCat("undefined annotation '", annotation.name, "'"))); + return; + } + + // Checking if assignable to Dyn may influence the type inference so skip + // here. + if (!annotation_decl->applicable_type().IsDyn()) { + if (!inference_context_->IsAssignable(annotated_type, + annotation_decl->applicable_type())) { + ReportIssue(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, annotated_expr.id()), + absl::StrCat( + "annotation '", annotation.name, "' is not applicable to type '", + inference_context_->FinalizeType(annotated_type).DebugString(), + "'"))); + return; + } + } + + if (annotation.inspect_only) { + // Nothing to do -- the value expression is not intended to be evaluated. + // Examples are for things like a pointer to another file if the + // subexpression is inlined from somewhere else. + return; + } + + // TODO - re-entrant traversal bypasses the complexity limits. + AstTraverse(*annotation.value_expr, *this, kTraversalOptions); + + if (!status_.ok()) { + return; + } + + Type value_expression_type = GetDeducedType(annotation.value_expr); + + if (!annotation_decl->expected_type().IsDyn()) { + if (!inference_context_->IsAssignable(value_expression_type, + annotation_decl->expected_type())) { + ReportIssue(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, annotated_expr.id()), + absl::StrCat("annotation '", annotation.name, + "' value expression type '", + inference_context_->FinalizeType(value_expression_type) + .DebugString(), + "' is not assignable to '", + annotation_decl->expected_type().DebugString(), "'"))); + return; + } + } +} + void ResolveVisitor::PostVisitConst(const Expr& expr, const Constant& constant) { switch (constant.kind().index()) { @@ -1276,13 +1392,28 @@ absl::StatusOr TypeCheckerImpl::Check( TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); + + internal::AnnotationMap annotation_exprs; + Expr* root = &ast_impl.root_expr(); + if (ast_impl.root_expr().has_call_expr() && + ast_impl.root_expr().call_expr().function() == "cel.@annotated" && + ast_impl.root_expr().call_expr().args().size() == 2) { + if (options_.annotation_support == CheckerAnnotationSupport::kStrip) { + ast_impl.root_expr() = + std::move(ast_impl.root_expr().mutable_call_expr().mutable_args()[0]); + root = &ast_impl.root_expr(); + } else { + annotation_exprs = internal::BuildAnnotationMap(ast_impl); + root = &ast_impl.root_expr().mutable_call_expr().mutable_args()[0]; + } + } + ResolveVisitor visitor(env_.container(), std::move(generator), env_, ast_impl, - type_inference_context, issues, &type_arena); + type_inference_context, issues, annotation_exprs, + options_.annotation_support, &type_arena); - TraversalOptions opts; - opts.use_comprehension_callbacks = true; bool error_limit_reached = false; - auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts); + auto traversal = AstTraversal::Create(*root, kTraversalOptions); for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { bool has_next = traversal.Step(visitor); @@ -1300,7 +1431,7 @@ absl::StatusOr TypeCheckerImpl::Check( if (!traversal.IsDone() && !error_limit_reached) { return absl::InvalidArgumentError( - absl::StrCat("Maximum expression node count exceeded: ", + absl::StrCat("maximum expression node count exceeded: ", options_.max_expression_node_count)); } @@ -1309,7 +1440,7 @@ absl::StatusOr TypeCheckerImpl::Check( {}, absl::StrCat("maximum number of ERROR issues exceeded: ", options_.max_error_issues))); } else if (env_.expected_type().has_value()) { - visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); + visitor.AssertExpectedType(*root, *env_.expected_type()); } // If any issues are errors, return without an AST. @@ -1324,7 +1455,14 @@ absl::StatusOr TypeCheckerImpl::Check( // been invalidated by other updates. ResolveRewriter rewriter(visitor, type_inference_context, options_, ast_impl.reference_map(), ast_impl.type_map()); - AstRewrite(ast_impl.root_expr(), rewriter); + AstRewrite(*root, rewriter); + if (options_.annotation_support == CheckerAnnotationSupport::kCheck) { + for (auto& annotations : annotation_exprs) { + for (auto& annotation : annotations.second) { + AstRewrite(*annotation.value_expr, rewriter); + } + } + } CEL_RETURN_IF_ERROR(rewriter.status()); diff --git a/checker/type_check_issue.cc b/checker/type_check_issue.cc index 1f32ee54e..dfd1474da 100644 --- a/checker/type_check_issue.cc +++ b/checker/type_check_issue.cc @@ -42,15 +42,21 @@ absl::string_view SeverityString(TypeCheckIssue::Severity severity) { } // namespace -std::string TypeCheckIssue::ToDisplayString(const Source& source) const { +std::string TypeCheckIssue::ToDisplayString(const Source* source) const { int column = location_.column; // convert to 1-based if it's in range. int display_column = column >= 0 ? column + 1 : column; - return absl::StrCat( - absl::StrFormat("%s: %s:%d:%d: %s", SeverityString(severity_), - source.description(), location_.line, display_column, - message_), - source.DisplayErrorLocation(location_)); + if (source) { + return absl::StrCat( + absl::StrFormat("%s: %s:%d:%d: %s", SeverityString(severity_), + source->description(), location_.line, display_column, + message_), + source->DisplayErrorLocation(location_)); + } else { + return absl::StrFormat("%s: %s:%d:%d: %s", SeverityString(severity_), + "", location_.line, display_column, + message_); + } } } // namespace cel diff --git a/checker/type_check_issue.h b/checker/type_check_issue.h index d58f39658..2a050e6b9 100644 --- a/checker/type_check_issue.h +++ b/checker/type_check_issue.h @@ -48,7 +48,10 @@ class TypeCheckIssue { } // Format the issue highlighting the source position. - std::string ToDisplayString(const Source& source) const; + std::string ToDisplayString(const Source& source) const { + return ToDisplayString(&source); + } + std::string ToDisplayString(const Source* source) const; absl::string_view message() const { return message_; } Severity severity() const { return severity_; } diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 21f3c35a5..63a7bc79c 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -73,6 +73,9 @@ class TypeCheckerBuilder { // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; + // Registers an annotation that may be referenced in the expression. + virtual absl::Status AddAnnotation(const AnnotationDecl& decl) = 0; + // Sets the expected type for checked expressions. // // Validation will fail with an ERROR level issue if the deduced type of the diff --git a/checker/validation_result.h b/checker/validation_result.h index c5ed50b35..06cc93000 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -16,12 +16,14 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #include +#include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" @@ -68,6 +70,14 @@ class ValidationResult { return std::move(source_); } + std::string FormatError() const { + std::string out; + for (const auto& issue : issues_) { + absl::StrAppend(&out, issue.ToDisplayString(source_.get()), "\n"); + } + return out; + } + private: absl::Nullable> ast_; std::vector issues_; diff --git a/common/decl.h b/common/decl.h index d2ceaca19..ccf7fce13 100644 --- a/common/decl.h +++ b/common/decl.h @@ -108,6 +108,67 @@ class VariableDecl final { absl::optional value_; }; +// `AnnotationDecl` represents a declaration for a Annotation, composed of its +// name and applicable expressions, and optionally an expected value type. +class AnnotationDecl final { + public: + AnnotationDecl() = default; + AnnotationDecl(const AnnotationDecl&) = default; + AnnotationDecl(AnnotationDecl&&) = default; + AnnotationDecl& operator=(const AnnotationDecl&) = default; + AnnotationDecl& operator=(AnnotationDecl&&) = default; + + const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return name_; + } + + void set_name(std::string name) { name_ = std::move(name); } + + void set_name(absl::string_view name) { + name_.assign(name.data(), name.size()); + } + + void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } + + std::string release_name() { + std::string released; + released.swap(name_); + return released; + } + + const Type& applicable_type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return applicable_type_; + } + + Type& mutable_applicable_type() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return applicable_type_; + } + + void set_type(Type type) { mutable_applicable_type() = std::move(type); } + + bool inspect_only() const { return inspect_only_; } + + void set_inspect_only(bool inspect_only) { inspect_only_ = inspect_only; } + + const Type& expected_type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return expected_type_; + } + + Type& mutable_expected_type() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return expected_type_; + } + + void set_expected_type(Type type) { + mutable_expected_type() = std::move(type); + } + + private: + std::string name_; + Type applicable_type_ = DynType{}; + bool inspect_only_ = false; + Type expected_type_ = DynType{}; +}; + inline VariableDecl MakeVariableDecl(std::string name, Type type) { VariableDecl variable_decl; variable_decl.set_name(std::move(name)); diff --git a/compiler/BUILD b/compiler/BUILD index 22894ee78..e6c7240c1 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -61,9 +61,11 @@ cc_test( deps = [ ":compiler", ":compiler_factory", + "//checker:checker_options", "//checker:optional", "//checker:standard_library", "//checker:type_check_issue", + "//checker:type_checker_builder", "//checker:validation_result", "//common:decl", "//common:type", diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 9d6c663b3..2b2c28cc1 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -19,9 +19,11 @@ #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "checker/checker_options.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/type.h" @@ -31,6 +33,7 @@ #include "parser/macro.h" #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -113,6 +116,272 @@ TEST(CompilerFactoryTest, Works) { )~bool^logical_and)"); } +TEST(CompilerFactoryTest, AnnotationSupport) { + CompilerOptions options; + options.parser_options.enable_annotations = true; + options.parser_options.enable_hidden_accumulator_var = true; + options.checker_options.annotation_support = CheckerAnnotationSupport::kCheck; + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), + options)); + + absl::Status s; + s.Update(builder->AddLibrary(StandardCheckerLibrary())); + s.Update(builder->AddLibrary( + CompilerLibrary("test", [](TypeCheckerBuilder& builder) -> absl::Status { + absl::Status s; + AnnotationDecl decl; + decl.set_name("Describe"); + decl.set_expected_type(StringType()); + s.Update(builder.AddAnnotation(std::move(decl))); + s.Update(builder.AddVariable(MakeVariableDecl("foo", MapType()))); + s.Update(builder.AddVariable(MakeVariableDecl("bar", StringType()))); + + return s; + }))); + + ASSERT_THAT(s, IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(R"cel( + cel.annotate( + ['a', 'b', 'c'] in foo, + cel.Annotation{ + name: "Describe", + value: "foo " + (cel.annotated_value ? "contains" : "does not contain") + + " something interesting" + }) || + cel.annotate( + ['d', 'e', 'f'].exists(x, x.endsWith(bar)), + cel.Annotation{ + name: "Describe", + value: "bar " + + (cel.annotated_value ? "is" : "is not" ) + + "an interesting suffix" + }) + )cel")); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + EXPECT_EQ(FormatBaselineAst(*result.GetAst()), + R"(cel.@annotated( + _||_( + @in( + [ + "a"~string, + "b"~string, + "c"~string + ]~list(string), + foo~map(dyn, dyn)^foo + )~bool^in_map, + __comprehension__( + // Variable + x, + // Target + [ + "d"~string, + "e"~string, + "f"~string + ]~list(string), + // Accumulator + @result, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + @result~bool^@result + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + @result~bool^@result, + x~string^x.endsWith( + bar~string^bar + )~bool^ends_with_string + )~bool^logical_or, + // Result + @result~bool^@result)~bool + )~bool^logical_or, + { + 7:[ + cel.Annotation{ + name:"Describe", + value:_+_( + _+_( + "foo "~string, + _?_:_( + cel.annotated_value~bool^cel.annotated_value, + "contains"~string, + "does not contain"~string + )~string^conditional + )~string^add_string, + " something interesting"~string + )~string^add_string + } + ], + 40:[ + cel.Annotation{ + name:"Describe", + value:_+_( + _+_( + "bar "~string, + _?_:_( + cel.annotated_value~bool^cel.annotated_value, + "is"~string, + "is not"~string + )~string^conditional + )~string^add_string, + "an interesting suffix"~string + )~string^add_string + } + ] + } +))"); +} + +TEST(CompilerFactoryTest, AnnotationScopingRules) { + CompilerOptions options; + options.parser_options.enable_annotations = true; + options.parser_options.enable_hidden_accumulator_var = true; + options.checker_options.annotation_support = CheckerAnnotationSupport::kCheck; + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), + options)); + google::protobuf::Arena arena; + Type map_list_string = + MapType(&arena, StringType(), ListType(&arena, StringType())); + absl::Status s; + s.Update(builder->AddLibrary(StandardCheckerLibrary())); + s.Update(builder->AddLibrary( + CompilerLibrary("test", [=](TypeCheckerBuilder& builder) -> absl::Status { + absl::Status s; + AnnotationDecl decl; + decl.set_name("Describe"); + decl.set_expected_type(StringType()); + s.Update(builder.AddAnnotation(std::move(decl))); + s.Update(builder.AddVariable( + MakeVariableDecl("memberships", map_list_string))); + s.Update(builder.AddVariable(MakeVariableDecl("user", StringType()))); + + return s; + }))); + + ASSERT_THAT(s, IsOk()); + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(R"cel( + cel.annotate( + ['g1', 'g2', 'g3'].all(g, + cel.annotate( + g in memberships[user], + cel.Annotation{ + name: "Describe", + value: "user '" + user + "' " + + (cel.annotated_value ? "is" : "is not") + + " a member of " + g + } + ) + ), + cel.Annotation{ + name: "Describe", + value: + "user '" + user + "' " + + (cel.annotated_value ? "is" : "is not") + + " a member of all required groups" + } + ) + )cel")); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + std::string adorned_ast = FormatBaselineAst(*result.GetAst()); + EXPECT_EQ(adorned_ast, + R"(cel.@annotated( + __comprehension__( + // Variable + g, + // Target + [ + "g1"~string, + "g2"~string, + "g3"~string + ]~list(string), + // Accumulator + @result, + // Init + true~bool, + // LoopCondition + @not_strictly_false( + @result~bool^@result + )~bool^not_strictly_false, + // LoopStep + _&&_( + @result~bool^@result, + @in( + g~string^g, + _[_]( + memberships~map(string, list(string))^memberships, + user~string^user + )~list(string)^index_map + )~bool^in_list + )~bool^logical_and, + // Result + @result~bool^@result)~bool, + { + 12:[ + cel.Annotation{ + name:"Describe", + value:_+_( + _+_( + _+_( + _+_( + _+_( + "user '"~string, + user~string^user + )~string^add_string, + "' "~string + )~string^add_string, + _?_:_( + cel.annotated_value~bool^cel.annotated_value, + "is"~string, + "is not"~string + )~string^conditional + )~string^add_string, + " a member of "~string + )~string^add_string, + g~string^g + )~string^add_string + } + ], + 41:[ + cel.Annotation{ + name:"Describe", + value:_+_( + _+_( + _+_( + _+_( + "user '"~string, + user~string^user + )~string^add_string, + "' "~string + )~string^add_string, + _?_:_( + cel.annotated_value~bool^cel.annotated_value, + "is"~string, + "is not"~string + )~string^conditional + )~string^add_string, + " a member of all required groups"~string + )~string^add_string + } + ] + } +))"); +} + TEST(CompilerFactoryTest, ParserLibrary) { ASSERT_OK_AND_ASSIGN( auto builder, diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 8973cd67e..3012d325b 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -125,6 +125,7 @@ cc_library( "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/eval:trace_step", + "//internal:annotations", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_issue", @@ -140,6 +141,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a657e9dae..984ed5528 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -29,11 +29,13 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/functional/any_invocable.h" +#include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -82,6 +84,7 @@ #include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" #include "eval/eval/trace_step.h" +#include "internal/annotations.h" #include "internal/status_macros.h" #include "runtime/internal/convert_constant.h" #include "runtime/internal/issue_collector.h" @@ -485,13 +488,13 @@ class FlatExprVisitor : public cel::AstVisitor { FlatExprVisitor( const Resolver& resolver, const cel::RuntimeOptions& options, std::vector> program_optimizers, - const absl::flat_hash_map& - reference_map, + const cel::internal::AnnotationMap& annotation_map, ValueManager& value_factory, IssueCollector& issue_collector, ProgramBuilder& program_builder, PlannerContext& extension_context, bool enable_optional_types) : resolver_(resolver), value_factory_(value_factory), + annotations_(annotation_map), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), options_(options), @@ -619,11 +622,13 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.ExitSubexpression(&expr); - if (!comprehension_stack_.empty() && - comprehension_stack_.back().is_optimizable_bind && - (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { - SetProgressStatusError( - MaybeExtractSubexpression(&expr, comprehension_stack_.back())); + if (!scope_stack_.empty()) { + ScopeRecord& scope = scope_stack_.back(); + if (auto* bind = absl::get_if(&scope.kind); bind != nullptr) { + if (&bind->comprehension->accu_init() == &expr) { + SetProgressStatusError(MaybeExtractSubexpression(&expr, *bind)); + } + } } if (block_.has_value()) { @@ -709,34 +714,43 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - if (!comprehension_stack_.empty()) { - for (int i = comprehension_stack_.size() - 1; i >= 0; i--) { - const ComprehensionStackRecord& record = comprehension_stack_[i]; - if (record.iter_var_in_scope && - record.comprehension->iter_var() == path) { - if (record.is_optimizable_bind) { - SetProgressStatusError(issue_collector_.AddIssue( - RuntimeIssue::CreateWarning(absl::InvalidArgumentError( - "Unexpected iter_var access in trivial comprehension")))); - return {-1, -1}; - } - return {static_cast(record.iter_slot), -1}; + + for (int i = scope_stack_.size() - 1; i >= 0; i--) { + ScopeRecord& scope = scope_stack_[i]; + if (auto* bind = absl::get_if(&scope.kind); bind != nullptr) { + if (bind->iter_var_in_scope && + bind->comprehension->iter_var() == path) { + SetProgressStatusError(issue_collector_.AddIssue( + RuntimeIssue::CreateWarning(absl::InvalidArgumentError( + "Unexpected iter_var access in trivial comprehension")))); + return {-1, -1}; } - if (record.iter_var2_in_scope && - record.comprehension->iter_var2() == path) { - return {static_cast(record.iter2_slot), -1}; + if (bind->accu_var_in_scope && + bind->comprehension->accu_var() == path) { + return {static_cast(bind->slot), bind->subexpression}; } - if (record.accu_var_in_scope && - record.comprehension->accu_var() == path) { - int slot = record.accu_slot; - int subexpression = -1; - if (record.is_optimizable_bind) { - subexpression = record.subexpression; - } - return {slot, subexpression}; + } else if (auto* comprehension_scope = + absl::get_if(&scope.kind); + comprehension_scope != nullptr) { + const auto* comprehension = comprehension_scope->comprehension; + if (comprehension_scope->iter_var_in_scope && + comprehension->iter_var() == path) { + return {static_cast(comprehension_scope->iter_slot), -1}; + } + if (comprehension_scope->iter_var2_in_scope && + comprehension->iter_var2() == path) { + return {static_cast(comprehension_scope->iter2_slot), -1}; + } + if (comprehension_scope->accu_var_in_scope && + comprehension->accu_var() == path) { + return {static_cast(comprehension_scope->accu_slot), -1}; } + } else { + // handle for annotations in follow up CL. + return {-1, -1}; } } + if (absl::StartsWith(path, "@it:") || absl::StartsWith(path, "@it2:") || absl::StartsWith(path, "@ac:")) { // If we see a CSE generated comprehension variable that was not @@ -1356,6 +1370,18 @@ class FlatExprVisitor : public cel::AstVisitor { slot_count = 3; } + // Account the slot such that it is not re-used for lazy evaluation. + // + // With lazy evaluation, the init expressions are effectively inlined at the + // first usage in the critical path (which is unknown at plan time). + // To account for this, we account all slots for lazy evaluation to the + // outermost scope where the lazy evaluation could occur context. + // + // For block, all slots are accounted to the block scope. (The top level + // call expression, generally). + // + // For bind, the slots are accounted to the outermost bind initializer + // scope. if (block_.has_value()) { BlockInfo& block = *block_; if (block.in) { @@ -1363,37 +1389,41 @@ class FlatExprVisitor : public cel::AstVisitor { slot_count = 0; } } - // If this is in the scope of an optimized bind accu-init, account the slots - // to the outermost bind-init scope. - // - // The init expression is effectively inlined at the first usage in the - // critical path (which is unknown at plan time), so the used slots need to - // be dedicated for the entire scope of that bind. - for (ComprehensionStackRecord& record : comprehension_stack_) { - if (record.in_accu_init && record.is_optimizable_bind) { - record.slot_count += slot_count; - slot_count = 0; - break; + + for (ScopeRecord& record : scope_stack_) { + if (auto* bind_scope = std::get_if(&record.kind)) { + if (bind_scope->in_accu_init) { + record.slot_count += slot_count; + slot_count = 0; + break; + } } // If no bind init subexpression, account normally. } - comprehension_stack_.push_back( - {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, - /*subexpression=*/-1, - /*.is_optimizable_list_append=*/ - IsOptimizableListAppend(&comprehension, - options_.enable_comprehension_list_append), - /*.is_optimizable_map_insert=*/IsOptimizableMapInsert(&comprehension), - /*.is_optimizable_bind=*/is_bind, - /*.iter_var_in_scope=*/false, - /*.iter_var2_in_scope=*/false, - /*.accu_var_in_scope=*/false, - /*.in_accu_init=*/false, - std::make_unique(this, options_.short_circuiting, - is_bind, iter_slot, iter2_slot, - accu_slot)}); - comprehension_stack_.back().visitor->PreVisit(&expr); + if (is_bind) { + scope_stack_.push_back(ScopeRecord{ + &expr, slot_count, + BindScope{&comprehension, accu_slot, false, false, false, -1, + std::make_unique( + this, options_.short_circuiting, is_bind, accu_slot, + accu_slot, accu_slot)}}); + } else { + scope_stack_.push_back(ScopeRecord{ + &expr, slot_count, + ComprehensionScope{ + &comprehension, iter_slot, iter2_slot, accu_slot, + IsOptimizableListAppend( + &comprehension, options_.enable_comprehension_list_append), + IsOptimizableMapInsert(&comprehension), + /*.iter_var_in_scope=*/false, + /*.iter_var2_in_scope=*/false, + /*.accu_var_in_scope=*/false, + std::make_unique( + this, options_.short_circuiting, is_bind, iter_slot, + iter2_slot, accu_slot)}}); + } + scope_stack_.back().comprehension_visitor()->PreVisit(&expr); } // Invoked after all child nodes are processed. @@ -1404,16 +1434,20 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - ComprehensionStackRecord& record = comprehension_stack_.back(); - if (comprehension_stack_.empty() || - record.comprehension != &comprehension_expr) { + if (scope_stack_.empty()) { + return; + } + + ScopeRecord& scope = scope_stack_.back(); + + if (scope.comprehension() != &comprehension_expr) { return; } - record.visitor->PostVisit(&expr); + scope.comprehension_visitor()->PostVisit(&expr); - index_manager_.ReleaseSlots(record.slot_count); - comprehension_stack_.pop_back(); + index_manager_.ReleaseSlots(scope.slot_count); + scope_stack_.pop_back(); } void PreVisitComprehensionSubexpression( @@ -1424,48 +1458,80 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - if (comprehension_stack_.empty() || - comprehension_stack_.back().comprehension != &compr) { + if (scope_stack_.empty() || scope_stack_.back().comprehension() != &compr) { return; } - ComprehensionStackRecord& record = comprehension_stack_.back(); + ScopeRecord& scope = scope_stack_.back(); - switch (comprehension_arg) { - case cel::ITER_RANGE: { - record.in_accu_init = false; - record.iter_var_in_scope = false; - record.iter_var2_in_scope = false; - record.accu_var_in_scope = false; - break; - } - case cel::ACCU_INIT: { - record.in_accu_init = true; - record.iter_var_in_scope = false; - record.iter_var2_in_scope = false; - record.accu_var_in_scope = false; - break; - } - case cel::LOOP_CONDITION: { - record.in_accu_init = false; - record.iter_var_in_scope = true; - record.iter_var2_in_scope = true; - record.accu_var_in_scope = true; - break; - } - case cel::LOOP_STEP: { - record.in_accu_init = false; - record.iter_var_in_scope = true; - record.iter_var2_in_scope = true; - record.accu_var_in_scope = true; - break; + if (auto* bind_scope = std::get_if(&scope.kind); + bind_scope != nullptr) { + switch (comprehension_arg) { + case cel::ITER_RANGE: { + bind_scope->in_accu_init = false; + bind_scope->iter_var_in_scope = false; + bind_scope->accu_var_in_scope = false; + break; + } + case cel::ACCU_INIT: { + bind_scope->in_accu_init = true; + bind_scope->iter_var_in_scope = false; + bind_scope->accu_var_in_scope = false; + break; + } + case cel::LOOP_CONDITION: { + bind_scope->in_accu_init = false; + bind_scope->iter_var_in_scope = true; + bind_scope->accu_var_in_scope = true; + break; + } + case cel::LOOP_STEP: { + bind_scope->in_accu_init = false; + bind_scope->iter_var_in_scope = true; + bind_scope->accu_var_in_scope = true; + break; + } + case cel::RESULT: { + bind_scope->in_accu_init = false; + bind_scope->iter_var_in_scope = false; + bind_scope->accu_var_in_scope = true; + break; + } } - case cel::RESULT: { - record.in_accu_init = false; - record.iter_var_in_scope = false; - record.iter_var2_in_scope = false; - record.accu_var_in_scope = true; - break; + } else if (auto* comprehension_scope = + std::get_if(&scope.kind); + comprehension_scope != nullptr) { + switch (comprehension_arg) { + case cel::ITER_RANGE: { + comprehension_scope->iter_var_in_scope = false; + comprehension_scope->iter_var2_in_scope = false; + comprehension_scope->accu_var_in_scope = false; + break; + } + case cel::ACCU_INIT: { + comprehension_scope->iter_var_in_scope = false; + comprehension_scope->iter_var2_in_scope = false; + comprehension_scope->accu_var_in_scope = false; + break; + } + case cel::LOOP_CONDITION: { + comprehension_scope->iter_var_in_scope = true; + comprehension_scope->iter_var2_in_scope = true; + comprehension_scope->accu_var_in_scope = true; + break; + } + case cel::LOOP_STEP: { + comprehension_scope->iter_var_in_scope = true; + comprehension_scope->iter_var2_in_scope = true; + comprehension_scope->accu_var_in_scope = true; + break; + } + case cel::RESULT: { + comprehension_scope->iter_var_in_scope = false; + comprehension_scope->iter_var2_in_scope = false; + comprehension_scope->accu_var_in_scope = true; + break; + } } } } @@ -1478,13 +1544,13 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - if (comprehension_stack_.empty() || - comprehension_stack_.back().comprehension != &compr) { + if (scope_stack_.empty() || scope_stack_.back().comprehension() != &compr) { return; } - SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( - comprehension_arg, comprehension_stack_.back().expr)); + SetProgressStatusError( + scope_stack_.back().comprehension_visitor()->PostVisitArg( + comprehension_arg, scope_stack_.back().expr)); } // Invoked after each argument node processed. @@ -1524,24 +1590,26 @@ class FlatExprVisitor : public cel::AstVisitor { } } - if (!comprehension_stack_.empty()) { - const ComprehensionStackRecord& comprehension = - comprehension_stack_.back(); - if (comprehension.is_optimizable_list_append) { - if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { - SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); + if (!scope_stack_.empty()) { + if (const ComprehensionScope* scope = + std::get_if(&scope_stack_.back().kind); + scope != nullptr) { + if (scope->is_optimizable_list_append) { + if (&(scope->comprehension->accu_init()) == &expr) { + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); + return; + } + AddStep(CreateMutableListStep(expr.id())); + return; + } + if (GetOptimizableListAppendOperand(scope->comprehension) == &expr) { return; } - AddStep(CreateMutableListStep(expr.id())); - return; - } - if (GetOptimizableListAppendOperand(comprehension.comprehension) == - &expr) { - return; } } } + absl::optional depth = RecursionEligible(); if (depth.has_value()) { auto deps = ExtractRecursiveDependencies(); @@ -1567,17 +1635,19 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - if (!comprehension_stack_.empty()) { - const ComprehensionStackRecord& comprehension = - comprehension_stack_.back(); - if (comprehension.is_optimizable_map_insert) { - if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { - SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); + if (!scope_stack_.empty()) { + if (const auto* scope = + absl::get_if(&scope_stack_.back().kind); + scope != nullptr) { + if (scope->is_optimizable_map_insert) { + if (&(scope->comprehension->accu_init()) == &expr) { + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); + return; + } + AddStep(CreateMutableMapStep(expr.id())); return; } - AddStep(CreateMutableMapStep(expr.id())); - return; } } } @@ -1782,25 +1852,62 @@ class FlatExprVisitor : public cel::AstVisitor { } private: - struct ComprehensionStackRecord { - const cel::ast_internal::Expr* expr; - const cel::ast_internal::Comprehension* comprehension; + struct ComprehensionScope { + const cel::ComprehensionExpr* comprehension; size_t iter_slot; size_t iter2_slot; size_t accu_slot; - size_t slot_count; - // -1 indicates this shouldn't be used. - int subexpression; bool is_optimizable_list_append; bool is_optimizable_map_insert; - bool is_optimizable_bind; bool iter_var_in_scope; bool iter_var2_in_scope; bool accu_var_in_scope; + std::unique_ptr visitor; + }; + + struct BindScope { + const cel::ComprehensionExpr* comprehension; + size_t slot; + bool accu_var_in_scope; + // Only used to check if the bind is malformed. + bool iter_var_in_scope; bool in_accu_init; + int subexpression; std::unique_ptr visitor; }; + struct AnnotationScope { + const cel::ast_internal::Expr* annotated_expr; + size_t slot; + int subexpression; + }; + + struct ScopeRecord { + const cel::ast_internal::Expr* expr; + size_t slot_count; + absl::variant kind; + + absl::Nullable comprehension_visitor() { + if (auto* comp = absl::get_if(&kind); + comp != nullptr) { + return comp->visitor.get(); + } else if (auto* bind = absl::get_if(&kind); bind != nullptr) { + return bind->visitor.get(); + } + return nullptr; + } + + absl::Nullable comprehension() { + if (auto* comp = absl::get_if(&kind); + comp != nullptr) { + return comp->comprehension; + } else if (auto* bind = absl::get_if(&kind); bind != nullptr) { + return bind->comprehension; + } + return nullptr; + } + }; + struct BlockInfo { // True if we are currently visiting the `cel.@block` node or any of its // children. @@ -1838,11 +1945,7 @@ class FlatExprVisitor : public cel::AstVisitor { } absl::Status MaybeExtractSubexpression(const cel::ast_internal::Expr* expr, - ComprehensionStackRecord& record) { - if (!record.is_optimizable_bind) { - return absl::OkStatus(); - } - + BindScope& record) { int index = program_builder_.ExtractSubexpression(expr); if (index == -1) { return absl::InternalError("Failed to extract subexpression"); @@ -1916,6 +2019,7 @@ class FlatExprVisitor : public cel::AstVisitor { const Resolver& resolver_; ValueManager& value_factory_; + const cel::internal::AnnotationMap& annotations_; absl::Status progress_status_; absl::flat_hash_map call_handlers_; @@ -1933,7 +2037,7 @@ class FlatExprVisitor : public cel::AstVisitor { const cel::RuntimeOptions& options_; - std::vector comprehension_stack_; + std::vector scope_stack_; absl::flat_hash_set suppressed_branches_; const cel::ast_internal::Expr* resume_from_suppressed_branch_ = nullptr; std::vector> program_optimizers_; @@ -2064,12 +2168,21 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend( // Check to see if this is a special case of add that should really be // treated as a list append - if (!comprehension_stack_.empty() && - comprehension_stack_.back().is_optimizable_list_append) { + if (scope_stack_.empty()) { + return CallHandlerResult::kNotIntercepted; + } + + const auto* comprehension_scope = + absl::get_if(&scope_stack_.back().kind); + if (comprehension_scope == nullptr) { + return CallHandlerResult::kNotIntercepted; + } + + if (comprehension_scope->is_optimizable_list_append) { // Already checked that this is an optimizeable comprehension, // check that this is the correct list append node. const cel::ast_internal::Comprehension* comprehension = - comprehension_stack_.back().comprehension; + comprehension_scope->comprehension; const cel::ast_internal::Expr& loop_step = comprehension->loop_step(); // Macro loop_step for a map() will contain a list concat operation: // accu_var + [elem] @@ -2521,6 +2634,10 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( absl::StrCat("Invalid expression container: '", container_, "'")); } + // Preprocess with any transforms. + // + // TODO - will need need to decide on semantics for AST rewriting + // (are annotations visible or not?) for (const std::unique_ptr& transform : ast_transforms_) { CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, ast_impl)); } @@ -2534,18 +2651,35 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( } } + const cel::Expr* root = &ast_impl.root_expr(); + cel::internal::AnnotationMap annotation_exprs; + + if (ast_impl.root_expr().has_call_expr() && + ast_impl.root_expr().call_expr().function() == "cel.@annotated" && + ast_impl.root_expr().call_expr().args().size() == 2) { + if (options_.annotation_processing == + cel::AnnotationProcessingOptions::kIgnore) { + ast_impl.root_expr() = + std::move(ast_impl.root_expr().mutable_call_expr().mutable_args()[0]); + root = &ast_impl.root_expr(); + } else { + annotation_exprs = cel::internal::BuildAnnotationMap(ast_impl); + root = &ast_impl.root_expr(); + } + } + // These objects are expected to remain scoped to one build call -- references // to them shouldn't be persisted in any part of the result expression. cel::common_internal::LegacyValueManager value_factory( cel::MemoryManagerRef::ReferenceCounting(), GetTypeProvider()); FlatExprVisitor visitor(resolver, options_, std::move(optimizers), - ast_impl.reference_map(), value_factory, - issue_collector, program_builder, extension_context, + annotation_exprs, value_factory, issue_collector, + program_builder, extension_context, enable_optional_types_); cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; - AstTraverse(ast_impl.root_expr(), visitor, opts); + AstTraverse(*root, visitor, opts); if (!visitor.progress_status().ok()) { return visitor.progress_status(); @@ -2566,7 +2700,8 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( return FlatExpression(std::move(execution_path), std::move(subexpressions), visitor.slot_count(), GetTypeProvider(), options_, - std::move(arena)); + std::move(arena), std::move(ast), + std::move(annotation_exprs)); } const cel::TypeProvider& FlatExprBuilder::GetTypeProvider() const { return use_legacy_type_provider_ diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 48412601b..0d3e0416a 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -42,10 +42,12 @@ cc_library( ":comprehension_slots", ":evaluator_stack", "//base:data", + "//common:ast", "//common:memory", "//common:native_type", "//common:type", "//common:value", + "//internal:annotations", "//runtime", "//runtime:activation_interface", "//runtime:managed_value_factory", diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 931c76651..d419fee57 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -30,6 +30,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/type_provider.h" +#include "common/ast.h" #include "common/memory.h" #include "common/native_type.h" #include "common/type_factory.h" @@ -39,6 +40,7 @@ #include "eval/eval/attribute_utility.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/evaluator_stack.h" +#include "internal/annotations.h" #include "runtime/activation_interface.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime.h" @@ -383,13 +385,17 @@ class FlatExpression { size_t comprehension_slots_size, const cel::TypeProvider& type_provider, const cel::RuntimeOptions& options, - absl::Nullable> arena = nullptr) + absl::Nullable> arena = nullptr, + std::unique_ptr ast = nullptr, + cel::internal::AnnotationMap annotation_map = {}) : path_(std::move(path)), subexpressions_(std::move(subexpressions)), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), options_(options), - arena_(std::move(arena)) {} + arena_(std::move(arena)), + ast_(std::move(ast)), + annotation_map_(std::move(annotation_map)) {} // Move-only FlatExpression(FlatExpression&&) = default; @@ -426,6 +432,10 @@ class FlatExpression { const cel::TypeProvider& type_provider() const { return type_provider_; } + const cel::internal::AnnotationMap& annotation_map() const { + return annotation_map_; + } + private: ExecutionPath path_; std::vector subexpressions_; @@ -435,6 +445,10 @@ class FlatExpression { // Arena used during planning phase, may hold constant values so should be // kept alive. absl::Nullable> arena_; + // The AST used to generate the expression, if available. + absl::Nullable> ast_; + // The annotation map used to generate the expression, if available. + cel::internal::AnnotationMap annotation_map_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc index e0c8e1a4b..2d66fda1e 100644 --- a/eval/public/cel_options.cc +++ b/eval/public/cel_options.cc @@ -41,6 +41,7 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { options.max_recursion_depth, options.enable_recursive_tracing, options.enable_fast_builtins, + cel::AnnotationProcessingOptions::kIgnore, options.locale}; } diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 9b9412eb4..dacfea35f 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -199,6 +199,12 @@ struct InterpreterOptions { // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = true; + // Legacy implementation always ignores annotations. + // + // Setting this has no effect. + cel::AnnotationProcessingOptions annotation_processing = + cel::AnnotationProcessingOptions::kIgnore; + // The locale to use for string formatting. // // Default is en_US. diff --git a/internal/BUILD b/internal/BUILD index 494ac748b..cf302e3ae 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -724,6 +724,18 @@ cc_test( ], ) +cc_library( + name = "annotations", + srcs = ["annotations.cc"], + hdrs = ["annotations.h"], + deps = [ + "//base/ast_internal:ast_impl", + "//common:expr", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_log", + ], +) + cc_library( name = "protobuf_runtime_version", hdrs = ["protobuf_runtime_version.h"], diff --git a/internal/annotations.cc b/internal/annotations.cc new file mode 100644 index 000000000..1d037f199 --- /dev/null +++ b/internal/annotations.cc @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/annotations.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "base/ast_internal/ast_impl.h" + +namespace cel::internal { + +using ::cel::ast_internal::AstImpl; + +AnnotationMap BuildAnnotationMap(AstImpl& ast) { + AnnotationMap annotation_exprs; + // Caller validates that this is an annotated expression. + auto& annotation_map = ast.root_expr().mutable_call_expr().mutable_args()[1]; + + if (!annotation_map.has_map_expr()) { + return annotation_exprs; + } + + for (auto& entry : annotation_map.mutable_map_expr().mutable_entries()) { + if (!entry.has_key() || !entry.key().has_const_expr() || + !entry.key().const_expr().has_int_value()) { + continue; + } + int64_t id = entry.key().const_expr().int_value(); + if (!entry.has_value() || !entry.value().has_list_expr() || + entry.value().list_expr().elements().empty()) { + continue; + } + annotation_exprs[id].reserve(entry.value().list_expr().elements().size()); + for (auto& element : + entry.mutable_value().mutable_list_expr().mutable_elements()) { + if (!element.expr().has_struct_expr() || + element.expr().struct_expr().name() != "cel.Annotation") { + continue; + } + + AnnotationRep rep{}; + + for (auto& field : + element.mutable_expr().mutable_struct_expr().mutable_fields()) { + if (field.name() == "name") { + rep.name = field.value().const_expr().string_value(); + } else if (field.name() == "inspect_only") { + rep.inspect_only = field.value().const_expr().bool_value(); + } else if (field.name() == "value") { + rep.value_expr = &field.mutable_value(); + } + } + + if (rep.name.empty() || + (!rep.inspect_only && rep.value_expr == nullptr)) { + ABSL_LOG(WARNING) << "Invalid annotation"; + // TODO - log error. + continue; + } + annotation_exprs[id].push_back(std::move(rep)); + } + } + + return annotation_exprs; +} + +} // namespace cel::internal diff --git a/internal/annotations.h b/internal/annotations.h new file mode 100644 index 000000000..576f15a91 --- /dev/null +++ b/internal/annotations.h @@ -0,0 +1,46 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for extracting annotations from CEL ASTs. +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_ANNOTATIONS_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_ANNOTATIONS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "base/ast_internal/ast_impl.h" +#include "common/expr.h" + +namespace cel::internal { + +struct AnnotationRep { + std::string name = ""; + bool inspect_only = false; + Expr* value_expr = nullptr; + // Internal index used to identify the program associated with this + // annotation. Only used by the runtime. + int index = -1; +}; + +// Note: this returns raw ptrs to the AST nodes for each annotation. +// These may be invalidated by any change to the underlying AST. +using AnnotationMap = absl::flat_hash_map>; + +AnnotationMap BuildAnnotationMap(ast_internal::AstImpl& ast); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_ANNOTATIONS_H_ diff --git a/parser/BUILD b/parser/BUILD index d2815af47..6c0f63bcf 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -51,6 +51,7 @@ cc_library( "//parser/internal:cel_cc_parser", "@antlr4_runtimes//:cpp", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", diff --git a/parser/options.h b/parser/options.h index ad03102e8..546179bf7 100644 --- a/parser/options.h +++ b/parser/options.h @@ -59,6 +59,31 @@ struct ParserOptions final { // // Limited to field specifiers in select and message creation. bool enable_quoted_identifiers = false; + + // Enables support for the cel.annotate macro. + // + // Annotations are normally injected by higher level CEL tools to provide + // additional metadata about how to interpret or analyze the expression. This + // macro is intended for adding annotations in the source expression, using + // the same internal mechanisms as annotations added by tools. + // + // The macro takes two arguments: + // + // 1. The expression to annotate. + // 2. A list of annotations to apply to the expression. + // + // example: + // cel.annotate(foo.bar in baz, + // [cel.Annotation{name: "com.example.Explain", + // inspect_only: true, + // value: "check if foo.bar is in baz"}] + // ) + // + // Permits the short hand if the annotation has no value: + // cel.annotate(foo.bar in baz, "com.example.MyAnnotation") + // + // The annotation is recorded in the source_info of the parsed expression. + bool enable_annotations = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index 1437f6613..7c6fadad1 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -32,6 +32,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/base/macros.h" +#include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" @@ -601,23 +602,151 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +// Lightweight overlay for a registry. +// Adds stateful macros that are relevant per Parse call. +class AugmentedMacroRegistry { + public: + explicit AugmentedMacroRegistry(const cel::MacroRegistry& registry) + : base_(registry) {} + + cel::MacroRegistry& overlay() { return overlay_; } + + absl::optional FindMacro(absl::string_view name, size_t arg_count, + bool receiver_style) const; + + private: + const cel::MacroRegistry& base_; + cel::MacroRegistry overlay_; +}; + +absl::optional AugmentedMacroRegistry::FindMacro( + absl::string_view name, size_t arg_count, bool receiver_style) const { + auto result = overlay_.FindMacro(name, arg_count, receiver_style); + if (result.has_value()) { + return result; + } + + return base_.FindMacro(name, arg_count, receiver_style); +} + +bool IsSupportedAnnotation(const Expr& e) { + if (e.has_const_expr() && e.const_expr().has_string_value()) { + return true; + } else if (e.has_struct_expr() && + e.struct_expr().name() == "cel.Annotation") { + for (const auto& field : e.struct_expr().fields()) { + if (field.name() != "name" && field.name() != "inspect_only" && + field.name() != "value") { + return false; + } + } + return true; + } + return false; +} + +class AnnotationCollector { + private: + struct AnnotationRep { + Expr expr; + }; + + struct MacroImpl { + absl::Nonnull parent; + + // Record a single annotation. Returns a non-empty optional if + // an error is encountered. + absl::optional RecordAnnotation(cel::MacroExprFactory& mef, + int64_t id, Expr e) const; + + // MacroExpander for "cel.annotate" + absl::optional operator()(cel::MacroExprFactory& mef, Expr& target, + absl::Span args) const; + }; + + void Add(int64_t annotated_expr, Expr value); + + public: + const absl::btree_map>& annotations() { + return annotations_; + } + + absl::btree_map> consume_annotations() { + using std::swap; + absl::btree_map> result; + swap(result, annotations_); + return result; + } + + Macro MakeAnnotationImpl() { + auto impl = Macro::Receiver("annotate", 2, MacroImpl{this}); + ABSL_CHECK_OK(impl.status()); + return std::move(impl).value(); + } + + private: + absl::btree_map> annotations_; +}; + +absl::optional AnnotationCollector::MacroImpl::RecordAnnotation( + cel::MacroExprFactory& mef, int64_t id, Expr e) const { + if (IsSupportedAnnotation(e)) { + parent->Add(id, std::move(e)); + return absl::nullopt; + } + + return mef.ReportErrorAt( + e, + "cel.annotate argument is not a cel.Annotation{} or string expression"); +} + +absl::optional AnnotationCollector::MacroImpl::operator()( + cel::MacroExprFactory& mef, Expr& target, absl::Span args) const { + if (!target.has_ident_expr() || target.ident_expr().name() != "cel") { + return absl::nullopt; + } + + if (args.size() != 2) { + return mef.ReportErrorAt( + target, "wrong number of arguments for cel.annotate macro"); + } + + // arg0 (the annotated expression) is the expansion result. The remainder are + // annotations to record. + int64_t id = args[0].id(); + + absl::optional result; + if (args[1].has_list_expr()) { + auto list = args[1].release_list_expr(); + for (auto& e : list.mutable_elements()) { + result = RecordAnnotation(mef, id, e.release_expr()); + if (result) { + break; + } + } + } else { + result = RecordAnnotation(mef, id, std::move(args[1])); + } + + if (result) { + return result; + } + + return std::move(args[0]); +} + +void AnnotationCollector::Add(int64_t annotated_expr, Expr value) { + annotations_[annotated_expr].push_back({std::move(value)}); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, absl::string_view accu_var, - const cel::MacroRegistry& macro_registry, - bool add_macro_calls = false, - bool enable_optional_syntax = false, - bool enable_quoted_identifiers = false) - : source_(source), - factory_(source_, accu_var), - macro_registry_(macro_registry), - recursion_depth_(0), - max_recursion_depth_(max_recursion_depth), - add_macro_calls_(add_macro_calls), - enable_optional_syntax_(enable_optional_syntax), - enable_quoted_identifiers_(enable_quoted_identifiers) {} + const cel::MacroRegistry& macro_registry, bool add_macro_calls, + bool enable_optional_syntax, bool enable_quoted_identifiers, + bool enable_annotations); ~ParserVisitor() override = default; @@ -675,6 +804,8 @@ class ParserVisitor final : public CelBaseVisitor, std::string ErrorMessage(); + Expr PackAnnotations(Expr ast); + private: template Expr GlobalCallOrMacro(int64_t expr_id, absl::string_view function, @@ -702,14 +833,38 @@ class ParserVisitor final : public CelBaseVisitor, private: const cel::Source& source_; cel::ParserMacroExprFactory factory_; - const cel::MacroRegistry& macro_registry_; + AugmentedMacroRegistry macro_registry_; + AnnotationCollector annotations_; int recursion_depth_; const int max_recursion_depth_; const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; + const bool enable_annotations_; }; +ParserVisitor::ParserVisitor(const cel::Source& source, int max_recursion_depth, + absl::string_view accu_var, + const cel::MacroRegistry& macro_registry, + bool add_macro_calls, bool enable_optional_syntax, + bool enable_quoted_identifiers, + bool enable_annotations) + : source_(source), + factory_(source_, accu_var), + macro_registry_(macro_registry), + recursion_depth_(0), + max_recursion_depth_(max_recursion_depth), + add_macro_calls_(add_macro_calls), + enable_optional_syntax_(enable_optional_syntax), + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_annotations_(enable_annotations) { + if (enable_annotations_) { + macro_registry_.overlay() + .RegisterMacro(annotations_.MakeAnnotationImpl()) + .IgnoreError(); + } +} + template ::value>> T* tree_as(antlr4::tree::ParseTree* tree) { @@ -1638,6 +1793,61 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; +Expr NormalizeAnnotation(cel::ParserMacroExprFactory& mef, Expr expr) { + if (expr.has_struct_expr()) { + return expr; + } + + if (expr.has_const_expr()) { + std::vector fields; + fields.reserve(2); + fields.push_back( + mef.NewStructField(mef.NextId({}), "name", std::move(expr))); + auto bool_const = mef.NewBoolConst(mef.NextId({}), true); + fields.push_back(mef.NewStructField(mef.NextId({}), "inspect_only", + std::move(bool_const))); + return mef.NewStruct(mef.NextId({}), "cel.Annotation", std::move(fields)); + } + + return mef.ReportError("invalid annotation encountered finalizing AST"); +} + +Expr ParserVisitor::PackAnnotations(Expr ast) { + if (annotations_.annotations().empty()) { + return ast; + } + + auto annotations = annotations_.consume_annotations(); + std::vector entries; + entries.reserve(annotations.size()); + + for (auto& annotation : annotations) { + std::vector annotation_values; + annotation_values.reserve(annotation.second.size()); + + for (auto& annotation_value : annotation.second) { + auto annotation = + NormalizeAnnotation(factory_, std::move(annotation_value.expr)); + annotation_values.push_back( + factory_.NewListElement(std::move(annotation))); + } + auto id = factory_.NewIntConst(factory_.NextId({}), annotation.first); + auto annotation_list = + factory_.NewList(factory_.NextId({}), std::move(annotation_values)); + entries.push_back(factory_.NewMapEntry(factory_.NextId({}), std::move(id), + std::move(annotation_list))); + } + + std::vector args; + args.push_back(std::move(ast)); + args.push_back(factory_.NewMap(factory_.NextId({}), std::move(entries))); + + auto result = + factory_.NewCall(factory_.NextId({}), "cel.@annotated", std::move(args)); + + return result; +} + absl::StatusOr ParseImpl(const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { @@ -1656,10 +1866,10 @@ absl::StatusOr ParseImpl(const cel::Source& source, if (options.enable_hidden_accumulator_var) { accu_var = cel::kHiddenAccumulatorVariableName; } - ParserVisitor visitor(source, options.max_recursion_depth, accu_var, - registry, options.add_macro_calls, - options.enable_optional_syntax, - options.enable_quoted_identifiers); + ParserVisitor visitor( + source, options.max_recursion_depth, accu_var, registry, + options.add_macro_calls, options.enable_optional_syntax, + options.enable_quoted_identifiers, options.enable_annotations); lexer.removeErrorListeners(); parser.removeErrorListeners(); @@ -1686,7 +1896,9 @@ absl::StatusOr ParseImpl(const cel::Source& source, if (visitor.HasErrored()) { return absl::InvalidArgumentError(visitor.ErrorMessage()); } - + if (options.enable_annotations) { + expr = visitor.PackAnnotations(std::move(expr)); + } return { ParseResult{.expr = std::move(expr), .source_info = visitor.GetSourceInfo(), diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a29c62626..ce9ce5461 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -16,7 +16,6 @@ #include #include -#include #include #include @@ -54,11 +53,20 @@ using ::testing::HasSubstr; using ::testing::Not; struct TestInfo { - TestInfo(const std::string& I, const std::string& P, - const std::string& E = "", const std::string& L = "", - const std::string& R = "", const std::string& M = "") + TestInfo(absl::string_view I, absl::string_view P, absl::string_view E = "", + absl::string_view L = "", absl::string_view R = "", + absl::string_view M = "") : I(I), P(P), E(E), L(L), R(R), M(M) {} + static TestInfo MacroCallCase(absl::string_view I, absl::string_view P, + absl::string_view M) { + return TestInfo(I, P, /*E=*/"", /*L=*/"", /*R=*/"", M); + } + + static TestInfo ErrorCase(absl::string_view I, absl::string_view E) { + return TestInfo(I, /*P=*/"", E, /*L=*/"", /*R=*/"", /*M=*/""); + } + // I contains the input expression to be parsed. std::string I; @@ -1889,6 +1897,244 @@ TEST_P(UpdatedAccuVarDisabledTest, Parse) { } } +const std::vector& AnnotationsTestCases() { + static const std::vector* kInstance = new std::vector{ + TestInfo::MacroCallCase("cel.annotate(" + " foo.bar," + " 'com.example.SimpleAnnotation'" + ")", + R"( +cel.@annotated( + foo^#3:Expr.Ident#.bar^#4:Expr.Select#, + { + 4^#10:int64#:[ + cel.Annotation{ + name:"com.example.SimpleAnnotation"^#5:string#^#6:Expr.CreateStruct.Entry#, + inspect_only:true^#7:bool#^#8:Expr.CreateStruct.Entry# + }^#9:Expr.CreateStruct# + ]^#11:Expr.CreateList#^#12:Expr.CreateStruct.Entry# + }^#13:Expr.CreateStruct# +)^#14:Expr.Call#)", + "cel^#1:Expr.Ident#.annotate(\n" + " foo^#3:Expr.Ident#.bar^#4:annotate#,\n" + " \"com.example.SimpleAnnotation\"^#5:string#\n" + ")^#4:annotate"), + TestInfo::MacroCallCase( + R"cel( + cel.annotate( + foo.bar, + 'com.example.SimpleAnnotation') || + cel.annotate( + foo.baz, + 'com.example.MyOtherAnnotation'))cel", + R"( +cel.@annotated( + _||_( + foo^#3:Expr.Ident#.bar^#4:Expr.Select#, + foo^#8:Expr.Ident#.baz^#9:Expr.Select# + )^#11:Expr.Call#, + { + 4^#16:int64#:[ + cel.Annotation{ + name:"com.example.SimpleAnnotation"^#5:string#^#12:Expr.CreateStruct.Entry#, + inspect_only:true^#13:bool#^#14:Expr.CreateStruct.Entry# + }^#15:Expr.CreateStruct# + ]^#17:Expr.CreateList#^#18:Expr.CreateStruct.Entry#, + 9^#23:int64#:[ + cel.Annotation{ + name:"com.example.MyOtherAnnotation"^#10:string#^#19:Expr.CreateStruct.Entry#, + inspect_only:true^#20:bool#^#21:Expr.CreateStruct.Entry# + }^#22:Expr.CreateStruct# + ]^#24:Expr.CreateList#^#25:Expr.CreateStruct.Entry# + }^#26:Expr.CreateStruct# +)^#27:Expr.Call#)", + /*M=*/ + "cel^#6:Expr.Ident#.annotate(\n" + " foo^#8:Expr.Ident#.baz^#9:annotate#,\n" + " \"com.example.MyOtherAnnotation\"^#10:string#\n" + ")^#9:annotate#,\n" + "cel^#1:Expr.Ident#.annotate(\n" + " foo^#3:Expr.Ident#.bar^#4:annotate#,\n" + " \"com.example.SimpleAnnotation\"^#5:string#\n" + ")^#4:annotate"), + TestInfo::MacroCallCase(R"cel( + cel.annotate( + foo.bar, + ['com.example.SimpleAnnotation', + 'com.example.MyOtherAnnotation'] + ))cel", + + /*P=*/R"( +cel.@annotated( + foo^#3:Expr.Ident#.bar^#4:Expr.Select#, + { + 4^#16:int64#:[ + cel.Annotation{ + name:"com.example.SimpleAnnotation"^#6:string#^#8:Expr.CreateStruct.Entry#, + inspect_only:true^#9:bool#^#10:Expr.CreateStruct.Entry# + }^#11:Expr.CreateStruct#, + cel.Annotation{ + name:"com.example.MyOtherAnnotation"^#7:string#^#12:Expr.CreateStruct.Entry#, + inspect_only:true^#13:bool#^#14:Expr.CreateStruct.Entry# + }^#15:Expr.CreateStruct# + ]^#17:Expr.CreateList#^#18:Expr.CreateStruct.Entry# + }^#19:Expr.CreateStruct# +)^#20:Expr.Call#)", + + /*M=*/R"(cel^#1:Expr.Ident#.annotate( + foo^#3:Expr.Ident#.bar^#4:annotate#, + [ + "com.example.SimpleAnnotation"^#6:string#, + "com.example.MyOtherAnnotation"^#7:string# + ]^#5:Expr.CreateList# +)^#4:annotate)"), + TestInfo::MacroCallCase(R"cel( + cel.annotate( + baz in foo.bar, + cel.Annotation{ + name: 'com.example.Explainer', + value: "baz is in foo.bar." + cel.annotation_value ? " oh no" : "" + } + ))cel", + + R"( +cel.@annotated( + @in( + baz^#3:Expr.Ident#, + foo^#5:Expr.Ident#.bar^#6:Expr.Select# + )^#4:Expr.Call#, + { + 4^#18:int64#:[ + cel.Annotation{ + name:"com.example.Explainer"^#9:string#^#8:Expr.CreateStruct.Entry#, + value:_?_:_( + _+_( + "baz is in foo.bar."^#11:string#, + cel^#13:Expr.Ident#.annotation_value^#14:Expr.Select# + )^#12:Expr.Call#, + " oh no"^#16:string#, + ""^#17:string# + )^#15:Expr.Call#^#10:Expr.CreateStruct.Entry# + }^#7:Expr.CreateStruct# + ]^#19:Expr.CreateList#^#20:Expr.CreateStruct.Entry# + }^#21:Expr.CreateStruct# +)^#22:Expr.Call#)", + + /*M=*/R"(cel^#1:Expr.Ident#.annotate( + @in( + baz^#3:Expr.Ident#, + foo^#5:Expr.Ident#.bar^#6:Expr.Select# + )^#4:annotate#, + cel.Annotation{ + name:"com.example.Explainer"^#9:string#^#8:Expr.CreateStruct.Entry#, + value:_?_:_( + _+_( + "baz is in foo.bar."^#11:string#, + cel^#13:Expr.Ident#.annotation_value^#14:Expr.Select# + )^#12:Expr.Call#, + " oh no"^#16:string#, + ""^#17:string# + )^#15:Expr.Call#^#10:Expr.CreateStruct.Entry# + }^#7:Expr.CreateStruct# +)^#4:annotate)"), + + TestInfo::MacroCallCase(R"cel( + cel.annotate( + baz in foo.bar, + [ + cel.Annotation{ + name: 'com.example.Explainer', + value: "baz is in foo.bar. oh no" + }, + "com.example.SimpleAnnotation" + ] + ))cel", + + /*P=*/R"( +cel.@annotated( + @in( + baz^#3:Expr.Ident#, + foo^#5:Expr.Ident#.bar^#6:Expr.Select# + )^#4:Expr.Call#, + { + 4^#18:int64#:[ + cel.Annotation{ + name:"com.example.Explainer"^#10:string#^#9:Expr.CreateStruct.Entry#, + value:"baz is in foo.bar. oh no"^#12:string#^#11:Expr.CreateStruct.Entry# + }^#8:Expr.CreateStruct#, + cel.Annotation{ + name:"com.example.SimpleAnnotation"^#13:string#^#14:Expr.CreateStruct.Entry#, + inspect_only:true^#15:bool#^#16:Expr.CreateStruct.Entry# + }^#17:Expr.CreateStruct# + ]^#19:Expr.CreateList#^#20:Expr.CreateStruct.Entry# + }^#21:Expr.CreateStruct# +)^#22:Expr.Call#)", + + /*M=*/R"(cel^#1:Expr.Ident#.annotate( + @in( + baz^#3:Expr.Ident#, + foo^#5:Expr.Ident#.bar^#6:Expr.Select# + )^#4:annotate#, + [ + cel.Annotation{ + name:"com.example.Explainer"^#10:string#^#9:Expr.CreateStruct.Entry#, + value:"baz is in foo.bar. oh no"^#12:string#^#11:Expr.CreateStruct.Entry# + }^#8:Expr.CreateStruct#, + "com.example.SimpleAnnotation"^#13:string# + ]^#7:Expr.CreateList# +)^#4:annotate)")}; + + return *kInstance; +} + +class AnnotationsTest : public testing::TestWithParam {}; + +TEST_P(AnnotationsTest, Parse) { + const TestInfo& test_info = GetParam(); + ParserOptions options; + options.enable_annotations = true; + + if (!test_info.M.empty()) { + options.add_macro_calls = true; + } + + auto result = + EnrichedParse(test_info.I, Macro::AllMacros(), "", options); + if (test_info.E.empty()) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, Not(IsOk())); + EXPECT_EQ(test_info.E, result.status().message()); + } + + if (!test_info.P.empty()) { + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); + EXPECT_EQ(absl::StripAsciiWhitespace(test_info.P), adorned_string) + << result->parsed_expr(); + } + + if (!test_info.L.empty()) { + LocationAdorner location_adorner(result->parsed_expr().source_info()); + ExprPrinter w(location_adorner); + std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); + EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + } + + if (!test_info.R.empty()) { + EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( + result->enriched_source_info())); + } + + if (!test_info.M.empty()) { + EXPECT_EQ( + absl::StripAsciiWhitespace(test_info.M), + ConvertMacroCallsToString(result.value().parsed_expr().source_info())) + << result->parsed_expr(); + } +} + TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); @@ -1954,5 +2200,8 @@ INSTANTIATE_TEST_SUITE_P(UpdatedAccuVarTest, UpdatedAccuVarDisabledTest, testing::ValuesIn(UpdatedAccuVarTestCases()), TestName); +INSTANTIATE_TEST_SUITE_P(AnnotationsTest, AnnotationsTest, + testing::ValuesIn(AnnotationsTestCases()), TestName); + } // namespace } // namespace google::api::expr::parser diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 49596a5ea..a730221b6 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -43,6 +43,18 @@ enum class ProtoWrapperTypeOptions { kUnsetNull, }; +enum class AnnotationProcessingOptions { + // Annotations are discarded. + kIgnore, + + // Annotations are retained for inspection with tracing, but are not + // evaluable. + kRetain, + + // Annotations are retained and support dynamic evaluation during tracing. + kPlan, +}; + // LINT.IfChange // Interpreter options for controlling evaluation and builtin functions. // @@ -167,6 +179,13 @@ struct RuntimeOptions { // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = true; + // Annotation support level. + // + // Default behavior is to ignore, the annotations are just discarded and have + // no effect on evaluation or planning. + AnnotationProcessingOptions annotation_processing = + AnnotationProcessingOptions::kIgnore; + // The locale to use for string formatting. // // Default is the "en_US" locale.