From 7566cb998a75f300e1950c91387523abedf27696 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Tue, 23 Jul 2024 20:20:48 -0400 Subject: [PATCH 001/180] Fix antlr4 parser generation on Windows --- bazel/antlr.bzl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index 7e74a2e56..c1c20778e 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -42,13 +42,21 @@ def antlr_cc_library(name, src, package): def _antlr_library(ctx): output = ctx.actions.declare_directory(ctx.attr.name) + src_path = ctx.file.src.path + + # Workaround for Antlr4 bug: + # https://github.com/antlr/antlr4/issues/3138 + windows_constraint = ctx.attr._windows_constraint[platform_common.ConstraintValueInfo] + if ctx.target_platform_has_constraint(windows_constraint): + src_path = src_path.replace('/', '\\') + antlr_args = ctx.actions.args() antlr_args.add("-Dlanguage=Cpp") antlr_args.add("-no-listener") antlr_args.add("-visitor") antlr_args.add("-o", output.path) antlr_args.add("-package", ctx.attr.package) - antlr_args.add(ctx.file.src) + antlr_args.add(src_path) # Strip ".g4" extension. basename = ctx.file.src.basename[:-3] @@ -98,5 +106,6 @@ antlr_library = rule( cfg = "exec", # buildifier: disable=attr-cfg default = Label("//bazel:antlr4_tool"), ), + '_windows_constraint': attr.label(default = '@platforms//os:windows'), }, ) From 4d4a1a495347fc254fe9036a0049638536ddbcf2 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 24 Oct 2024 15:42:17 -0700 Subject: [PATCH 002/180] Sync with GitHub PiperOrigin-RevId: 689541392 --- conformance/BUILD | 4 +- conformance/service.cc | 86 ++++++++++++++++++++---------------------- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index e09b21f0c..7aebd2bc8 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -98,8 +98,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/conformance/service.cc b/conformance/service.cc index 6c5c5752a..803e80e35 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -83,9 +83,9 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" -#include "proto/test/v1/proto2/test_all_types_extensions.pb.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "proto/cel/expr/conformance/proto2/test_all_types.pb.h" +#include "proto/cel/expr/conformance/proto2/test_all_types_extensions.pb.h" +#include "proto/cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -271,34 +271,32 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto3::TestAllTypes>(); + cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto2::TestAllTypes>(); + cel::expr::conformance::google::protobuf::TestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto3::NestedTestAllTypes>(); + cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto2::NestedTestAllTypes>(); + cel::expr::conformance::google::protobuf::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::nested_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::int32_ext); + cel::expr::conformance::google::protobuf::test_all_types_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::nested_ext); + cel::expr::conformance::google::protobuf::nested_enum_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::test_all_types_ext); + cel::expr::conformance::google::protobuf::repeated_test_all_types); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::nested_enum_ext); - google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::repeated_test_all_types); - google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); InterpreterOptions options; @@ -322,13 +320,13 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CreateCelExpressionBuilder(options); auto type_registry = builder->GetTypeRegistry(); type_registry->Register( - google::api::expr::test::v1::proto2::GlobalEnum_descriptor()); + cel::expr::conformance::google::protobuf::GlobalEnum_descriptor()); + type_registry->Register( + cel::expr::conformance::proto3::GlobalEnum_descriptor()); + type_registry->Register( + cel::expr::conformance::google::protobuf::TestAllTypes::NestedEnum_descriptor()); type_registry->Register( - google::api::expr::test::v1::proto3::GlobalEnum_descriptor()); - type_registry->Register(google::api::expr::test::v1::proto2::TestAllTypes:: - NestedEnum_descriptor()); - type_registry->Register(google::api::expr::test::v1::proto3::TestAllTypes:: - NestedEnum_descriptor()); + cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( @@ -426,34 +424,32 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { static absl::StatusOr> Create( bool optimize, bool use_arena, bool recursive) { google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto3::TestAllTypes>(); + cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto2::TestAllTypes>(); + cel::expr::conformance::google::protobuf::TestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto3::NestedTestAllTypes>(); + cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< - google::api::expr::test::v1::proto2::NestedTestAllTypes>(); + cel::expr::conformance::google::protobuf::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::nested_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::int32_ext); + cel::expr::conformance::google::protobuf::test_all_types_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::nested_ext); + cel::expr::conformance::google::protobuf::nested_enum_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::test_all_types_ext); + cel::expr::conformance::google::protobuf::repeated_test_all_types); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::nested_enum_ext); - google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::repeated_test_all_types); - google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( - google::api::expr::test::v1::proto2::Proto2ExtensionScopedMessage:: + cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); RuntimeOptions options; @@ -491,16 +487,16 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { std::make_unique()); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, - google::api::expr::test::v1::proto2::GlobalEnum_descriptor())); + cel::expr::conformance::google::protobuf::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, - google::api::expr::test::v1::proto3::GlobalEnum_descriptor())); + cel::expr::conformance::proto3::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( - type_registry, google::api::expr::test::v1::proto2::TestAllTypes:: - NestedEnum_descriptor())); + type_registry, + cel::expr::conformance::google::protobuf::TestAllTypes::NestedEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( - type_registry, google::api::expr::test::v1::proto3::TestAllTypes:: - NestedEnum_descriptor())); + type_registry, + cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( From aff86967fc811ed4605f099ee2f502d331d5bbe7 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 25 Oct 2024 11:58:17 -0700 Subject: [PATCH 003/180] Breaking change: Use dev.cel (canonical) protos instead of google.api.expr Historically C++ was implemented using the unversioned protobuf descriptors. It eventually included a mix of the versioned variant. When moving to OSS, unversioned protobuf descriptors were rewritten to the versioned variant unconditionally. This has created a bit of a conundrum, now that their is an unversioned canonical protobuf descriptor in cel-spec. We have decided to bite the bullet and break the world. The versioned and unversioned protobuf descriptors are wire compatible, so if you run into issues you can serialize and parse to convert between the two. PiperOrigin-RevId: 689865345 --- base/ast_internal/expr.h | 2 +- bazel/deps.bzl | 4 +- checker/internal/BUILD | 4 +- checker/internal/type_checker_impl_test.cc | 108 +++++++------- checker/optional_test.cc | 28 ++-- codelab/BUILD | 4 +- codelab/exercise1.cc | 4 +- codelab/exercise2.cc | 4 +- codelab/solutions/BUILD | 4 +- codelab/solutions/exercise1.cc | 4 +- codelab/solutions/exercise2.cc | 4 +- codelab/solutions/exercise4.cc | 2 +- common/BUILD | 6 +- common/ast_rewrite_test.cc | 6 +- common/operators.cc | 4 +- common/operators.h | 6 +- common/type_test.cc | 52 +++---- common/types/type_pool_test.cc | 2 +- common/value_test.cc | 4 +- common/values/message_value_test.cc | 6 +- common/values/parsed_json_list_value_test.cc | 4 +- common/values/parsed_json_map_value_test.cc | 4 +- common/values/parsed_json_value_test.cc | 4 +- common/values/parsed_map_field_value_test.cc | 4 +- common/values/parsed_message_value_test.cc | 6 +- .../parsed_repeated_field_value_test.cc | 4 +- common/values/struct_value_test.cc | 4 +- conformance/BUILD | 15 +- conformance/run.cc | 41 ++++-- conformance/service.cc | 118 ++++++++------- conformance/value_conversion.cc | 74 +++++----- conformance/value_conversion.h | 78 +++++++++- eval/compiler/BUILD | 28 ++-- .../cel_expression_builder_flat_impl.cc | 10 +- .../cel_expression_builder_flat_impl.h | 16 +-- .../cel_expression_builder_flat_impl_test.cc | 26 ++-- eval/compiler/constant_folding_test.cc | 4 +- .../flat_expr_builder_comprehensions_test.cc | 6 +- ...ilder_short_circuiting_conformance_test.cc | 2 +- eval/compiler/flat_expr_builder_test.cc | 24 ++-- eval/compiler/instrumentation_test.cc | 4 +- .../qualified_reference_resolver_test.cc | 24 ++-- .../regex_precompilation_optimization_test.cc | 6 +- eval/eval/BUILD | 24 ++-- eval/eval/attribute_trail_test.cc | 2 +- eval/eval/comprehension_step_test.cc | 2 +- eval/eval/container_access_step_test.cc | 4 +- eval/eval/create_map_step_test.cc | 2 +- eval/eval/create_struct_step_test.cc | 2 +- eval/eval/evaluator_core_test.cc | 6 +- eval/eval/jump_step.h | 2 +- eval/eval/lazy_init_step.cc | 2 +- eval/eval/regex_match_step_test.cc | 8 +- eval/eval/select_step_test.cc | 6 +- eval/public/BUILD | 36 ++--- eval/public/activation_test.cc | 2 +- eval/public/ast_rewrite.cc | 18 +-- eval/public/ast_rewrite.h | 66 ++++----- eval/public/ast_rewrite_test.cc | 22 +-- eval/public/ast_traverse.cc | 18 +-- eval/public/ast_traverse.h | 6 +- eval/public/ast_traverse_test.cc | 18 +-- eval/public/ast_visitor.h | 74 +++++----- eval/public/ast_visitor_base.h | 42 +++--- eval/public/builtin_func_registrar_test.cc | 6 +- eval/public/builtin_func_test.cc | 6 +- eval/public/cel_attribute.h | 2 +- eval/public/cel_attribute_test.cc | 2 +- eval/public/cel_expression.h | 16 +-- eval/public/comparison_functions_test.cc | 4 +- .../container_function_registrar_test.cc | 4 +- eval/public/containers/BUILD | 2 +- eval/public/containers/field_access_test.cc | 4 +- .../equality_function_registrar_test.cc | 4 +- .../public/logical_function_registrar_test.cc | 6 +- .../portable_cel_expr_builder_factory_test.cc | 2 +- eval/public/source_position.cc | 2 +- eval/public/source_position.h | 6 +- eval/public/source_position_test.cc | 4 +- .../string_extension_func_registrar_test.cc | 2 +- eval/public/structs/BUILD | 6 +- ...dynamic_descriptor_pool_end_to_end_test.cc | 32 ++--- eval/public/structs/field_access_impl_test.cc | 4 +- eval/public/transform_utility.cc | 2 +- eval/public/transform_utility.h | 8 +- eval/public/unknown_attribute_set_test.cc | 2 +- eval/public/unknown_set_test.cc | 2 +- eval/tests/BUILD | 16 +-- eval/tests/allocation_benchmark_test.cc | 4 +- eval/tests/benchmark_test.cc | 8 +- eval/tests/end_to_end_test.cc | 6 +- .../expression_builder_benchmark_test.cc | 8 +- eval/tests/memory_safety_test.cc | 4 +- eval/tests/modern_benchmark_test.cc | 8 +- eval/tests/unknowns_end_to_end_test.cc | 8 +- extensions/BUILD | 14 +- extensions/bindings_ext_benchmark_test.cc | 2 +- extensions/bindings_ext_test.cc | 34 ++--- extensions/math_ext_test.cc | 8 +- extensions/protobuf/BUILD | 22 +-- extensions/protobuf/ast_converters.cc | 134 +++++++++--------- extensions/protobuf/ast_converters.h | 26 ++-- extensions/protobuf/ast_converters_test.cc | 102 ++++++------- .../protobuf/bind_proto_to_activation_test.cc | 4 +- extensions/protobuf/internal/BUILD | 6 +- extensions/protobuf/internal/ast.cc | 18 +-- extensions/protobuf/internal/ast.h | 6 +- extensions/protobuf/internal/ast_test.cc | 4 +- extensions/protobuf/internal/constant.cc | 4 +- extensions/protobuf/internal/constant.h | 6 +- extensions/protobuf/runtime_adapter.cc | 12 +- extensions/protobuf/runtime_adapter.h | 12 +- extensions/protobuf/type_introspector_test.cc | 12 +- extensions/protobuf/type_reflector_test.cc | 4 +- extensions/protobuf/value_end_to_end_test.cc | 6 +- extensions/protobuf/value_test.cc | 4 +- extensions/protobuf/value_testing_test.cc | 4 +- extensions/sets_functions_benchmark_test.cc | 4 +- extensions/sets_functions_test.cc | 8 +- extensions/strings_test.cc | 4 +- internal/BUILD | 20 +-- internal/json_test.cc | 52 +++---- internal/message_equality_test.cc | 4 +- internal/testing_descriptor_pool_test.cc | 4 +- internal/well_known_types_test.cc | 6 +- parser/BUILD | 4 +- parser/parser.cc | 12 +- parser/parser.h | 14 +- parser/parser_test.cc | 22 +-- runtime/BUILD | 14 +- .../comprehension_vulnerability_check_test.cc | 4 +- runtime/constant_folding_test.cc | 4 +- runtime/optional_types_test.cc | 4 +- runtime/reference_resolver_test.cc | 18 +-- runtime/regex_precompilation_test.cc | 4 +- .../standard_runtime_builder_factory_test.cc | 4 +- testutil/BUILD | 4 +- testutil/baseline_tests.cc | 2 +- testutil/baseline_tests.h | 4 +- testutil/baseline_tests_test.cc | 2 +- testutil/expr_printer.cc | 2 +- testutil/expr_printer.h | 4 +- tools/BUILD | 8 +- tools/branch_coverage.cc | 6 +- tools/branch_coverage.h | 6 +- tools/branch_coverage_test.cc | 2 +- tools/navigable_ast.cc | 4 +- tools/navigable_ast.h | 12 +- tools/navigable_ast_test.cc | 4 +- 149 files changed, 1066 insertions(+), 960 deletions(-) diff --git a/base/ast_internal/expr.h b/base/ast_internal/expr.h index 7ae08797c..bdba1363d 100644 --- a/base/ast_internal/expr.h +++ b/base/ast_internal/expr.h @@ -575,7 +575,7 @@ using TypeKind = absl::Nullable>, ErrorType, AbstractType>; -// Analogous to google::api::expr::v1alpha1::Type. +// Analogous to cel::expr::Type. // Represents a CEL type. // // TODO: align with value.proto diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 7fbdd7925..51eb3e9d6 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -142,10 +142,10 @@ def cel_spec_deps(): url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", ) - CEL_SPEC_GIT_SHA = "f027a86d2e5bf18f796be0c4373f637a61041cde" # Aug 23, 2024 + CEL_SPEC_GIT_SHA = "373994d7e20e582fce56767b01ac5039524cddab" # Oct 23, 2024 http_archive( name = "com_google_cel_spec", - sha256 = "006594fa4f97819a4e4cd98404e4522f5f46ed5ac65402b354649bcc871b0cf2", + sha256 = "b498a768140fc0ed0314eef8b2519a48287661d09ca15b17c8ca34088af6aac3", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index e07fb2e36..453a2c309 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -175,8 +175,8 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index c53ca2255..50be6d671 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -45,8 +45,8 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" @@ -58,8 +58,8 @@ namespace { using ::absl_testing::IsOk; using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::GetSharedTestingDescriptorPool; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::testing::_; using ::testing::Contains; using ::testing::ElementsAre; @@ -71,7 +71,7 @@ using ::testing::Property; using AstType = ast_internal::Type; using Severity = TypeCheckIssue::Severity; -namespace testpb3 = ::google::api::expr::test::v1::proto3; +namespace testpb3 = ::cel::expr::conformance::proto3; std::string SevString(Severity severity) { switch (severity) { @@ -989,7 +989,7 @@ INSTANTIATE_TEST_SUITE_P( AstTypeConversionTestCase{ .decl_type = StructType(MessageType(TestAllTypes::descriptor())), .expected_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes"))})); + "cel.expr.conformance.proto3.TestAllTypes"))})); TEST(TypeCheckerImplTest, NullLiteral) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -1280,7 +1280,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.api.expr.test.v1.proto3"); + env.set_container("cel.expr.conformance.proto3"); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -1293,7 +1293,7 @@ TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { auto ref_iter = ast_impl.reference_map().find(ast_impl.root_expr().id()); ASSERT_NE(ref_iter, ast_impl.reference_map().end()); EXPECT_EQ(ref_iter->second.name(), - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum.BAZ"); + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAZ"); EXPECT_EQ(ref_iter->second.value().int_value(), 2); } @@ -1472,7 +1472,7 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.api.expr.test.v1.proto3"); + env.set_container("cel.expr.conformance.proto3"); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( @@ -1512,11 +1512,11 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " - "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64: 'string'}", @@ -1527,113 +1527,113 @@ INSTANTIATE_TEST_SUITE_P( CheckedExprTestCase{ .expr = "TestAllTypes{single_int32: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_uint64: 10u}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_uint32: 10u}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sint64: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sint32: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_fixed64: 10u}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_fixed32: 10u}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sfixed64: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sfixed32: 10}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_double: 1.25}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_float: 1.25}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_string: 'string'}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_bool: true}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_bytes: b'string'}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, // Well-known CheckedExprTestCase{ .expr = "TestAllTypes{single_any: TestAllTypes{single_int64: 10}}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: 1}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: 'string'}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: ['string']}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: duration('1s')}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_timestamp: timestamp(0)}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {}}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {'key': 'value'}}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {1: 2}}", @@ -1644,12 +1644,12 @@ INSTANTIATE_TEST_SUITE_P( CheckedExprTestCase{ .expr = "TestAllTypes{list_value: [1, 2, 3]}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{list_value: []}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{list_value: 1}", @@ -1660,42 +1660,42 @@ INSTANTIATE_TEST_SUITE_P( CheckedExprTestCase{ .expr = "TestAllTypes{single_int64_wrapper: 1}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64_wrapper: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: 1.0}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: 'string'}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: {'string': 'string'}}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: ['string']}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{repeated_int64: [1, 2, 3]}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{repeated_int64: ['string']}", @@ -1710,18 +1710,18 @@ INSTANTIATE_TEST_SUITE_P( CheckedExprTestCase{ .expr = "TestAllTypes{map_string_int64: {'string': 1}}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_enum: 1}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_enum: TestAllTypes.NestedEnum.BAR}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes.NestedEnum.BAR", @@ -1732,7 +1732,7 @@ INSTANTIATE_TEST_SUITE_P( .expr = "TestAllTypes", .expected_result_type = AstType(std::make_unique(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes"))), + "cel.expr.conformance.proto3.TestAllTypes"))), }, CheckedExprTestCase{ .expr = "TestAllTypes == type(TestAllTypes{})", @@ -1742,28 +1742,28 @@ INSTANTIATE_TEST_SUITE_P( CheckedExprTestCase{ .expr = "TestAllTypes{null_value: 0}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{null_value: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, // Legacy nullability behaviors. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_timestamp: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_message: null}", .expected_result_type = AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")), + "cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_duration == null", @@ -1786,7 +1786,7 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " - "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "test_msg.single_int64", .expected_result_type = @@ -1811,7 +1811,7 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " - "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "has(test_msg.single_int64)", .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), @@ -1983,7 +1983,7 @@ TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.api.expr.test.v1.proto3"); + env.set_container("cel.expr.conformance.proto3"); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 7c81dea59..841597061 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -81,7 +81,7 @@ TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); - builder.set_container("google.api.expr.test.v1.proto3"); + builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(builder).Build()); @@ -227,10 +227,10 @@ INSTANTIATE_TEST_SUITE_P( new AstType(ast_internal::PrimitiveType::kString)))))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type' but found 'string'"}, - TestCase{"google.api.expr.test.v1.proto3.TestAllTypes{?single_int64: " + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " "optional.of(1)}", Eq(AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")))}, + "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"[0][?1]", IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, TestCase{"[[0]][?1][?1]", @@ -250,19 +250,18 @@ INSTANTIATE_TEST_SUITE_P( TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", IsOptionalType(AstType(ast_internal::PrimitiveType::kString))}, // Legacy nullability behaviors. - TestCase{"google.api.expr.test.v1.proto3.TestAllTypes{?null_value: " + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", Eq(AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")))}, - TestCase{ - "google.api.expr.test.v1.proto3.TestAllTypes{?null_value: null}", - Eq(AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")))}, - TestCase{"google.api.expr.test.v1.proto3.TestAllTypes{?null_value: " + "cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}", + Eq(AstType(ast_internal::MessageType( + "cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(null)}", Eq(AstType(ast_internal::MessageType( - "google.api.expr.test.v1.proto3.TestAllTypes")))}, - TestCase{"google.api.expr.test.v1.proto3.TestAllTypes{}.?single_int64 " + "cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " "== null", Eq(AstType(ast_internal::PrimitiveType::kBool))})); @@ -311,11 +310,10 @@ INSTANTIATE_TEST_SUITE_P( OptionalTests, OptionalStrictNullAssignmentTest, ::testing::Values( TestCase{ - "google.api.expr.test.v1.proto3.TestAllTypes{?single_int64: null}", - _, + "cel.expr.conformance.proto3.TestAllTypes{?single_int64: null}", _, "expected type of field 'single_int64' is 'optional_type' but " "provided type is 'null_type'"}, - TestCase{"google.api.expr.test.v1.proto3.TestAllTypes{}.?single_int64 " + TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " "== null", _, "no matching overload for '_==_'"})); diff --git a/codelab/BUILD b/codelab/BUILD index 5c98be576..b80219f21 100644 --- a/codelab/BUILD +++ b/codelab/BUILD @@ -48,7 +48,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -80,7 +80,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/codelab/exercise1.cc b/codelab/exercise1.cc index ba0fdfa14..85908250b 100644 --- a/codelab/exercise1.cc +++ b/codelab/exercise1.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -54,7 +54,7 @@ absl::StatusOr ConvertResult(const CelValue& value) { absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { // === Start Codelab === // Parse the expression using ::google::api::expr::parser::Parse; - // This will return a google::api::expr::v1alpha1::ParsedExpr message. + // This will return a cel::expr::ParsedExpr message. // Setup a default environment for building expressions. // std::unique_ptr builder = diff --git a/codelab/exercise2.cc b/codelab/exercise2.cc index 28b68e49c..93f060ccd 100644 --- a/codelab/exercise2.cc +++ b/codelab/exercise2.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" @@ -35,7 +35,7 @@ namespace google::api::expr::codelab { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelError; diff --git a/codelab/solutions/BUILD b/codelab/solutions/BUILD index 5767d35ff..a85f0f668 100644 --- a/codelab/solutions/BUILD +++ b/codelab/solutions/BUILD @@ -32,7 +32,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -63,7 +63,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/codelab/solutions/exercise1.cc b/codelab/solutions/exercise1.cc index 69bbafff7..83e729c9c 100644 --- a/codelab/solutions/exercise1.cc +++ b/codelab/solutions/exercise1.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -34,7 +34,7 @@ namespace google::api::expr::codelab { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpression; diff --git a/codelab/solutions/exercise2.cc b/codelab/solutions/exercise2.cc index e6c8ed567..236ad9312 100644 --- a/codelab/solutions/exercise2.cc +++ b/codelab/solutions/exercise2.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -35,7 +35,7 @@ namespace google::api::expr::codelab { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::BindProtoToActivation; diff --git a/codelab/solutions/exercise4.cc b/codelab/solutions/exercise4.cc index 4caf23322..c07bc3413 100644 --- a/codelab/solutions/exercise4.cc +++ b/codelab/solutions/exercise4.cc @@ -16,7 +16,7 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/common/BUILD b/common/BUILD index 11c60e5e2..0544969b2 100644 --- a/common/BUILD +++ b/common/BUILD @@ -136,7 +136,7 @@ cc_test( "//extensions/protobuf:ast_converters", "//internal:testing", "//parser", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -238,7 +238,7 @@ cc_library( deps = [ "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -730,7 +730,7 @@ cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/common/ast_rewrite_test.cc b/common/ast_rewrite_test.cc index 2c2e45455..ed4e3eabf 100644 --- a/common/ast_rewrite_test.cc +++ b/common/ast_rewrite_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "base/ast_internal/ast_impl.h" #include "common/ast.h" #include "common/ast_visitor.h" @@ -536,7 +536,7 @@ TEST(AstRewrite, SelectRewriteExample) { RewriterExample example; ASSERT_TRUE(AstRewrite(ast_impl.root_expr(), example)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 3 @@ -588,7 +588,7 @@ TEST(AstRewrite, PreAndPostVisitExpample) { AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); ASSERT_TRUE(AstRewrite(ast_impl.root_expr(), visitor)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 1 diff --git a/common/operators.cc b/common/operators.cc index de3b3a082..9c469da2c 100644 --- a/common/operators.cc +++ b/common/operators.cc @@ -217,7 +217,7 @@ absl::optional ReverseLookupOperator(const std::string& op) { } bool IsOperatorSamePrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr) { + const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } @@ -225,7 +225,7 @@ bool IsOperatorSamePrecedence(const std::string& op, } bool IsOperatorLowerPrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr) { + const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } diff --git a/common/operators.h b/common/operators.h index b12b0a46f..dcafce2dd 100644 --- a/common/operators.h +++ b/common/operators.h @@ -4,7 +4,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -61,10 +61,10 @@ absl::optional ReverseLookupOperator(const std::string& op); // returns true if op has a lower precedence than the one expressed in expr bool IsOperatorLowerPrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr); + const cel::expr::Expr& expr); // returns true if op has the same precedence as the one expressed in expr bool IsOperatorSamePrecedence(const std::string& op, - const google::api::expr::v1alpha1::Expr& expr); + const cel::expr::Expr& expr); // return true if operator is left recursive, i.e., neither && nor ||. bool IsOperatorLeftRecursive(const std::string& op); diff --git a/common/type_test.cc b/common/type_test.cc index 024d8b1f7..119234fdc 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -39,9 +39,9 @@ TEST(Type, Enum) { EXPECT_EQ( Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))), + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))), EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))); + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); EXPECT_EQ(Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"))), @@ -52,7 +52,7 @@ TEST(Type, Field) { google::protobuf::Arena arena; const auto* descriptor = ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")); + "cel.expr.conformance.proto3.TestAllTypes")); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_bool"))), BoolType()); @@ -150,7 +150,7 @@ TEST(Type, Field) { Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("standalone_enum"))), EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))); + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("repeated_int32"))), ListType(&arena, IntType())); @@ -183,7 +183,7 @@ TEST(Type, Kind) { EXPECT_EQ( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .kind(), EnumType::kKind); @@ -202,12 +202,12 @@ TEST(Type, Kind) { EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .kind(), MessageType::kKind); EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .kind(), MessageType::kKind); @@ -252,7 +252,7 @@ TEST(Type, GetParameters) { EXPECT_THAT( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .GetParameters(), IsEmpty()); @@ -274,7 +274,7 @@ TEST(Type, GetParameters) { EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .GetParameters(), IsEmpty()); @@ -322,7 +322,7 @@ TEST(Type, Is) { EXPECT_TRUE( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .Is()); EXPECT_TRUE(Type(ErrorType()).Is()); @@ -340,11 +340,11 @@ TEST(Type, Is) { EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .IsStruct()); EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .IsMessage()); EXPECT_TRUE(Type(NullType()).Is()); @@ -399,7 +399,7 @@ TEST(Type, As) { EXPECT_THAT( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .As(), Optional(An())); @@ -418,12 +418,12 @@ TEST(Type, As) { EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .As(), Optional(An())); EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))) + "cel.expr.conformance.proto3.TestAllTypes")))) .As(), Optional(An())); @@ -494,7 +494,7 @@ TEST(Type, Get) { EXPECT_THAT( DoGet(Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))))), + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))))), An()); EXPECT_THAT(DoGet(Type(ErrorType())), An()); @@ -515,11 +515,11 @@ TEST(Type, Get) { EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes"))))), + "cel.expr.conformance.proto3.TestAllTypes"))))), An()); EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes"))))), + "cel.expr.conformance.proto3.TestAllTypes"))))), An()); EXPECT_THAT(DoGet(Type(NullType())), An()); @@ -585,37 +585,37 @@ TEST(Type, VerifyTypeImplementsAbslHashCorrectly) { EXPECT_EQ( absl::HashOf(Type::Field( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) + "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("repeated_int64"))), absl::HashOf(Type(ListType(&arena, IntType())))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) + "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("repeated_int64")), Type(ListType(&arena, IntType()))); EXPECT_EQ( absl::HashOf(Type::Field( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) + "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("map_int64_int64"))), absl::HashOf(Type(MapType(&arena, IntType(), IntType())))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) + "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("map_int64_int64")), Type(MapType(&arena, IntType(), IntType()))); EXPECT_EQ(absl::HashOf(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes"))))), + "cel.expr.conformance.proto3.TestAllTypes"))))), absl::HashOf(Type(StructType(common_internal::MakeBasicStructType( - "google.api.expr.test.v1.proto3.TestAllTypes"))))); + "cel.expr.conformance.proto3.TestAllTypes"))))); EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")))), + "cel.expr.conformance.proto3.TestAllTypes")))), Type(StructType(common_internal::MakeBasicStructType( - "google.api.expr.test.v1.proto3.TestAllTypes")))); + "cel.expr.conformance.proto3.TestAllTypes")))); } TEST(Type, Unwrap) { diff --git a/common/types/type_pool_test.cc b/common/types/type_pool_test.cc index 2f36121be..4d32113d0 100644 --- a/common/types/type_pool_test.cc +++ b/common/types/type_pool_test.cc @@ -31,7 +31,7 @@ TEST(TypePool, MakeStructType) { EXPECT_EQ(type_pool.MakeStructType("foo.Bar"), MakeBasicStructType("foo.Bar")); EXPECT_TRUE( - type_pool.MakeStructType("google.api.expr.test.v1.proto3.TestAllTypes") + type_pool.MakeStructType("cel.expr.conformance.proto3.TestAllTypes") .IsMessage()); EXPECT_DEBUG_DEATH( static_cast(type_pool.MakeStructType("google.protobuf.BoolValue")), diff --git a/common/value_test.cc b/common/value_test.cc index 090f71357..7ef8c006a 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -30,7 +30,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/generated_enum_reflection.h" @@ -48,7 +48,7 @@ using ::testing::Eq; using ::testing::NotNull; using ::testing::Optional; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; TEST(Value, KindDebugDeath) { Value value; diff --git a/common/values/message_value_test.cc b/common/values/message_value_test.cc index bbd49421f..08cfbb083 100644 --- a/common/values/message_value_test.cc +++ b/common/values/message_value_test.cc @@ -31,7 +31,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -48,7 +48,7 @@ using ::testing::Optional; using ::testing::PrintToStringParamName; using ::testing::TestWithParam; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class MessageValueTest : public TestWithParam { public: @@ -179,7 +179,7 @@ TEST_P(MessageValueTest, GetTypeName) { MessageValue value( ParsedMessageValue(DynamicParseTextProto( allocator(), R"pb()pb", descriptor_pool(), message_factory()))); - EXPECT_EQ(value.GetTypeName(), "google.api.expr.test.v1.proto3.TestAllTypes"); + EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); } TEST_P(MessageValueTest, GetRuntimeType) { diff --git a/common/values/parsed_json_list_value_test.cc b/common/values/parsed_json_list_value_test.cc index e50793b5e..40b05fde7 100644 --- a/common/values/parsed_json_list_value_test.cc +++ b/common/values/parsed_json_list_value_test.cc @@ -36,7 +36,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -58,7 +58,7 @@ using ::testing::PrintToStringParamName; using ::testing::TestWithParam; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedJsonListValueTest : public TestWithParam { public: diff --git a/common/values/parsed_json_map_value_test.cc b/common/values/parsed_json_map_value_test.cc index 24af12d3d..d4ebbd686 100644 --- a/common/values/parsed_json_map_value_test.cc +++ b/common/values/parsed_json_map_value_test.cc @@ -36,7 +36,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -63,7 +63,7 @@ using ::testing::TestWithParam; using ::testing::UnorderedElementsAre; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedJsonMapValueTest : public TestWithParam { public: diff --git a/common/values/parsed_json_value_test.cc b/common/values/parsed_json_value_test.cc index ff0193835..256a0b659 100644 --- a/common/values/parsed_json_value_test.cc +++ b/common/values/parsed_json_value_test.cc @@ -28,7 +28,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -52,7 +52,7 @@ using ::testing::PrintToStringParamName; using ::testing::TestWithParam; using ::testing::UnorderedElementsAre; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedJsonValueTest : public TestWithParam { public: diff --git a/common/values/parsed_map_field_value_test.cc b/common/values/parsed_map_field_value_test.cc index e17d2ac59..a90f782e1 100644 --- a/common/values/parsed_map_field_value_test.cc +++ b/common/values/parsed_map_field_value_test.cc @@ -39,7 +39,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -70,7 +70,7 @@ using ::testing::PrintToStringParamName; using ::testing::TestWithParam; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedMapFieldValueTest : public TestWithParam { public: diff --git a/common/values/parsed_message_value_test.cc b/common/values/parsed_message_value_test.cc index 1036ccd00..d2840de22 100644 --- a/common/values/parsed_message_value_test.cc +++ b/common/values/parsed_message_value_test.cc @@ -30,7 +30,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -50,7 +50,7 @@ using ::testing::PrintToStringParamName; using ::testing::TestWithParam; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedMessageValueTest : public TestWithParam { public: @@ -122,7 +122,7 @@ TEST_P(ParsedMessageValueTest, Kind) { TEST_P(ParsedMessageValueTest, GetTypeName) { ParsedMessageValue value = MakeParsedMessage(R"pb()pb"); - EXPECT_EQ(value.GetTypeName(), "google.api.expr.test.v1.proto3.TestAllTypes"); + EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); } TEST_P(ParsedMessageValueTest, GetRuntimeType) { diff --git a/common/values/parsed_repeated_field_value_test.cc b/common/values/parsed_repeated_field_value_test.cc index 4bcc84aa5..05439f131 100644 --- a/common/values/parsed_repeated_field_value_test.cc +++ b/common/values/parsed_repeated_field_value_test.cc @@ -38,7 +38,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -66,7 +66,7 @@ using ::testing::PrintToStringParamName; using ::testing::TestWithParam; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ParsedRepeatedFieldValueTest : public TestWithParam { public: diff --git a/common/values/struct_value_test.cc b/common/values/struct_value_test.cc index ab485cb6d..e0667e085 100644 --- a/common/values/struct_value_test.cc +++ b/common/values/struct_value_test.cc @@ -18,7 +18,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel { @@ -30,7 +30,7 @@ using ::cel::internal::GetTestingMessageFactory; using ::testing::An; using ::testing::Optional; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; TEST(StructValue, Is) { EXPECT_TRUE(StructValue(ParsedMessageValue()).Is()); diff --git a/conformance/BUILD b/conformance/BUILD index 7aebd2bc8..aca4c2795 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -30,12 +30,17 @@ cc_library( "//extensions/protobuf:value", "//internal:proto_time_encoding", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", "@com_google_protobuf//:protobuf", @@ -91,18 +96,18 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", - "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -119,9 +124,13 @@ cc_library( "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@com_google_cel_spec//proto/test/v1:simple_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:eval_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:simple_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", diff --git a/conformance/run.cc b/conformance/run.cc index 325c82a7e..d810833e3 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -27,14 +27,18 @@ #include #include +#include "cel/expr/checked.pb.h" #include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" +#include "cel/expr/eval.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" // IWYU pragma: keep #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "google/rpc/code.pb.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -42,7 +46,7 @@ #include "absl/types/span.h" #include "conformance/service.h" #include "internal/testing.h" -#include "proto/test/v1/simple.pb.h" +#include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -67,14 +71,14 @@ namespace { using ::testing::IsEmpty; +using cel::expr::conformance::test::SimpleTest; +using cel::expr::conformance::test::SimpleTestFile; using google::api::expr::conformance::v1alpha1::CheckRequest; using google::api::expr::conformance::v1alpha1::CheckResponse; using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; -using google::api::expr::test::v1::SimpleTest; -using google::api::expr::test::v1::SimpleTestFile; using google::protobuf::TextFormat; using google::protobuf::util::DefaultFieldComparator; using google::protobuf::util::MessageDifferencer; @@ -102,8 +106,7 @@ MATCHER_P(MatchesConformanceValue, expected, "") { auto* differencer = new MessageDifferencer(); differencer->set_message_field_comparison(MessageDifferencer::EQUIVALENT); differencer->set_field_comparator(kFieldComparator); - const auto* descriptor = - google::api::expr::v1alpha1::MapValue::descriptor(); + const auto* descriptor = cel::expr::MapValue::descriptor(); const auto* entries_field = descriptor->FindFieldByName("entries"); const auto* key_field = entries_field->message_type()->FindFieldByName("key"); @@ -111,10 +114,10 @@ MATCHER_P(MatchesConformanceValue, expected, "") { return differencer; }(); - const google::api::expr::v1alpha1::ExprValue& got = arg; - const google::api::expr::v1alpha1::Value& want = expected; + const cel::expr::ExprValue& got = arg; + const cel::expr::Value& want = expected; - google::api::expr::v1alpha1::ExprValue test_value; + cel::expr::ExprValue test_value; (*test_value.mutable_value()) = want; if (kDifferencer->Compare(got, test_value)) { @@ -172,7 +175,12 @@ class ConformanceTest : public testing::Test { eval_request.set_container(test_.container()); } if (!test_.bindings().empty()) { - *eval_request.mutable_bindings() = test_.bindings(); + for (const auto& binding : test_.bindings()) { + absl::Cord serialized; + ABSL_CHECK(binding.second.SerializePartialToCord(&serialized)); + ABSL_CHECK((*eval_request.mutable_bindings())[binding.first] + .ParsePartialFromCord(serialized)); + } } if (absl::GetFlag(FLAGS_skip_check) || test_.disable_check()) { @@ -183,7 +191,12 @@ class ConformanceTest : public testing::Test { check_request.set_allocated_parsed_expr( parse_response.release_parsed_expr()); check_request.set_container(test_.container()); - (*check_request.mutable_type_env()) = test_.type_env(); + for (const auto& type_env : test_.type_env()) { + absl::Cord serialized; + ABSL_CHECK(type_env.SerializePartialToCord(&serialized)); + ABSL_CHECK( + check_request.add_type_env()->ParsePartialFromCord(serialized)); + } CheckResponse check_response; service_->Check(check_request, check_response); ASSERT_THAT(check_response.issues(), IsEmpty()) << absl::StrCat( @@ -202,9 +215,11 @@ class ConformanceTest : public testing::Test { ASSERT_TRUE(eval_response.has_result()) << eval_response; switch (test_.result_matcher_case()) { case SimpleTest::kValue: { - google::api::expr::v1alpha1::ExprValue test_value; - EXPECT_THAT(eval_response.result(), - MatchesConformanceValue(test_.value())); + absl::Cord serialized; + ABSL_CHECK(eval_response.result().SerializePartialToCord(&serialized)); + cel::expr::ExprValue test_value; + ABSL_CHECK(test_value.ParsePartialFromCord(serialized)); + EXPECT_THAT(test_value, MatchesConformanceValue(test_.value())); break; } case SimpleTest::kEvalError: diff --git a/conformance/service.cc b/conformance/service.cc index 803e80e35..a6d90c0b0 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -21,7 +21,7 @@ #include #include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -31,6 +31,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -83,19 +84,19 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" -#include "proto/cel/expr/conformance/proto2/test_all_types.pb.h" -#include "proto/cel/expr/conformance/proto2/test_all_types_extensions.pb.h" -#include "proto/cel/expr/conformance/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" - using ::cel::CreateStandardRuntimeBuilder; using ::cel::FunctionDecl; using ::cel::Runtime; using ::cel::RuntimeOptions; using ::cel::VariableDecl; +using ::cel::conformance_internal::ConvertWireCompatProto; using ::cel::conformance_internal::FromConformanceValue; using ::cel::conformance_internal::ToConformanceValue; using ::cel::extensions::ProtoMemoryManagerRef; @@ -218,7 +219,7 @@ using ConformanceServiceInterface = ::cel_conformance::ConformanceServiceInterface; // Return a normalized raw expr for evaluation. -google::api::expr::v1alpha1::Expr ExtractExpr( +cel::expr::Expr ExtractExpr( const conformance::v1alpha1::EvalRequest& request) { const v1alpha1::Expr* expr = nullptr; @@ -228,14 +229,16 @@ google::api::expr::v1alpha1::Expr ExtractExpr( } else if (request.has_checked_expr()) { expr = &request.checked_expr().expr(); } - google::api::expr::v1alpha1::Expr out; - (out).MergeFrom(*expr); + cel::expr::Expr out; + if (expr != nullptr) { + ABSL_CHECK(ConvertWireCompatProto(*expr, &out)); // Crash OK + } return out; } absl::StatusOr FromConformanceType( google::protobuf::Arena* arena, const google::api::expr::v1alpha1::Type& type) { - google::api::expr::v1alpha1::Type unversioned; + cel::expr::Type unversioned; if (!unversioned.MergeFromString(type.SerializeAsString())) { return absl::InternalError("Failed to convert from v1alpha1 type."); } @@ -260,7 +263,8 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, parser::Parse(*source, macros, options)); - (*response.mutable_parsed_expr()).MergeFrom(parsed_expr); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(parsed_expr, response.mutable_parsed_expr())); return absl::OkStatus(); } @@ -273,30 +277,30 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< - cel::expr::conformance::google::protobuf::TestAllTypes>(); + cel::expr::conformance::proto2::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< - cel::expr::conformance::google::protobuf::NestedTestAllTypes>(); - google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::int32_ext); - google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::nested_ext); + cel::expr::conformance::proto2::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::test_all_types_ext); + cel::expr::conformance::proto2::test_all_types_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::nested_enum_ext); + cel::expr::conformance::proto2::nested_enum_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::repeated_test_all_types); + cel::expr::conformance::proto2::repeated_test_all_types); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); InterpreterOptions options; @@ -320,11 +324,11 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CreateCelExpressionBuilder(options); auto type_registry = builder->GetTypeRegistry(); type_registry->Register( - cel::expr::conformance::google::protobuf::GlobalEnum_descriptor()); + cel::expr::conformance::proto2::GlobalEnum_descriptor()); type_registry->Register( cel::expr::conformance::proto3::GlobalEnum_descriptor()); type_registry->Register( - cel::expr::conformance::google::protobuf::TestAllTypes::NestedEnum_descriptor()); + cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor()); type_registry->Register( cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); CEL_RETURN_IF_ERROR( @@ -361,8 +365,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { absl::Status Eval(const conformance::v1alpha1::EvalRequest& request, conformance::v1alpha1::EvalResponse& response) override { Arena arena; - google::api::expr::v1alpha1::SourceInfo source_info; - google::api::expr::v1alpha1::Expr expr = ExtractExpr(request); + cel::expr::SourceInfo source_info; + cel::expr::Expr expr = ExtractExpr(request); builder_->set_container(request.container()); auto cel_expression_status = builder_->CreateExpression(&expr, &source_info); @@ -375,8 +379,9 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { Activation activation; for (const auto& pair : request.bindings()) { - auto* import_value = Arena::Create(&arena); - (*import_value).MergeFrom(pair.second.value()); + auto* import_value = Arena::Create(&arena); + ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK + import_value)); auto import_status = ValueToCelValue(*import_value, &arena); if (!import_status.ok()) { return absl::InternalError(import_status.status().ToString()); @@ -400,13 +405,14 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { ->add_errors() ->mutable_message() = std::string(result.ErrorOrDie()->message()); } else { - google::api::expr::v1alpha1::Value export_value; + cel::expr::Value export_value; auto export_status = CelValueToValue(result, &export_value); if (!export_status.ok()) { return absl::InternalError(export_status.ToString()); } auto* result_value = response.mutable_result()->mutable_value(); - (*result_value).MergeFrom(export_value); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(export_value, result_value)); } return absl::OkStatus(); } @@ -426,30 +432,30 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< - cel::expr::conformance::google::protobuf::TestAllTypes>(); + cel::expr::conformance::proto2::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< - cel::expr::conformance::google::protobuf::NestedTestAllTypes>(); - google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::int32_ext); - google::protobuf::LinkExtensionReflection(cel::expr::conformance::google::protobuf::nested_ext); + cel::expr::conformance::proto2::NestedTestAllTypes>(); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); + google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::test_all_types_ext); + cel::expr::conformance::proto2::test_all_types_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::nested_enum_ext); + cel::expr::conformance::proto2::nested_enum_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::repeated_test_all_types); + cel::expr::conformance::proto2::repeated_test_all_types); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( - cel::expr::conformance::google::protobuf::Proto2ExtensionScopedMessage:: + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); RuntimeOptions options; @@ -487,13 +493,13 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { std::make_unique()); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, - cel::expr::conformance::google::protobuf::GlobalEnum_descriptor())); + cel::expr::conformance::proto2::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto3::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, - cel::expr::conformance::google::protobuf::TestAllTypes::NestedEnum_descriptor())); + cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); @@ -558,8 +564,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { cel::Activation activation; for (const auto& pair : request.bindings()) { - google::api::expr::v1alpha1::Value import_value; - (import_value).MergeFrom(pair.second.value()); + cel::expr::Value import_value; + ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK + &import_value)); auto import_status = FromConformanceValue(value_factory.get(), import_value); if (!import_status.ok()) { @@ -595,7 +602,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { absl::StatusToStringMode::kWithEverything)); } auto* result_value = response.mutable_result()->mutable_value(); - (*result_value).MergeFrom(*export_status); + ABSL_CHECK( // Crash OK + ConvertWireCompatProto(*export_status, result_value)); } return absl::OkStatus(); } @@ -614,9 +622,10 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { static absl::Status DoCheck( google::protobuf::Arena* arena, const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) { - google::api::expr::v1alpha1::ParsedExpr parsed_expr; + cel::expr::ParsedExpr parsed_expr; - (parsed_expr).MergeFrom(request.parsed_expr()); + ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK + &parsed_expr)); CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, cel::extensions::CreateAstFromParsedExpr(parsed_expr)); @@ -688,9 +697,10 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( - google::api::expr::v1alpha1::CheckedExpr pb_checked_ast, + cel::expr::CheckedExpr pb_checked_ast, cel::extensions::CreateCheckedExprFromAst(*validation_result.GetAst())); - *response.mutable_checked_expr() = std::move(pb_checked_ast); + ABSL_CHECK(ConvertWireCompatProto(pb_checked_ast, // Crash OK + response.mutable_checked_expr())); return absl::OkStatus(); } @@ -699,15 +709,17 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { const conformance::v1alpha1::EvalRequest& request) { std::unique_ptr ast; if (request.has_parsed_expr()) { - google::api::expr::v1alpha1::ParsedExpr unversioned; - (unversioned).MergeFrom(request.parsed_expr()); + cel::expr::ParsedExpr unversioned; + ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK + &unversioned)); CEL_ASSIGN_OR_RETURN(ast, cel::extensions::CreateAstFromParsedExpr( std::move(unversioned))); } else if (request.has_checked_expr()) { - google::api::expr::v1alpha1::CheckedExpr unversioned; - (unversioned).MergeFrom(request.checked_expr()); + cel::expr::CheckedExpr unversioned; + ABSL_CHECK(ConvertWireCompatProto(request.checked_expr(), // Crash OK + &unversioned)); CEL_ASSIGN_OR_RETURN(ast, cel::extensions::CreateAstFromCheckedExpr( std::move(unversioned))); } diff --git a/conformance/value_conversion.cc b/conformance/value_conversion.cc index 8da26613f..9c8f8c361 100644 --- a/conformance/value_conversion.cc +++ b/conformance/value_conversion.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -43,9 +43,9 @@ namespace cel::conformance_internal { namespace { -using ConformanceKind = google::api::expr::v1alpha1::Value::KindCase; -using ConformanceMapValue = google::api::expr::v1alpha1::MapValue; -using ConformanceListValue = google::api::expr::v1alpha1::ListValue; +using ConformanceKind = cel::expr::Value::KindCase; +using ConformanceMapValue = cel::expr::MapValue; +using ConformanceListValue = cel::expr::ListValue; std::string ToString(ConformanceKind kind_case) { switch (kind_case) { @@ -214,8 +214,8 @@ absl::optional MaybeWellKnownType(absl::string_view type_name) { } // namespace absl::StatusOr FromConformanceValue( - ValueManager& value_manager, const google::api::expr::v1alpha1::Value& value) { - google::protobuf::LinkMessageReflection(); + ValueManager& value_manager, const cel::expr::Value& value) { + google::protobuf::LinkMessageReflection(); switch (value.kind_case()) { case ConformanceKind::kBoolValue: return value_manager.CreateBoolValue(value.bool_value()); @@ -244,9 +244,9 @@ absl::StatusOr FromConformanceValue( } } -absl::StatusOr ToConformanceValue( +absl::StatusOr ToConformanceValue( ValueManager& value_manager, const Value& value) { - google::api::expr::v1alpha1::Value result; + cel::expr::Value result; switch (value->kind()) { case ValueKind::kBool: result.set_bool_value(value.GetBool().NativeValue()); @@ -312,70 +312,70 @@ absl::StatusOr ToConformanceValue( } absl::StatusOr FromConformanceType(google::protobuf::Arena* arena, - const google::api::expr::v1alpha1::Type& type) { + const cel::expr::Type& type) { switch (type.type_kind_case()) { - case google::api::expr::v1alpha1::Type::kNull: + case cel::expr::Type::kNull: return NullType(); - case google::api::expr::v1alpha1::Type::kDyn: + case cel::expr::Type::kDyn: return DynType(); - case google::api::expr::v1alpha1::Type::kPrimitive: { + case cel::expr::Type::kPrimitive: { switch (type.primitive()) { - case google::api::expr::v1alpha1::Type::BOOL: + case cel::expr::Type::BOOL: return BoolType(); - case google::api::expr::v1alpha1::Type::INT64: + case cel::expr::Type::INT64: return IntType(); - case google::api::expr::v1alpha1::Type::UINT64: + case cel::expr::Type::UINT64: return UintType(); - case google::api::expr::v1alpha1::Type::DOUBLE: + case cel::expr::Type::DOUBLE: return DoubleType(); - case google::api::expr::v1alpha1::Type::STRING: + case cel::expr::Type::STRING: return StringType(); - case google::api::expr::v1alpha1::Type::BYTES: + case cel::expr::Type::BYTES: return BytesType(); default: return absl::UnimplementedError(absl::StrCat( "FromConformanceType not supported ", type.primitive())); } } - case google::api::expr::v1alpha1::Type::kWrapper: { + case cel::expr::Type::kWrapper: { switch (type.wrapper()) { - case google::api::expr::v1alpha1::Type::BOOL: + case cel::expr::Type::BOOL: return BoolWrapperType(); - case google::api::expr::v1alpha1::Type::INT64: + case cel::expr::Type::INT64: return IntWrapperType(); - case google::api::expr::v1alpha1::Type::UINT64: + case cel::expr::Type::UINT64: return UintWrapperType(); - case google::api::expr::v1alpha1::Type::DOUBLE: + case cel::expr::Type::DOUBLE: return DoubleWrapperType(); - case google::api::expr::v1alpha1::Type::STRING: + case cel::expr::Type::STRING: return StringWrapperType(); - case google::api::expr::v1alpha1::Type::BYTES: + case cel::expr::Type::BYTES: return BytesWrapperType(); default: return absl::InvalidArgumentError(absl::StrCat( "FromConformanceType not supported ", type.wrapper())); } } - case google::api::expr::v1alpha1::Type::kWellKnown: { + case cel::expr::Type::kWellKnown: { switch (type.well_known()) { - case google::api::expr::v1alpha1::Type::DURATION: + case cel::expr::Type::DURATION: return DurationType(); - case google::api::expr::v1alpha1::Type::TIMESTAMP: + case cel::expr::Type::TIMESTAMP: return TimestampType(); - case google::api::expr::v1alpha1::Type::ANY: + case cel::expr::Type::ANY: return DynType(); default: return absl::InvalidArgumentError(absl::StrCat( "FromConformanceType not supported ", type.well_known())); } } - case google::api::expr::v1alpha1::Type::kListType: { + case cel::expr::Type::kListType: { CEL_ASSIGN_OR_RETURN( Type element_type, FromConformanceType(arena, type.list_type().elem_type())); return ListType(arena, element_type); } - case google::api::expr::v1alpha1::Type::kMapType: { + case cel::expr::Type::kMapType: { CEL_ASSIGN_OR_RETURN( auto key_type, FromConformanceType(arena, type.map_type().key_type())); @@ -384,10 +384,10 @@ absl::StatusOr FromConformanceType(google::protobuf::Arena* arena, FromConformanceType(arena, type.map_type().value_type())); return MapType(arena, key_type, value_type); } - case google::api::expr::v1alpha1::Type::kFunction: { + case cel::expr::Type::kFunction: { return absl::UnimplementedError("Function support not yet implemented"); } - case google::api::expr::v1alpha1::Type::kMessageType: { + case cel::expr::Type::kMessageType: { if (absl::optional wkt = MaybeWellKnownType(type.message_type()); wkt.has_value()) { return *wkt; @@ -401,20 +401,20 @@ absl::StatusOr FromConformanceType(google::protobuf::Arena* arena, } return MessageType(descriptor); } - case google::api::expr::v1alpha1::Type::kTypeParam: { + case cel::expr::Type::kTypeParam: { auto* param = google::protobuf::Arena::Create(arena, type.type_param()); return TypeParamType(*param); } - case google::api::expr::v1alpha1::Type::kType: { + case cel::expr::Type::kType: { CEL_ASSIGN_OR_RETURN(Type param_type, FromConformanceType(arena, type.type())); return TypeType(arena, param_type); } - case google::api::expr::v1alpha1::Type::kError: { + case cel::expr::Type::kError: { return absl::InvalidArgumentError("Error type not supported"); } - case google::api::expr::v1alpha1::Type::kAbstractType: { + case cel::expr::Type::kAbstractType: { std::vector parameters; for (const auto& param : type.abstract_type().parameter_types()) { CEL_ASSIGN_OR_RETURN(auto param_type, diff --git a/conformance/value_conversion.h b/conformance/value_conversion.h index c8a9bd962..dcf5ea8f4 100644 --- a/conformance/value_conversion.h +++ b/conformance/value_conversion.h @@ -16,24 +16,96 @@ #ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_VALUE_CONVERSION_H_ #define THIRD_PARTY_CEL_CPP_CONFORMANCE_VALUE_CONVERSION_H_ +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "common/type.h" #include "common/value.h" #include "common/value_manager.h" #include "google/protobuf/arena.h" +#include "google/protobuf/message_lite.h" namespace cel::conformance_internal { +ABSL_MUST_USE_RESULT +inline bool UnsafeConvertWireCompatProto( + const google::protobuf::MessageLite& src, absl::Nonnull dest) { + absl::Cord serialized; + return src.SerializePartialToCord(&serialized) && + dest->ParsePartialFromCord(serialized); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::CheckedExpr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::CheckedExpr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::ParsedExpr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::ParsedExpr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::Expr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::Expr& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const cel::expr::Value& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + +ABSL_MUST_USE_RESULT +inline bool ConvertWireCompatProto( + const google::api::expr::v1alpha1::Value& src, + absl::Nonnull dest) { + return UnsafeConvertWireCompatProto(src, dest); +} + absl::StatusOr FromConformanceValue( - ValueManager& value_manager, const google::api::expr::v1alpha1::Value& value); + ValueManager& value_manager, const cel::expr::Value& value); -absl::StatusOr ToConformanceValue( +absl::StatusOr ToConformanceValue( ValueManager& value_manager, const Value& value); absl::StatusOr FromConformanceType(google::protobuf::Arena* arena, - const google::api::expr::v1alpha1::Type& type); + const cel::expr::Type& type); } // namespace cel::conformance_internal #endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_VALUE_CONVERSION_H_ diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 5974a27c9..396cca677 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -185,9 +185,9 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -215,7 +215,7 @@ cc_test( "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -244,8 +244,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -280,9 +280,9 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -351,7 +351,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -433,7 +433,7 @@ cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -534,8 +534,8 @@ cc_test( "//runtime:runtime_issue", "//runtime/internal:issue_collector", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -596,7 +596,7 @@ cc_test( "//runtime:type_registry", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/compiler/cel_expression_builder_flat_impl.cc b/eval/compiler/cel_expression_builder_flat_impl.cc index 0aa9fc4f1..63b601cc4 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.cc +++ b/eval/compiler/cel_expression_builder_flat_impl.cc @@ -20,8 +20,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/macros.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -40,9 +40,9 @@ namespace google::api::expr::runtime { using ::cel::Ast; using ::cel::RuntimeIssue; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; // NOLINT: adjusted in OSS -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; // NOLINT: adjusted in OSS +using ::cel::expr::SourceInfo; absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpression( diff --git a/eval/compiler/cel_expression_builder_flat_impl.h b/eval/compiler/cel_expression_builder_flat_impl.h index 8c4581e54..98efc4b74 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.h +++ b/eval/compiler/cel_expression_builder_flat_impl.h @@ -22,8 +22,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast.h" @@ -46,19 +46,19 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { *GetTypeRegistry()) {} absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info) const override; + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info) const override; absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, std::vector* warnings) const override; absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::CheckedExpr* checked_expr) const override; + const cel::expr::CheckedExpr* checked_expr) const override; absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::CheckedExpr* checked_expr, + const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const override; FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; } diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index 8a79e19a7..c70a04396 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -24,8 +24,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -50,7 +50,7 @@ #include "parser/macro.h" #include "parser/parser.h" #include "runtime/runtime_options.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -60,15 +60,15 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::conformance::proto3::NestedTestAllTypes; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::api::expr::parser::Macro; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParseWithMacros; -using ::google::api::expr::test::v1::proto3::NestedTestAllTypes; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::testing::_; using ::testing::Contains; using ::testing::HasSubstr; @@ -163,7 +163,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); cel::RuntimeOptions options; - options.container = "google.api.expr.test.v1.proto3"; + options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; @@ -190,7 +190,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); cel::RuntimeOptions options; - options.container = "google.api.expr.test.v1.proto3"; + options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; @@ -224,7 +224,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); cel::RuntimeOptions options; - options.container = "google.api.expr.test.v1.proto3"; + options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; auto cb = [](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { return absl::OkStatus(); @@ -257,7 +257,7 @@ TEST_P(RecursivePlanTest, Disabled) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); cel::RuntimeOptions options; - options.container = "google.api.expr.test.v1.proto3"; + options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // disabled. options.max_recursion_depth = 0; diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index b724795ad..7aafa7442 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -58,7 +58,7 @@ using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::runtime_internal::IssueCollector; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::CreateCreateListStep; diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index a3aa8ff29..4b9ff2b8c 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -16,7 +16,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -43,8 +43,8 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::CheckedExpr; +using ::cel::expr::ParsedExpr; using ::testing::HasSubstr; class CelExpressionBuilderFlatImplComprehensionsTest diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index 1c3be14ab..b7bed3655 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -23,7 +23,7 @@ namespace google::api::expr::runtime { namespace { -using ::google::api::expr::v1alpha1::Expr; +using ::cel::expr::Expr; using ::testing::Eq; using ::testing::SizeIs; diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index bd25cea2d..488f81a8d 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -22,8 +22,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/container/flat_hash_map.h" @@ -62,7 +62,7 @@ #include "internal/testing.h" #include "parser/parser.h" #include "runtime/runtime_options.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" @@ -74,14 +74,14 @@ namespace { using ::absl_testing::StatusIs; using ::cel::Value; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::internal::test::EqualsProto; using ::cel::internal::test::ReadBinaryProtoFromFile; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; -using ::google::api::expr::test::v1::proto3::TestAllTypes; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::testing::_; using ::testing::Eq; using ::testing::HasSubstr; @@ -1866,7 +1866,7 @@ TEST(FlatExprBuilderTest, AnyPackingList) { std::make_unique( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory())); - builder.set_container("google.api.expr.test.v1.proto3"); + builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), @@ -1901,7 +1901,7 @@ TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { std::make_unique( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory())); - builder.set_container("google.api.expr.test.v1.proto3"); + builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), @@ -1934,7 +1934,7 @@ TEST(FlatExprBuilderTest, AnyPackingInt) { std::make_unique( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory())); - builder.set_container("google.api.expr.test.v1.proto3"); + builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), @@ -1966,7 +1966,7 @@ TEST(FlatExprBuilderTest, AnyPackingMap) { std::make_unique( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory())); - builder.set_container("google.api.expr.test.v1.proto3"); + builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), diff --git a/eval/compiler/instrumentation_test.cc b/eval/compiler/instrumentation_test.cc index b429127f2..beb94fe2c 100644 --- a/eval/compiler/instrumentation_test.cc +++ b/eval/compiler/instrumentation_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" @@ -45,7 +45,7 @@ namespace { using ::cel::IntValue; using ::cel::Value; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::Pair; diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 0ca81a87c..10dd91f59 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -108,7 +108,7 @@ MATCHER_P(StatusCodeIs, x, "") { } std::unique_ptr ParseTestProto(const std::string& pb) { - google::api::expr::v1alpha1::Expr expr; + cel::expr::Expr expr; EXPECT_TRUE(google::protobuf::TextFormat::ParseFromString(pb, &expr)); return absl::WrapUnique(cel::internal::down_cast( cel::extensions::CreateAstFromParsedExpr(expr).value().release())); @@ -137,7 +137,7 @@ TEST(ResolveReferences, Basic) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -194,7 +194,7 @@ TEST(ResolveReferences, NamespacedIdent) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 1 @@ -313,7 +313,7 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -353,7 +353,7 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -392,7 +392,7 @@ TEST(ResolveReferences, ConstReferenceSkipped) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -685,7 +685,7 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -724,7 +724,7 @@ TEST(ResolveReferences, auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 call_expr { @@ -791,7 +791,7 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { ASSERT_THAT(result, IsOkAndHolds(false)); // The target is unchanged because it is a test_only select. - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(kReceiverCallHasExtensionAndExpr, &expected_expr); EXPECT_EQ(expr_ast->root_expr(), @@ -888,7 +888,7 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 17 @@ -995,7 +995,7 @@ TEST(ResolveReferences, ReferenceToId0Warns) { auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); - google::api::expr::v1alpha1::Expr expected_expr; + cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString(R"pb( id: 0 select_expr { diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index dca6bdfe7..65d2d9058 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -19,8 +19,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" #include "common/memory.h" @@ -49,7 +49,7 @@ using ::cel::runtime_internal::IssueCollector; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; -namespace exprpb = google::api::expr::v1alpha1; +namespace exprpb = cel::expr; class RegexPrecompilationExtensionTest : public testing::TestWithParam { public: diff --git a/eval/eval/BUILD b/eval/eval/BUILD index fce68475a..62c67c0e9 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -422,7 +422,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -517,7 +517,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -541,7 +541,7 @@ cc_test( "//internal:testing", "//runtime:activation", "//runtime:runtime_options", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -602,7 +602,7 @@ cc_test( "//internal:testing", "//parser", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -623,8 +623,8 @@ cc_test( "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -773,8 +773,8 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -849,7 +849,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -877,7 +877,7 @@ cc_test( "//runtime:runtime_options", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -904,7 +904,7 @@ cc_test( "//eval/public:cel_attribute", "//eval/public:cel_value", "//internal:testing", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -1091,7 +1091,7 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", ], ) diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc index 1ab889ed8..3143b9ed4 100644 --- a/eval/eval/attribute_trail_test.cc +++ b/eval/eval/attribute_trail_test.cc @@ -2,7 +2,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "internal/testing.h" diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 8fb5cfc27..2fd513ee7 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -5,7 +5,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 688907a66..232d0e469 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -6,7 +6,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "base/builtins.h" @@ -38,7 +38,7 @@ using ::absl_testing::StatusIs; using ::cel::TypeProvider; using ::cel::ast_internal::Expr; using ::cel::ast_internal::SourceInfo; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::protobuf::Struct; using ::testing::_; using ::testing::AllOf; diff --git a/eval/eval/create_map_step_test.cc b/eval/eval/create_map_step_test.cc index c7c0e8493..44554aee4 100644 --- a/eval/eval/create_map_step_test.cc +++ b/eval/eval/create_map_step_test.cc @@ -20,7 +20,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "base/ast_internal/expr.h" diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 7b56f2a23..ffcfb5faf 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -21,7 +21,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 1a5a7fd38..7b4404af1 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -4,7 +4,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "base/type_provider.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/eval/cel_expression_flat_impl.h" @@ -23,7 +23,7 @@ using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::interop_internal::CreateIntValue; -using ::google::api::expr::v1alpha1::Expr; +using ::cel::expr::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::testing::_; using ::testing::Eq; @@ -116,7 +116,7 @@ class MockTraceCallback { TEST(EvaluatorCoreTest, TraceTest) { Expr expr; - google::api::expr::v1alpha1::SourceInfo source_info; + cel::expr::SourceInfo source_info; // 1 && [1,2,3].all(x, x > 0) diff --git a/eval/eval/jump_step.h b/eval/eval/jump_step.h index c46d3a15c..fe33d4628 100644 --- a/eval/eval/jump_step.h +++ b/eval/eval/jump_step.h @@ -17,7 +17,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "eval/eval/evaluator_core.h" diff --git a/eval/eval/lazy_init_step.cc b/eval/eval/lazy_init_step.cc index a022d244f..73466f1ef 100644 --- a/eval/eval/lazy_init_step.cc +++ b/eval/eval/lazy_init_step.cc @@ -19,7 +19,7 @@ #include #include -#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/value.h" diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc index 3dfd793b2..367a8de25 100644 --- a/eval/eval/regex_match_step_test.cc +++ b/eval/eval/regex_match_step_test.cc @@ -14,8 +14,8 @@ #include "eval/eval/regex_match_step.h" -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -30,8 +30,8 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; -using google::api::expr::v1alpha1::CheckedExpr; -using google::api::expr::v1alpha1::Reference; +using cel::expr::CheckedExpr; +using cel::expr::Reference; using ::testing::Eq; using ::testing::HasSubstr; diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 48676f36b..3bb22fca8 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -4,7 +4,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" @@ -43,7 +43,7 @@ #include "runtime/activation.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace google::api::expr::runtime { @@ -65,11 +65,11 @@ using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ast_internal::Expr; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::extensions::ProtoMessageToValue; using ::cel::internal::test::EqualsProto; using ::cel::test::IntValueIs; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::testing::_; using ::testing::Eq; using ::testing::HasSubstr; diff --git a/eval/public/BUILD b/eval/public/BUILD index cb0a556bd..c11f16fc1 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -119,7 +119,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -317,7 +317,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -373,7 +373,7 @@ cc_test( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -453,7 +453,7 @@ cc_test( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -492,8 +492,8 @@ cc_library( "//common:legacy_value", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -502,7 +502,7 @@ cc_library( srcs = ["source_position.cc"], hdrs = ["source_position.h"], deps = [ - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -513,7 +513,7 @@ cc_library( ], deps = [ ":source_position", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -524,7 +524,7 @@ cc_library( ], deps = [ ":ast_visitor", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -541,7 +541,7 @@ cc_library( ":source_position", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -717,7 +717,7 @@ cc_library( "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -733,7 +733,7 @@ cc_test( "//internal:testing", "//parser", "//testutil:util", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -868,7 +868,7 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -905,7 +905,7 @@ cc_test( deps = [ ":source_position", "//internal:testing", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -992,7 +992,7 @@ cc_test( ":unknown_function_result_set", ":unknown_set", "//internal:testing", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -1015,7 +1015,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -1079,7 +1079,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -1147,7 +1147,7 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//internal:testing", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index 6e228e188..f490f0ca8 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -22,7 +22,7 @@ namespace { using ::absl_testing::StatusIs; using ::cel::extensions::ProtoMemoryManager; -using ::google::api::expr::v1alpha1::Expr; +using ::cel::expr::Expr; using ::google::protobuf::Arena; using ::testing::ElementsAre; using ::testing::Eq; diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 1d4f09393..3c210e607 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" @@ -25,14 +25,14 @@ namespace google::api::expr::runtime { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using cel::expr::Expr; +using cel::expr::SourceInfo; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; namespace { diff --git a/eval/public/ast_rewrite.h b/eval/public/ast_rewrite.h index b4519e7d0..791778c4f 100644 --- a/eval/public/ast_rewrite.h +++ b/eval/public/ast_rewrite.h @@ -15,7 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/types/span.h" #include "eval/public/ast_visitor.h" @@ -40,18 +40,18 @@ class AstRewriter : public AstVisitor { // Rewrite a sub expression before visiting. // Occurs before visiting Expr. If expr is modified, the new value will be // visited. - virtual bool PreVisitRewrite(google::api::expr::v1alpha1::Expr* expr, + virtual bool PreVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) = 0; // Rewrite a sub expression after visiting. // Occurs after visiting expr and it's children. If expr is modified, the old // sub expression is visited. - virtual bool PostVisitRewrite(google::api::expr::v1alpha1::Expr* expr, + virtual bool PostVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) = 0; // Notify the visitor of updates to the traversal stack. virtual void TraversalStackUpdate( - absl::Span path) = 0; + absl::Span path) = 0; }; // Trivial implementation for AST rewriters. @@ -60,66 +60,66 @@ class AstRewriterBase : public AstRewriter { public: ~AstRewriterBase() override {} - void PreVisitExpr(const google::api::expr::v1alpha1::Expr*, + void PreVisitExpr(const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitExpr(const google::api::expr::v1alpha1::Expr*, + void PostVisitExpr(const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PreVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PreVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitArg(int, const google::api::expr::v1alpha1::Expr*, + void PostVisitArg(int, const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitTarget(const google::api::expr::v1alpha1::Expr*, + void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) override {} - void PostVisitCreateStruct(const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) override {} - bool PreVisitRewrite(google::api::expr::v1alpha1::Expr* expr, + bool PreVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) override { return false; } - bool PostVisitRewrite(google::api::expr::v1alpha1::Expr* expr, + bool PostVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) override { return false; } void TraversalStackUpdate( - absl::Span path) override {} + absl::Span path) override {} }; // Traverses the AST representation in an expr proto. Returns true if any @@ -162,12 +162,12 @@ class AstRewriterBase : public AstRewriter { // ..PostVisitCall(fn) // PostVisitExpr -bool AstRewrite(google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, +bool AstRewrite(cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, AstRewriter* visitor); -bool AstRewrite(google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, +bool AstRewrite(cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, AstRewriter* visitor, RewriteTraversalOptions options); } // namespace google::api::expr::runtime diff --git a/eval/public/ast_rewrite_test.cc b/eval/public/ast_rewrite_test.cc index 3159d4607..b2ee8d13c 100644 --- a/eval/public/ast_rewrite_test.cc +++ b/eval/public/ast_rewrite_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" #include "internal/testing.h" @@ -28,20 +28,20 @@ namespace google::api::expr::runtime { namespace { -using ::google::api::expr::v1alpha1::Constant; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Constant; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::testing::_; using ::testing::ElementsAre; using ::testing::InSequence; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; class MockAstRewriter : public AstRewriter { public: diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index ce1a66202..a86923c67 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -16,7 +16,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" @@ -24,14 +24,14 @@ namespace google::api::expr::runtime { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using cel::expr::Expr; +using cel::expr::SourceInfo; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; namespace { diff --git a/eval/public/ast_traverse.h b/eval/public/ast_traverse.h index f9fe13752..f81c6f47a 100644 --- a/eval/public/ast_traverse.h +++ b/eval/public/ast_traverse.h @@ -17,7 +17,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "eval/public/ast_visitor.h" namespace google::api::expr::runtime { @@ -57,8 +57,8 @@ struct TraversalOptions { // ....PostVisitArg(fn, 1) // ..PostVisitCall(fn) // PostVisitExpr -void AstTraverse(const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, +void AstTraverse(const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, AstVisitor* visitor, TraversalOptions options = TraversalOptions()); diff --git a/eval/public/ast_traverse_test.cc b/eval/public/ast_traverse_test.cc index 45c0c523d..ca6d81b72 100644 --- a/eval/public/ast_traverse_test.cc +++ b/eval/public/ast_traverse_test.cc @@ -21,16 +21,16 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Constant; +using cel::expr::Expr; +using cel::expr::SourceInfo; using testing::_; -using Ident = google::api::expr::v1alpha1::Expr::Ident; -using Select = google::api::expr::v1alpha1::Expr::Select; -using Call = google::api::expr::v1alpha1::Expr::Call; -using CreateList = google::api::expr::v1alpha1::Expr::CreateList; -using CreateStruct = google::api::expr::v1alpha1::Expr::CreateStruct; -using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; +using Ident = cel::expr::Expr::Ident; +using Select = cel::expr::Expr::Select; +using Call = cel::expr::Expr::Call; +using CreateList = cel::expr::Expr::CreateList; +using CreateStruct = cel::expr::Expr::CreateStruct; +using Comprehension = cel::expr::Expr::Comprehension; class MockAstVisitor : public AstVisitor { public: diff --git a/eval/public/ast_visitor.h b/eval/public/ast_visitor.h index 09eb133ea..4f0ef2a0a 100644 --- a/eval/public/ast_visitor.h +++ b/eval/public/ast_visitor.h @@ -17,7 +17,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "eval/public/source_position.h" namespace google { @@ -49,117 +49,117 @@ class AstVisitor { // Is invoked before child Expr nodes being processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitExpr(const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitExpr(const cel::expr::Expr*, const SourcePosition*) {} // Expr node handler method. Called for all Expr nodes. // Is invoked after child Expr nodes are processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PostVisitExpr(const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitExpr(const cel::expr::Expr*, const SourcePosition*) {} // Const node handler. // Invoked before child nodes are processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) {} // Const node handler. // Invoked after child nodes are processed. - virtual void PostVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Ident node handler. // Invoked before child nodes are processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) {} // Ident node handler. // Invoked after child nodes are processed. - virtual void PostVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Select node handler // Invoked before child nodes are processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) {} // Select node handler // Invoked after child nodes are processed. - virtual void PostVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. - virtual void PreVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after all child nodes are processed. - virtual void PostVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after target node is processed. // Expr is the call expression. - virtual void PostVisitTarget(const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked before all child nodes are processed. virtual void PreVisitComprehension( - const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked before comprehension child node is processed. virtual void PreVisitComprehensionSubexpression( - const google::api::expr::v1alpha1::Expr* subexpr, - const google::api::expr::v1alpha1::Expr::Comprehension* compr, + const cel::expr::Expr* subexpr, + const cel::expr::Expr::Comprehension* compr, ComprehensionArg comprehension_arg, const SourcePosition*) {} // Invoked after comprehension child node is processed. virtual void PostVisitComprehensionSubexpression( - const google::api::expr::v1alpha1::Expr* subexpr, - const google::api::expr::v1alpha1::Expr::Comprehension* compr, + const cel::expr::Expr* subexpr, + const cel::expr::Expr::Comprehension* compr, ComprehensionArg comprehension_arg, const SourcePosition*) {} // Invoked after all child nodes are processed. virtual void PostVisitComprehension( - const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. - virtual void PostVisitArg(int arg_num, const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitArg(int arg_num, const cel::expr::Expr*, const SourcePosition*) = 0; // CreateList node handler // Invoked before child nodes are processed. // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. - virtual void PreVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + virtual void PreVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) {} // CreateList node handler // Invoked after child nodes are processed. - virtual void PostVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + virtual void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) = 0; // CreateStruct node handler @@ -167,14 +167,14 @@ class AstVisitor { // TODO: this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitCreateStruct( - const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) {} + const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) {} // CreateStruct node handler // Invoked after child nodes are processed. virtual void PostVisitCreateStruct( - const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, const SourcePosition*) = 0; + const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) = 0; }; } // namespace runtime diff --git a/eval/public/ast_visitor_base.h b/eval/public/ast_visitor_base.h index 317253118..df8d8a926 100644 --- a/eval/public/ast_visitor_base.h +++ b/eval/public/ast_visitor_base.h @@ -18,7 +18,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ #include "eval/public/ast_visitor.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" namespace google { namespace api { @@ -38,66 +38,66 @@ class AstVisitorBase : public AstVisitor { // Const node handler. // Invoked after child nodes are processed. - void PostVisitConst(const google::api::expr::v1alpha1::Constant*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitConst(const cel::expr::Constant*, + const cel::expr::Expr*, const SourcePosition*) override {} // Ident node handler. // Invoked after child nodes are processed. - void PostVisitIdent(const google::api::expr::v1alpha1::Expr::Ident*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitIdent(const cel::expr::Expr::Ident*, + const cel::expr::Expr*, const SourcePosition*) override {} // Select node handler // Invoked after child nodes are processed. - void PostVisitSelect(const google::api::expr::v1alpha1::Expr::Select*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitSelect(const cel::expr::Expr::Select*, + const cel::expr::Expr*, const SourcePosition*) override {} // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. - void PreVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. - void PostVisitCall(const google::api::expr::v1alpha1::Expr::Call*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCall(const cel::expr::Expr::Call*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked before all child nodes are processed. - void PreVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PreVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. - void PostVisitComprehension(const google::api::expr::v1alpha1::Expr::Comprehension*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitComprehension(const cel::expr::Expr::Comprehension*, + const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. - void PostVisitArg(int, const google::api::expr::v1alpha1::Expr*, + void PostVisitArg(int, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after target node processed. - void PostVisitTarget(const google::api::expr::v1alpha1::Expr*, + void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) override {} // CreateList node handler // Invoked after child nodes are processed. - void PostVisitCreateList(const google::api::expr::v1alpha1::Expr::CreateList*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateList(const cel::expr::Expr::CreateList*, + const cel::expr::Expr*, const SourcePosition*) override {} // CreateStruct node handler // Invoked after child nodes are processed. - void PostVisitCreateStruct(const google::api::expr::v1alpha1::Expr::CreateStruct*, - const google::api::expr::v1alpha1::Expr*, + void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, + const cel::expr::Expr*, const SourcePosition*) override {} }; diff --git a/eval/public/builtin_func_registrar_test.cc b/eval/public/builtin_func_registrar_test.cc index e81cfaa46..44900f274 100644 --- a/eval/public/builtin_func_registrar_test.cc +++ b/eval/public/builtin_func_registrar_test.cc @@ -19,7 +19,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -39,8 +39,8 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Expr; +using cel::expr::SourceInfo; using ::absl_testing::StatusIs; using ::cel::internal::MaxDuration; diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index a4fbdd872..4727345d5 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -20,7 +20,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" @@ -43,8 +43,8 @@ namespace { using google::protobuf::Duration; using google::protobuf::Timestamp; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Expr; +using cel::expr::SourceInfo; using google::protobuf::Arena; diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h index 923d3b918..959fff75e 100644 --- a/eval/public/cel_attribute.h +++ b/eval/public/cel_attribute.h @@ -14,7 +14,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc index 9c8a15d36..3024e486d 100644 --- a/eval/public/cel_attribute_test.cc +++ b/eval/public/cel_attribute_test.cc @@ -13,7 +13,7 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; +using cel::expr::Expr; using ::absl_testing::StatusIs; using ::google::protobuf::Duration; diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 56e83eebe..98b58aa98 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -7,8 +7,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/public/base_activation.h" @@ -89,8 +89,8 @@ class CelExpressionBuilder { // IMPORTANT: The `expr` and `source_info` must outlive the resulting // CelExpression. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info) const = 0; + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info) const = 0; // Creates CelExpression object from AST tree. // expr specifies root of AST tree. @@ -99,8 +99,8 @@ class CelExpressionBuilder { // IMPORTANT: The `expr` and `source_info` must outlive the resulting // CelExpression. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::Expr* expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, + const cel::expr::Expr* expr, + const cel::expr::SourceInfo* source_info, std::vector* warnings) const = 0; // Creates CelExpression object from a checked expression. @@ -108,7 +108,7 @@ class CelExpressionBuilder { // // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::CheckedExpr* checked_expr) const { + const cel::expr::CheckedExpr* checked_expr) const { // Default implementation just passes through the expr and source info. return CreateExpression(&checked_expr->expr(), &checked_expr->source_info()); @@ -120,7 +120,7 @@ class CelExpressionBuilder { // // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. virtual absl::StatusOr> CreateExpression( - const google::api::expr::v1alpha1::CheckedExpr* checked_expr, + const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const { // Default implementation just passes through the expr and source_info. return CreateExpression(&checked_expr->expr(), &checked_expr->source_info(), diff --git a/eval/public/comparison_functions_test.cc b/eval/public/comparison_functions_test.cc index 7efdc48e2..c2b4efec9 100644 --- a/eval/public/comparison_functions_test.cc +++ b/eval/public/comparison_functions_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/arena.h" #include "absl/status/statusor.h" @@ -38,7 +38,7 @@ namespace google::api::expr::runtime { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using ::testing::Combine; using ::testing::ValuesIn; diff --git a/eval/public/container_function_registrar_test.cc b/eval/public/container_function_registrar_test.cc index 2593bc098..e6d5f93d8 100644 --- a/eval/public/container_function_registrar_test.cc +++ b/eval/public/container_function_registrar_test.cc @@ -30,8 +30,8 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Expr; +using cel::expr::SourceInfo; using ::testing::ValuesIn; struct TestCase { diff --git a/eval/public/containers/BUILD b/eval/public/containers/BUILD index ff5acad65..0313d8200 100644 --- a/eval/public/containers/BUILD +++ b/eval/public/containers/BUILD @@ -156,7 +156,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/containers/field_access_test.cc b/eval/public/containers/field_access_test.cc index 20be75ebb..03bed365b 100644 --- a/eval/public/containers/field_access_test.cc +++ b/eval/public/containers/field_access_test.cc @@ -29,16 +29,16 @@ #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "internal/time.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::MaxDuration; using ::cel::internal::MaxTimestamp; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::testing::HasSubstr; diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 7930eac59..0c56dd709 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -23,7 +23,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/any.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/descriptor.pb.h" @@ -62,7 +62,7 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using ::testing::_; using ::testing::Combine; diff --git a/eval/public/logical_function_registrar_test.cc b/eval/public/logical_function_registrar_test.cc index c9944bca0..4895b4931 100644 --- a/eval/public/logical_function_registrar_test.cc +++ b/eval/public/logical_function_registrar_test.cc @@ -19,7 +19,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" @@ -38,8 +38,8 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::Expr; +using cel::expr::SourceInfo; using ::absl_testing::StatusIs; using ::testing::HasSubstr; diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc index cf5e807f7..31d41f1d0 100644 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ b/eval/public/portable_cel_expr_builder_factory_test.cc @@ -44,7 +44,7 @@ namespace google::api::expr::runtime { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::protobuf::Int64Value; // Helpers for c++ / proto to cel value conversions. diff --git a/eval/public/source_position.cc b/eval/public/source_position.cc index 52a4c1185..ac902fa0e 100644 --- a/eval/public/source_position.cc +++ b/eval/public/source_position.cc @@ -21,7 +21,7 @@ namespace api { namespace expr { namespace runtime { -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::SourceInfo; namespace { diff --git a/eval/public/source_position.h b/eval/public/source_position.h index 739f501b4..c4b7f0f88 100644 --- a/eval/public/source_position.h +++ b/eval/public/source_position.h @@ -17,7 +17,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" namespace google { namespace api { @@ -31,7 +31,7 @@ class SourcePosition { // Constructor for a SourcePosition value. The source_info may be nullptr, // in which case line, column, and character_offset will return 0. SourcePosition(const int64_t expr_id, - const google::api::expr::v1alpha1::SourceInfo* source_info) + const cel::expr::SourceInfo* source_info) : expr_id_(expr_id), source_info_(source_info) {} // Non-copyable @@ -54,7 +54,7 @@ class SourcePosition { // The expression identifier. const int64_t expr_id_; // The source information reference generated during expression parsing. - const google::api::expr::v1alpha1::SourceInfo* source_info_; + const cel::expr::SourceInfo* source_info_; }; } // namespace runtime diff --git a/eval/public/source_position_test.cc b/eval/public/source_position_test.cc index 5808312d4..16140d96f 100644 --- a/eval/public/source_position_test.cc +++ b/eval/public/source_position_test.cc @@ -14,7 +14,7 @@ #include "eval/public/source_position.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "internal/testing.h" namespace google { @@ -25,7 +25,7 @@ namespace runtime { namespace { using ::testing::Eq; -using google::api::expr::v1alpha1::SourceInfo; +using cel::expr::SourceInfo; class SourcePositionTest : public testing::Test { protected: diff --git a/eval/public/string_extension_func_registrar_test.cc b/eval/public/string_extension_func_registrar_test.cc index f1151d0e4..7fd6e746f 100644 --- a/eval/public/string_extension_func_registrar_test.cc +++ b/eval/public/string_extension_func_registrar_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "absl/types/span.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 2da148ef6..10c1441a7 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -144,7 +144,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -402,8 +402,8 @@ cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc b/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc index 9fd0fc295..2261dab83 100644 --- a/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc +++ b/eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc @@ -15,7 +15,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -30,7 +30,7 @@ #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" @@ -40,9 +40,9 @@ namespace google::api::expr::runtime { namespace { -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::google::protobuf::DescriptorPool; constexpr int32_t kStartingFieldNumber = 512; @@ -79,7 +79,7 @@ absl::Status AddTestTypes(DescriptorPool& pool) { dynamic_message_field->set_name("dynamic_message_field"); dynamic_message_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_MESSAGE); dynamic_message_field->set_type_name( - ".google.api.expr.test.v1.proto3.TestAllTypes"); + ".cel.expr.conformance.proto3.TestAllTypes"); CEL_RETURN_IF_ERROR(AddStandardMessageTypesToDescriptorPool(pool)); if (!pool.BuildFile(file_descriptor)) { @@ -101,7 +101,7 @@ class DynamicDescriptorPoolTest : public ::testing::Test { absl::string_view text_format) { const google::protobuf::Descriptor* dynamic_desc = descriptor_pool_.FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes"); + "cel.expr.conformance.proto3.TestAllTypes"); auto message = absl::WrapUnique(factory_.GetPrototype(dynamic_desc)->New()); if (!google::protobuf::TextFormat::ParseFromString(text_format, message.get())) { @@ -142,7 +142,7 @@ TEST_F(DynamicDescriptorPoolTest, Create) { std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); - builder->set_container("google.api.expr.test.v1.proto3"); + builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse( R"cel( @@ -176,7 +176,7 @@ TEST_F(DynamicDescriptorPoolTest, AnyUnpack) { ASSERT_OK_AND_ASSIGN( auto message, CreateMessageFromText(R"pb( single_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 45 } } @@ -229,12 +229,12 @@ TEST_F(DynamicDescriptorPoolTest, AnyUnpackRepeated) { ASSERT_OK_AND_ASSIGN( auto message, CreateMessageFromText(R"pb( repeated_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 0 } } repeated_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 1 } } @@ -259,7 +259,7 @@ TEST_F(DynamicDescriptorPoolTest, AnyPack) { std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); - builder->set_container("google.api.expr.test.v1.proto3"); + builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ @@ -274,7 +274,7 @@ TEST_F(DynamicDescriptorPoolTest, AnyPack) { ASSERT_OK_AND_ASSIGN( auto expected_message, CreateMessageFromText(R"pb( single_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 42 } } @@ -288,7 +288,7 @@ TEST_F(DynamicDescriptorPoolTest, AnyWrapperPack) { std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); - builder->set_container("google.api.expr.test.v1.proto3"); + builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ @@ -315,7 +315,7 @@ TEST_F(DynamicDescriptorPoolTest, AnyPackRepeated) { std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); - builder->set_container("google.api.expr.test.v1.proto3"); + builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ @@ -333,12 +333,12 @@ TEST_F(DynamicDescriptorPoolTest, AnyPackRepeated) { ASSERT_OK_AND_ASSIGN( auto expected_message, CreateMessageFromText(R"pb( repeated_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 0 } } repeated_any { - [type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes] { + [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 1 } } diff --git a/eval/public/structs/field_access_impl_test.cc b/eval/public/structs/field_access_impl_test.cc index 8c9ff918f..8947373e9 100644 --- a/eval/public/structs/field_access_impl_test.cc +++ b/eval/public/structs/field_access_impl_test.cc @@ -32,16 +32,16 @@ #include "internal/testing.h" #include "internal/time.h" #include "testutil/util.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace google::api::expr::runtime::internal { namespace { using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::MaxDuration; using ::cel::internal::MaxTimestamp; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::testing::HasSubstr; diff --git a/eval/public/transform_utility.cc b/eval/public/transform_utility.cc index 1fd1f9b21..bdc76712c 100644 --- a/eval/public/transform_utility.cc +++ b/eval/public/transform_utility.cc @@ -4,7 +4,7 @@ #include #include -#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/arena.h" diff --git a/eval/public/transform_utility.h b/eval/public/transform_utility.h index 2ec628505..9836bc5fe 100644 --- a/eval/public/transform_utility.h +++ b/eval/public/transform_utility.h @@ -1,7 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ -#include "google/api/expr/v1alpha1/value.pb.h" +#include "cel/expr/value.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -12,9 +12,9 @@ namespace api { namespace expr { namespace runtime { -using google::api::expr::v1alpha1::Value; +using cel::expr::Value; -// Translates a CelValue into a google::api::expr::v1alpha1::Value. Returns an error if +// Translates a CelValue into a cel::expr::Value. Returns an error if // translation is not supported. absl::Status CelValueToValue(const CelValue& value, Value* result, google::protobuf::Arena* arena); @@ -24,7 +24,7 @@ inline absl::Status CelValueToValue(const CelValue& value, Value* result) { return CelValueToValue(value, result, &arena); } -// Translates a google::api::expr::v1alpha1::Value into a CelValue. Allocates any required +// Translates a cel::expr::Value into a CelValue. Allocates any required // external data on the provided arena. Returns an error if translation is not // supported. absl::StatusOr ValueToCelValue(const Value& value, diff --git a/eval/public/unknown_attribute_set_test.cc b/eval/public/unknown_attribute_set_test.cc index 36a301ca6..efd27537f 100644 --- a/eval/public/unknown_attribute_set_test.cc +++ b/eval/public/unknown_attribute_set_test.cc @@ -17,7 +17,7 @@ namespace { using ::testing::Eq; -using google::api::expr::v1alpha1::Expr; +using cel::expr::Expr; TEST(UnknownAttributeSetTest, TestCreate) { const std::string kAttr1 = "a1"; diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc index 25922a773..26e1e1a15 100644 --- a/eval/public/unknown_set_test.cc +++ b/eval/public/unknown_set_test.cc @@ -2,7 +2,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/arena.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 38b99e48f..256aa26e9 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -35,7 +35,7 @@ cc_library( "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -113,7 +113,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -148,7 +148,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -172,7 +172,7 @@ cc_test( "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -198,8 +198,8 @@ cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -222,7 +222,7 @@ cc_test( "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -252,7 +252,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc index b70ec4899..2b442a12a 100644 --- a/eval/tests/allocation_benchmark_test.cc +++ b/eval/tests/allocation_benchmark_test.cc @@ -14,7 +14,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/text_format.h" #include "absl/base/attributes.h" @@ -43,7 +43,7 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index 53266f4eb..b62929428 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -4,7 +4,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/text_format.h" @@ -39,9 +39,9 @@ namespace runtime { namespace { -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::rpc::context::AttributeContext; InterpreterOptions GetOptions(google::protobuf::Arena& arena) { diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index e60db8fa1..56a652068 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -2,7 +2,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/text_format.h" #include "absl/status/status.h" @@ -25,8 +25,8 @@ namespace runtime { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::SourceInfo; using ::google::protobuf::Arena; using ::google::protobuf::TextFormat; diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index 468450749..7fc84697b 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -17,8 +17,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/text_format.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" @@ -37,8 +37,8 @@ namespace google::api::expr::runtime { namespace { -using google::api::expr::v1alpha1::CheckedExpr; -using google::api::expr::v1alpha1::ParsedExpr; +using cel::expr::CheckedExpr; +using cel::expr::ParsedExpr; enum BenchmarkParam : int { kDefault = 0, diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index 738c025be..35c397520 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" @@ -39,7 +39,7 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOkAndHolds; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using testutil::EqualsProto; diff --git a/eval/tests/modern_benchmark_test.cc b/eval/tests/modern_benchmark_test.cc index 7e320b0a4..81cf91ef0 100644 --- a/eval/tests/modern_benchmark_test.cc +++ b/eval/tests/modern_benchmark_test.cc @@ -22,7 +22,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/base/attributes.h" @@ -71,9 +71,9 @@ namespace { using ::absl_testing::IsOkAndHolds; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::extensions::ProtoMemoryManagerRef; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::RequestContext; using ::google::rpc::context::AttributeContext; diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc index 5d9cea55c..fb3f60585 100644 --- a/eval/tests/unknowns_end_to_end_test.cc +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -8,7 +8,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -37,8 +37,8 @@ namespace expr { namespace runtime { namespace { -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::ParsedExpr; +using cel::expr::Expr; +using cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::protobuf::Arena; using ::testing::ElementsAre; @@ -164,7 +164,7 @@ class UnknownsTest : public testing::Test { Arena arena_; Activation activation_; std::unique_ptr builder_; - google::api::expr::v1alpha1::Expr expr_; + cel::expr::Expr expr_; }; MATCHER_P(FunctionCallIs, fn_name, "") { diff --git a/extensions/BUILD b/extensions/BUILD index e83cabc91..ae34b1194 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -103,7 +103,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -186,8 +186,8 @@ cc_test( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -211,7 +211,7 @@ cc_test( "//parser:macro", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -287,7 +287,7 @@ cc_test( "//internal:testing", "//parser", "//runtime:runtime_options", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -319,7 +319,7 @@ cc_test( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -367,6 +367,6 @@ cc_test( "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:cord", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) diff --git a/extensions/bindings_ext_benchmark_test.cc b/extensions/bindings_ext_benchmark_test.cc index 8c4ccc603..52203d810 100644 --- a/extensions/bindings_ext_benchmark_test.cc +++ b/extensions/bindings_ext_benchmark_test.cc @@ -16,7 +16,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "eval/public/activation.h" diff --git a/extensions/bindings_ext_test.cc b/extensions/bindings_ext_test.cc index 0c40937ec..c8b12c24a 100644 --- a/extensions/bindings_ext_test.cc +++ b/extensions/bindings_ext_test.cc @@ -20,7 +20,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -38,7 +38,7 @@ #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" @@ -47,10 +47,11 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::conformance::proto2::NestedTestAllTypes; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; @@ -64,7 +65,6 @@ using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::UnknownProcessingOptions; using ::google::api::expr::runtime::test::IsCelInt64; -using ::google::api::expr::test::v1::proto2::NestedTestAllTypes; using ::google::protobuf::Arena; using ::google::protobuf::TextFormat; using ::testing::Contains; @@ -317,7 +317,7 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( } reference_map: { key: 9 - value: { name: "google.api.expr.test.v1.proto2.TestAllTypes" } + value: { name: "cel.expr.conformance.proto2.TestAllTypes" } } reference_map: { key: 13 @@ -329,15 +329,15 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( } type_map: { key: 4 - value: { message_type: "google.api.expr.test.v1.proto2.NestedTestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 5 - value: { message_type: "google.api.expr.test.v1.proto2.NestedTestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 6 - value: { message_type: "google.api.expr.test.v1.proto2.NestedTestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 7 @@ -349,7 +349,7 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( } type_map: { key: 9 - value: { message_type: "google.api.expr.test.v1.proto2.TestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } } type_map: { key: 11 @@ -361,11 +361,11 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( } type_map: { key: 13 - value: { message_type: "google.api.expr.test.v1.proto2.NestedTestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 14 - value: { message_type: "google.api.expr.test.v1.proto2.TestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } } type_map: { key: 15 @@ -381,7 +381,7 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( } type_map: { key: 18 - value: { message_type: "google.api.expr.test.v1.proto2.NestedTestAllTypes" } + value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 19 @@ -452,7 +452,7 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( operand: { id: 9 struct_expr: { - message_name: "google.api.expr.test.v1.proto2.TestAllTypes" + message_name: "cel.expr.conformance.proto2.TestAllTypes" entries: { id: 10 field_key: "single_int64" @@ -535,7 +535,7 @@ constexpr absl::string_view kFieldSelectTestExpr = R"pb( operand: { id: 9 struct_expr: { - message_name: "google.api.expr.test.v1.proto2.TestAllTypes" + message_name: "cel.expr.conformance.proto2.TestAllTypes" entries: { id: 10 field_key: "single_int64" diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index bc7c45023..fbd1635ff 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -16,7 +16,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -38,9 +38,9 @@ namespace cel::extensions { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index b6a302a6d..4e23471a0 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -61,8 +61,8 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -85,8 +85,8 @@ cc_test( "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -102,8 +102,8 @@ cc_library( "//runtime:runtime_builder", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -154,7 +154,7 @@ cc_test( "//common:type_testing", "//internal:testing", "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -209,7 +209,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -232,7 +232,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -268,7 +268,7 @@ cc_test( "//runtime:managed_value_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -298,7 +298,7 @@ cc_test( "//common:value_testing", "//internal:proto_matchers", "//internal:testing", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/ast_converters.cc b/extensions/protobuf/ast_converters.cc index 39d06dd6e..63b893940 100644 --- a/extensions/protobuf/ast_converters.cc +++ b/extensions/protobuf/ast_converters.cc @@ -20,8 +20,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -70,34 +70,34 @@ using ::cel::ast_internal::Type; using ::cel::ast_internal::UnspecifiedType; using ::cel::ast_internal::WellKnownType; -using ExprPb = google::api::expr::v1alpha1::Expr; -using ParsedExprPb = google::api::expr::v1alpha1::ParsedExpr; -using CheckedExprPb = google::api::expr::v1alpha1::CheckedExpr; -using ExtensionPb = google::api::expr::v1alpha1::SourceInfo::Extension; +using ExprPb = cel::expr::Expr; +using ParsedExprPb = cel::expr::ParsedExpr; +using CheckedExprPb = cel::expr::CheckedExpr; +using ExtensionPb = cel::expr::SourceInfo::Extension; absl::StatusOr ConvertConstant( - const google::api::expr::v1alpha1::Constant& constant) { + const cel::expr::Constant& constant) { switch (constant.constant_kind_case()) { - case google::api::expr::v1alpha1::Constant::CONSTANT_KIND_NOT_SET: + case cel::expr::Constant::CONSTANT_KIND_NOT_SET: return Constant(); - case google::api::expr::v1alpha1::Constant::kNullValue: + case cel::expr::Constant::kNullValue: return Constant(nullptr); - case google::api::expr::v1alpha1::Constant::kBoolValue: + case cel::expr::Constant::kBoolValue: return Constant(constant.bool_value()); - case google::api::expr::v1alpha1::Constant::kInt64Value: + case cel::expr::Constant::kInt64Value: return Constant(constant.int64_value()); - case google::api::expr::v1alpha1::Constant::kUint64Value: + case cel::expr::Constant::kUint64Value: return Constant(constant.uint64_value()); - case google::api::expr::v1alpha1::Constant::kDoubleValue: + case cel::expr::Constant::kDoubleValue: return Constant(constant.double_value()); - case google::api::expr::v1alpha1::Constant::kStringValue: + case cel::expr::Constant::kStringValue: return Constant(StringConstant{constant.string_value()}); - case google::api::expr::v1alpha1::Constant::kBytesValue: + case cel::expr::Constant::kBytesValue: return Constant(BytesConstant{constant.bytes_value()}); - case google::api::expr::v1alpha1::Constant::kDurationValue: + case cel::expr::Constant::kDurationValue: return Constant(absl::Seconds(constant.duration_value().seconds()) + absl::Nanoseconds(constant.duration_value().nanos())); - case google::api::expr::v1alpha1::Constant::kTimestampValue: + case cel::expr::Constant::kTimestampValue: return Constant( absl::FromUnixSeconds(constant.timestamp_value().seconds()) + absl::Nanoseconds(constant.timestamp_value().nanos())); @@ -107,14 +107,14 @@ absl::StatusOr ConvertConstant( } absl::StatusOr ConvertProtoExprToNative( - const google::api::expr::v1alpha1::Expr& expr) { + const cel::expr::Expr& expr) { Expr native_expr; CEL_RETURN_IF_ERROR(protobuf_internal::ExprFromProto(expr, native_expr)); return native_expr; } absl::StatusOr ConvertProtoSourceInfoToNative( - const google::api::expr::v1alpha1::SourceInfo& source_info) { + const cel::expr::SourceInfo& source_info) { absl::flat_hash_map macro_calls; for (const auto& pair : source_info.macro_calls()) { auto native_expr = ConvertProtoExprToNative(pair.second); @@ -160,49 +160,49 @@ absl::StatusOr ConvertProtoSourceInfoToNative( } absl::StatusOr ToNative( - google::api::expr::v1alpha1::Type::PrimitiveType primitive_type) { + cel::expr::Type::PrimitiveType primitive_type) { switch (primitive_type) { - case google::api::expr::v1alpha1::Type::PRIMITIVE_TYPE_UNSPECIFIED: + case cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED: return PrimitiveType::kPrimitiveTypeUnspecified; - case google::api::expr::v1alpha1::Type::BOOL: + case cel::expr::Type::BOOL: return PrimitiveType::kBool; - case google::api::expr::v1alpha1::Type::INT64: + case cel::expr::Type::INT64: return PrimitiveType::kInt64; - case google::api::expr::v1alpha1::Type::UINT64: + case cel::expr::Type::UINT64: return PrimitiveType::kUint64; - case google::api::expr::v1alpha1::Type::DOUBLE: + case cel::expr::Type::DOUBLE: return PrimitiveType::kDouble; - case google::api::expr::v1alpha1::Type::STRING: + case cel::expr::Type::STRING: return PrimitiveType::kString; - case google::api::expr::v1alpha1::Type::BYTES: + case cel::expr::Type::BYTES: return PrimitiveType::kBytes; default: return absl::InvalidArgumentError( "Illegal type specified for " - "google::api::expr::v1alpha1::Type::PrimitiveType."); + "cel::expr::Type::PrimitiveType."); } } absl::StatusOr ToNative( - google::api::expr::v1alpha1::Type::WellKnownType well_known_type) { + cel::expr::Type::WellKnownType well_known_type) { switch (well_known_type) { - case google::api::expr::v1alpha1::Type::WELL_KNOWN_TYPE_UNSPECIFIED: + case cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED: return WellKnownType::kWellKnownTypeUnspecified; - case google::api::expr::v1alpha1::Type::ANY: + case cel::expr::Type::ANY: return WellKnownType::kAny; - case google::api::expr::v1alpha1::Type::TIMESTAMP: + case cel::expr::Type::TIMESTAMP: return WellKnownType::kTimestamp; - case google::api::expr::v1alpha1::Type::DURATION: + case cel::expr::Type::DURATION: return WellKnownType::kDuration; default: return absl::InvalidArgumentError( "Illegal type specified for " - "google::api::expr::v1alpha1::Type::WellKnownType."); + "cel::expr::Type::WellKnownType."); } } absl::StatusOr ToNative( - const google::api::expr::v1alpha1::Type::ListType& list_type) { + const cel::expr::Type::ListType& list_type) { auto native_elem_type = ConvertProtoTypeToNative(list_type.elem_type()); if (!native_elem_type.ok()) { return native_elem_type.status(); @@ -211,7 +211,7 @@ absl::StatusOr ToNative( } absl::StatusOr ToNative( - const google::api::expr::v1alpha1::Type::MapType& map_type) { + const cel::expr::Type::MapType& map_type) { auto native_key_type = ConvertProtoTypeToNative(map_type.key_type()); if (!native_key_type.ok()) { return native_key_type.status(); @@ -225,7 +225,7 @@ absl::StatusOr ToNative( } absl::StatusOr ToNative( - const google::api::expr::v1alpha1::Type::FunctionType& function_type) { + const cel::expr::Type::FunctionType& function_type) { std::vector arg_types; arg_types.reserve(function_type.arg_types_size()); for (const auto& arg_type : function_type.arg_types()) { @@ -244,7 +244,7 @@ absl::StatusOr ToNative( } absl::StatusOr ToNative( - const google::api::expr::v1alpha1::Type::AbstractType& abstract_type) { + const cel::expr::Type::AbstractType& abstract_type) { std::vector parameter_types; for (const auto& parameter_type : abstract_type.parameter_types()) { auto native_parameter_type = ConvertProtoTypeToNative(parameter_type); @@ -257,61 +257,61 @@ absl::StatusOr ToNative( } absl::StatusOr ConvertProtoTypeToNative( - const google::api::expr::v1alpha1::Type& type) { + const cel::expr::Type& type) { switch (type.type_kind_case()) { - case google::api::expr::v1alpha1::Type::kDyn: + case cel::expr::Type::kDyn: return Type(DynamicType()); - case google::api::expr::v1alpha1::Type::kNull: + case cel::expr::Type::kNull: return Type(nullptr); - case google::api::expr::v1alpha1::Type::kPrimitive: { + case cel::expr::Type::kPrimitive: { auto native_primitive = ToNative(type.primitive()); if (!native_primitive.ok()) { return native_primitive.status(); } return Type(*(std::move(native_primitive))); } - case google::api::expr::v1alpha1::Type::kWrapper: { + case cel::expr::Type::kWrapper: { auto native_wrapper = ToNative(type.wrapper()); if (!native_wrapper.ok()) { return native_wrapper.status(); } return Type(PrimitiveTypeWrapper(*(std::move(native_wrapper)))); } - case google::api::expr::v1alpha1::Type::kWellKnown: { + case cel::expr::Type::kWellKnown: { auto native_well_known = ToNative(type.well_known()); if (!native_well_known.ok()) { return native_well_known.status(); } return Type(*std::move(native_well_known)); } - case google::api::expr::v1alpha1::Type::kListType: { + case cel::expr::Type::kListType: { auto native_list_type = ToNative(type.list_type()); if (!native_list_type.ok()) { return native_list_type.status(); } return Type(*(std::move(native_list_type))); } - case google::api::expr::v1alpha1::Type::kMapType: { + case cel::expr::Type::kMapType: { auto native_map_type = ToNative(type.map_type()); if (!native_map_type.ok()) { return native_map_type.status(); } return Type(*(std::move(native_map_type))); } - case google::api::expr::v1alpha1::Type::kFunction: { + case cel::expr::Type::kFunction: { auto native_function = ToNative(type.function()); if (!native_function.ok()) { return native_function.status(); } return Type(*(std::move(native_function))); } - case google::api::expr::v1alpha1::Type::kMessageType: + case cel::expr::Type::kMessageType: return Type(MessageType(type.message_type())); - case google::api::expr::v1alpha1::Type::kTypeParam: + case cel::expr::Type::kTypeParam: return Type(ParamType(type.type_param())); - case google::api::expr::v1alpha1::Type::kType: { + case cel::expr::Type::kType: { if (type.type().type_kind_case() == - google::api::expr::v1alpha1::Type::TypeKindCase::TYPE_KIND_NOT_SET) { + cel::expr::Type::TypeKindCase::TYPE_KIND_NOT_SET) { return Type(std::unique_ptr()); } auto native_type = ConvertProtoTypeToNative(type.type()); @@ -320,25 +320,25 @@ absl::StatusOr ConvertProtoTypeToNative( } return Type(std::make_unique(*std::move(native_type))); } - case google::api::expr::v1alpha1::Type::kError: + case cel::expr::Type::kError: return Type(ErrorType::kErrorTypeValue); - case google::api::expr::v1alpha1::Type::kAbstractType: { + case cel::expr::Type::kAbstractType: { auto native_abstract = ToNative(type.abstract_type()); if (!native_abstract.ok()) { return native_abstract.status(); } return Type(*(std::move(native_abstract))); } - case google::api::expr::v1alpha1::Type::TYPE_KIND_NOT_SET: + case cel::expr::Type::TYPE_KIND_NOT_SET: return Type(UnspecifiedType()); default: return absl::InvalidArgumentError( - "Illegal type specified for google::api::expr::v1alpha1::Type."); + "Illegal type specified for cel::expr::Type."); } } absl::StatusOr ConvertProtoReferenceToNative( - const google::api::expr::v1alpha1::Reference& reference) { + const cel::expr::Reference& reference) { Reference ret_val; ret_val.set_name(reference.name()); ret_val.mutable_overload_id().reserve(reference.overload_id_size()); @@ -387,13 +387,13 @@ using ::cel::ast_internal::Type; using ::cel::ast_internal::UnspecifiedType; using ::cel::ast_internal::WellKnownType; -using ExprPb = google::api::expr::v1alpha1::Expr; -using ParsedExprPb = google::api::expr::v1alpha1::ParsedExpr; -using CheckedExprPb = google::api::expr::v1alpha1::CheckedExpr; -using SourceInfoPb = google::api::expr::v1alpha1::SourceInfo; -using ExtensionPb = google::api::expr::v1alpha1::SourceInfo::Extension; -using ReferencePb = google::api::expr::v1alpha1::Reference; -using TypePb = google::api::expr::v1alpha1::Type; +using ExprPb = cel::expr::Expr; +using ParsedExprPb = cel::expr::ParsedExpr; +using CheckedExprPb = cel::expr::CheckedExpr; +using SourceInfoPb = cel::expr::SourceInfo; +using ExtensionPb = cel::expr::SourceInfo::Extension; +using ReferencePb = cel::expr::Reference; +using TypePb = cel::expr::Type; struct ToProtoStackEntry { absl::Nonnull source; @@ -401,7 +401,7 @@ struct ToProtoStackEntry { }; absl::Status ConstantToProto(const ast_internal::Constant& source, - google::api::expr::v1alpha1::Constant& dest) { + cel::expr::Constant& dest) { return absl::visit(absl::Overload( [&](absl::monostate) -> absl::Status { dest.clear_constant_kind(); @@ -658,8 +658,8 @@ absl::Status TypeToProto(const Type& type, TypePb* result) { } // namespace absl::StatusOr> CreateAstFromParsedExpr( - const google::api::expr::v1alpha1::Expr& expr, - const google::api::expr::v1alpha1::SourceInfo* source_info) { + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info) { CEL_ASSIGN_OR_RETURN(auto runtime_expr, internal::ConvertProtoExprToNative(expr)); cel::ast_internal::SourceInfo runtime_source_info; @@ -720,7 +720,7 @@ absl::StatusOr> CreateAstFromCheckedExpr( std::move(type_map), checked_expr.expr_version()); } -absl::StatusOr CreateCheckedExprFromAst( +absl::StatusOr CreateCheckedExprFromAst( const Ast& ast) { if (!ast.IsChecked()) { return absl::InvalidArgumentError("AST is not type-checked"); diff --git a/extensions/protobuf/ast_converters.h b/extensions/protobuf/ast_converters.h index 611c41c79..79bdc8a44 100644 --- a/extensions/protobuf/ast_converters.h +++ b/extensions/protobuf/ast_converters.h @@ -17,8 +17,8 @@ #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/ast_internal/expr.h" @@ -28,17 +28,17 @@ namespace internal { // Utilities for converting protobuf CEL message types to their corresponding // internal C++ representations. absl::StatusOr ConvertProtoExprToNative( - const google::api::expr::v1alpha1::Expr& expr); + const cel::expr::Expr& expr); absl::StatusOr ConvertProtoSourceInfoToNative( - const google::api::expr::v1alpha1::SourceInfo& source_info); + const cel::expr::SourceInfo& source_info); absl::StatusOr ConvertProtoTypeToNative( - const google::api::expr::v1alpha1::Type& type); + const cel::expr::Type& type); absl::StatusOr ConvertProtoReferenceToNative( - const google::api::expr::v1alpha1::Reference& reference); + const cel::expr::Reference& reference); // Conversion utility for the protobuf constant CEL value representation. absl::StatusOr ConvertConstant( - const google::api::expr::v1alpha1::Constant& constant); + const cel::expr::Constant& constant); } // namespace internal @@ -46,21 +46,21 @@ absl::StatusOr ConvertConstant( // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). absl::StatusOr> CreateAstFromParsedExpr( - const google::api::expr::v1alpha1::Expr& expr, - const google::api::expr::v1alpha1::SourceInfo* source_info = nullptr); + const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr); absl::StatusOr> CreateAstFromParsedExpr( - const google::api::expr::v1alpha1::ParsedExpr& parsed_expr); + const cel::expr::ParsedExpr& parsed_expr); -absl::StatusOr CreateParsedExprFromAst( +absl::StatusOr CreateParsedExprFromAst( const Ast& ast); // Creates a runtime AST from a checked protobuf AST. // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). absl::StatusOr> CreateAstFromCheckedExpr( - const google::api::expr::v1alpha1::CheckedExpr& checked_expr); + const cel::expr::CheckedExpr& checked_expr); -absl::StatusOr CreateCheckedExprFromAst( +absl::StatusOr CreateCheckedExprFromAst( const Ast& ast); } // namespace cel::extensions diff --git a/extensions/protobuf/ast_converters_test.cc b/extensions/protobuf/ast_converters_test.cc index 632f7a310..3cf01295f 100644 --- a/extensions/protobuf/ast_converters_test.cc +++ b/extensions/protobuf/ast_converters_test.cc @@ -19,8 +19,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" @@ -48,7 +48,7 @@ using ::cel::ast_internal::PrimitiveType; using ::cel::ast_internal::WellKnownType; TEST(AstConvertersTest, SourceInfoToNative) { - google::api::expr::v1alpha1::SourceInfo source_info; + cel::expr::SourceInfo source_info; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( syntax_version: "version" @@ -77,8 +77,8 @@ TEST(AstConvertersTest, SourceInfoToNative) { } TEST(AstConvertersTest, PrimitiveTypeUnspecifiedToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::PRIMITIVE_TYPE_UNSPECIFIED); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED); auto native_type = ConvertProtoTypeToNative(type); @@ -87,8 +87,8 @@ TEST(AstConvertersTest, PrimitiveTypeUnspecifiedToNative) { } TEST(AstConvertersTest, PrimitiveTypeBoolToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::BOOL); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); @@ -97,8 +97,8 @@ TEST(AstConvertersTest, PrimitiveTypeBoolToNative) { } TEST(AstConvertersTest, PrimitiveTypeInt64ToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::INT64); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::INT64); auto native_type = ConvertProtoTypeToNative(type); @@ -107,8 +107,8 @@ TEST(AstConvertersTest, PrimitiveTypeInt64ToNative) { } TEST(AstConvertersTest, PrimitiveTypeUint64ToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::UINT64); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::UINT64); auto native_type = ConvertProtoTypeToNative(type); @@ -117,8 +117,8 @@ TEST(AstConvertersTest, PrimitiveTypeUint64ToNative) { } TEST(AstConvertersTest, PrimitiveTypeDoubleToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::DOUBLE); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::DOUBLE); auto native_type = ConvertProtoTypeToNative(type); @@ -127,8 +127,8 @@ TEST(AstConvertersTest, PrimitiveTypeDoubleToNative) { } TEST(AstConvertersTest, PrimitiveTypeStringToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::STRING); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::STRING); auto native_type = ConvertProtoTypeToNative(type); @@ -137,8 +137,8 @@ TEST(AstConvertersTest, PrimitiveTypeStringToNative) { } TEST(AstConvertersTest, PrimitiveTypeBytesToNative) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(google::api::expr::v1alpha1::Type::BYTES); + cel::expr::Type type; + type.set_primitive(cel::expr::Type::BYTES); auto native_type = ConvertProtoTypeToNative(type); @@ -147,20 +147,20 @@ TEST(AstConvertersTest, PrimitiveTypeBytesToNative) { } TEST(AstConvertersTest, PrimitiveTypeError) { - google::api::expr::v1alpha1::Type type; - type.set_primitive(::google::api::expr::v1alpha1::Type_PrimitiveType(7)); + cel::expr::Type type; + type.set_primitive(::cel::expr::Type_PrimitiveType(7)); auto native_type = ConvertProtoTypeToNative(type); EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(native_type.status().message(), ::testing::HasSubstr("Illegal type specified for " - "google::api::expr::v1alpha1::Type::PrimitiveType.")); + "cel::expr::Type::PrimitiveType.")); } TEST(AstConvertersTest, WellKnownTypeUnspecifiedToNative) { - google::api::expr::v1alpha1::Type type; - type.set_well_known(google::api::expr::v1alpha1::Type::WELL_KNOWN_TYPE_UNSPECIFIED); + cel::expr::Type type; + type.set_well_known(cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED); auto native_type = ConvertProtoTypeToNative(type); @@ -170,8 +170,8 @@ TEST(AstConvertersTest, WellKnownTypeUnspecifiedToNative) { } TEST(AstConvertersTest, WellKnownTypeAnyToNative) { - google::api::expr::v1alpha1::Type type; - type.set_well_known(google::api::expr::v1alpha1::Type::ANY); + cel::expr::Type type; + type.set_well_known(cel::expr::Type::ANY); auto native_type = ConvertProtoTypeToNative(type); @@ -180,8 +180,8 @@ TEST(AstConvertersTest, WellKnownTypeAnyToNative) { } TEST(AstConvertersTest, WellKnownTypeTimestampToNative) { - google::api::expr::v1alpha1::Type type; - type.set_well_known(google::api::expr::v1alpha1::Type::TIMESTAMP); + cel::expr::Type type; + type.set_well_known(cel::expr::Type::TIMESTAMP); auto native_type = ConvertProtoTypeToNative(type); @@ -190,8 +190,8 @@ TEST(AstConvertersTest, WellKnownTypeTimestampToNative) { } TEST(AstConvertersTest, WellKnownTypeDuraionToNative) { - google::api::expr::v1alpha1::Type type; - type.set_well_known(google::api::expr::v1alpha1::Type::DURATION); + cel::expr::Type type; + type.set_well_known(cel::expr::Type::DURATION); auto native_type = ConvertProtoTypeToNative(type); @@ -200,21 +200,21 @@ TEST(AstConvertersTest, WellKnownTypeDuraionToNative) { } TEST(AstConvertersTest, WellKnownTypeError) { - google::api::expr::v1alpha1::Type type; - type.set_well_known(::google::api::expr::v1alpha1::Type_WellKnownType(4)); + cel::expr::Type type; + type.set_well_known(::cel::expr::Type_WellKnownType(4)); auto native_type = ConvertProtoTypeToNative(type); EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(native_type.status().message(), ::testing::HasSubstr("Illegal type specified for " - "google::api::expr::v1alpha1::Type::WellKnownType.")); + "cel::expr::Type::WellKnownType.")); } TEST(AstConvertersTest, ListTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.mutable_list_type()->mutable_elem_type()->set_primitive( - google::api::expr::v1alpha1::Type::BOOL); + cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); @@ -225,7 +225,7 @@ TEST(AstConvertersTest, ListTypeToNative) { } TEST(AstConvertersTest, MapTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( map_type { @@ -246,7 +246,7 @@ TEST(AstConvertersTest, MapTypeToNative) { } TEST(AstConvertersTest, FunctionTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( function { @@ -273,7 +273,7 @@ TEST(AstConvertersTest, FunctionTypeToNative) { } TEST(AstConvertersTest, AbstractTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( abstract_type { @@ -298,7 +298,7 @@ TEST(AstConvertersTest, AbstractTypeToNative) { } TEST(AstConvertersTest, DynamicTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.mutable_dyn(); auto native_type = ConvertProtoTypeToNative(type); @@ -307,7 +307,7 @@ TEST(AstConvertersTest, DynamicTypeToNative) { } TEST(AstConvertersTest, NullTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.set_null(google::protobuf::NULL_VALUE); auto native_type = ConvertProtoTypeToNative(type); @@ -317,8 +317,8 @@ TEST(AstConvertersTest, NullTypeToNative) { } TEST(AstConvertersTest, PrimitiveTypeWrapperToNative) { - google::api::expr::v1alpha1::Type type; - type.set_wrapper(google::api::expr::v1alpha1::Type::BOOL); + cel::expr::Type type; + type.set_wrapper(cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); @@ -327,7 +327,7 @@ TEST(AstConvertersTest, PrimitiveTypeWrapperToNative) { } TEST(AstConvertersTest, MessageTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.set_message_type("message"); auto native_type = ConvertProtoTypeToNative(type); @@ -337,7 +337,7 @@ TEST(AstConvertersTest, MessageTypeToNative) { } TEST(AstConvertersTest, ParamTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.set_type_param("param"); auto native_type = ConvertProtoTypeToNative(type); @@ -347,7 +347,7 @@ TEST(AstConvertersTest, ParamTypeToNative) { } TEST(AstConvertersTest, NestedTypeToNative) { - google::api::expr::v1alpha1::Type type; + cel::expr::Type type; type.mutable_type()->mutable_dyn(); auto native_type = ConvertProtoTypeToNative(type); @@ -357,7 +357,7 @@ TEST(AstConvertersTest, NestedTypeToNative) { } TEST(AstConvertersTest, TypeTypeDefault) { - auto native_type = ConvertProtoTypeToNative(google::api::expr::v1alpha1::Type()); + auto native_type = ConvertProtoTypeToNative(cel::expr::Type()); ASSERT_THAT(native_type, IsOk()); EXPECT_TRUE(absl::holds_alternative( @@ -365,7 +365,7 @@ TEST(AstConvertersTest, TypeTypeDefault) { } TEST(AstConvertersTest, ReferenceToNative) { - google::api::expr::v1alpha1::Reference reference; + cel::expr::Reference reference; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( name: "name" @@ -394,9 +394,9 @@ using ::cel::internal::test::EqualsProto; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; -using ParsedExprPb = google::api::expr::v1alpha1::ParsedExpr; -using CheckedExprPb = google::api::expr::v1alpha1::CheckedExpr; -using TypePb = google::api::expr::v1alpha1::Type; +using ParsedExprPb = cel::expr::ParsedExpr; +using CheckedExprPb = cel::expr::CheckedExpr; +using TypePb = cel::expr::Type; TEST(AstConvertersTest, CheckedExprToAst) { CheckedExprPb checked_expr; @@ -667,7 +667,7 @@ TEST(AstConvertersTest, AstToParsedExprBasic) { } TEST(AstConvertersTest, ExprToAst) { - google::api::expr::v1alpha1::Expr expr; + cel::expr::Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( ident_expr { name: "expr" } @@ -679,8 +679,8 @@ TEST(AstConvertersTest, ExprToAst) { } TEST(AstConvertersTest, ExprAndSourceInfoToAst) { - google::api::expr::v1alpha1::Expr expr; - google::api::expr::v1alpha1::SourceInfo source_info; + cel::expr::Expr expr; + cel::expr::SourceInfo source_info; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc index 83b7faf01..de15c0a15 100644 --- a/extensions/protobuf/bind_proto_to_activation_test.cc +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -25,7 +25,7 @@ #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/managed_value_factory.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel::extensions { @@ -33,8 +33,8 @@ namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::test::IntValueIs; -using ::google::api::expr::test::v1::proto2::TestAllTypes; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Optional; diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index b9e560074..f8fb80eac 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -35,7 +35,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -49,7 +49,7 @@ cc_test( "//internal:proto_matchers", "//internal:testing", "@com_google_absl//absl/status", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -67,7 +67,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/internal/ast.cc b/extensions/protobuf/internal/ast.cc index 0ac4bb963..e6972317c 100644 --- a/extensions/protobuf/internal/ast.cc +++ b/extensions/protobuf/internal/ast.cc @@ -20,7 +20,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" @@ -37,20 +37,20 @@ namespace cel::extensions::protobuf_internal { namespace { -using ExprProto = google::api::expr::v1alpha1::Expr; -using ConstantProto = google::api::expr::v1alpha1::Constant; -using StructExprProto = google::api::expr::v1alpha1::Expr::CreateStruct; +using ExprProto = cel::expr::Expr; +using ConstantProto = cel::expr::Constant; +using StructExprProto = cel::expr::Expr::CreateStruct; class ExprToProtoState final { private: struct Frame final { absl::Nonnull expr; - absl::Nonnull proto; + absl::Nonnull proto; }; public: absl::Status ExprToProto(const Expr& expr, - absl::Nonnull proto) { + absl::Nonnull proto) { Push(expr, proto); Frame frame; while (Pop(frame)) { @@ -61,7 +61,7 @@ class ExprToProtoState final { private: absl::Status ExprToProtoImpl(const Expr& expr, - absl::Nonnull proto) { + absl::Nonnull proto) { return absl::visit( absl::Overload( [&expr, proto](const UnspecifiedExpr&) -> absl::Status { @@ -499,12 +499,12 @@ class ExprFromProtoState final { } // namespace absl::Status ExprToProto(const Expr& expr, - absl::Nonnull proto) { + absl::Nonnull proto) { ExprToProtoState state; return state.ExprToProto(expr, proto); } -absl::Status ExprFromProto(const google::api::expr::v1alpha1::Expr& proto, Expr& expr) { +absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr) { ExprFromProtoState state; return state.ExprFromProto(proto, expr); } diff --git a/extensions/protobuf/internal/ast.h b/extensions/protobuf/internal/ast.h index d43217e34..9e6aa79ea 100644 --- a/extensions/protobuf/internal/ast.h +++ b/extensions/protobuf/internal/ast.h @@ -15,7 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_AST_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_AST_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/expr.h" @@ -23,9 +23,9 @@ namespace cel::extensions::protobuf_internal { absl::Status ExprToProto(const Expr& expr, - absl::Nonnull proto); + absl::Nonnull proto); -absl::Status ExprFromProto(const google::api::expr::v1alpha1::Expr& proto, Expr& expr); +absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr); } // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/internal/ast_test.cc b/extensions/protobuf/internal/ast_test.cc index ba4ad6ce6..243d75920 100644 --- a/extensions/protobuf/internal/ast_test.cc +++ b/extensions/protobuf/internal/ast_test.cc @@ -16,7 +16,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "common/ast.h" #include "internal/proto_matchers.h" @@ -30,7 +30,7 @@ using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::internal::test::EqualsProto; -using ExprProto = google::api::expr::v1alpha1::Expr; +using ExprProto = cel::expr::Expr; struct ExprRoundtripTestCase { std::string input; diff --git a/extensions/protobuf/internal/constant.cc b/extensions/protobuf/internal/constant.cc index 83c7d9279..40c85a78f 100644 --- a/extensions/protobuf/internal/constant.cc +++ b/extensions/protobuf/internal/constant.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" @@ -30,7 +30,7 @@ namespace cel::extensions::protobuf_internal { -using ConstantProto = google::api::expr::v1alpha1::Constant; +using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, absl::Nonnull proto) { diff --git a/extensions/protobuf/internal/constant.h b/extensions/protobuf/internal/constant.h index b55345545..0ea87cdd0 100644 --- a/extensions/protobuf/internal/constant.h +++ b/extensions/protobuf/internal/constant.h @@ -15,7 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_CONSTANT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_CONSTANT_H_ -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/constant.h" @@ -25,11 +25,11 @@ namespace cel::extensions::protobuf_internal { // `ConstantToProto` converts from native `Constant` to its protocol buffer // message equivalent. absl::Status ConstantToProto(const Constant& constant, - absl::Nonnull proto); + absl::Nonnull proto); // `ConstantToProto` converts to native `Constant` from its protocol buffer // message equivalent. -absl::Status ConstantFromProto(const google::api::expr::v1alpha1::Constant& proto, +absl::Status ConstantFromProto(const cel::expr::Constant& proto, Constant& constant); } // namespace cel::extensions::protobuf_internal diff --git a/extensions/protobuf/runtime_adapter.cc b/extensions/protobuf/runtime_adapter.cc index 4da274b50..ca9f9354a 100644 --- a/extensions/protobuf/runtime_adapter.cc +++ b/extensions/protobuf/runtime_adapter.cc @@ -17,8 +17,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" @@ -28,7 +28,7 @@ namespace cel::extensions { absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::CheckedExpr& expr, + const Runtime& runtime, const cel::expr::CheckedExpr& expr, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(expr)); return runtime.CreateTraceableProgram(std::move(ast), options); @@ -36,7 +36,7 @@ ProtobufRuntimeAdapter::CreateProgram( absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::ParsedExpr& expr, + const Runtime& runtime, const cel::expr::ParsedExpr& expr, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr)); return runtime.CreateTraceableProgram(std::move(ast), options); @@ -44,8 +44,8 @@ ProtobufRuntimeAdapter::CreateProgram( absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::Expr& expr, - const google::api::expr::v1alpha1::SourceInfo* source_info, + const Runtime& runtime, const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr, source_info)); return runtime.CreateTraceableProgram(std::move(ast), options); diff --git a/extensions/protobuf/runtime_adapter.h b/extensions/protobuf/runtime_adapter.h index 48854cfe9..49af58a07 100644 --- a/extensions/protobuf/runtime_adapter.h +++ b/extensions/protobuf/runtime_adapter.h @@ -17,8 +17,8 @@ #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "runtime/runtime.h" @@ -35,14 +35,14 @@ class ProtobufRuntimeAdapter { ProtobufRuntimeAdapter() = delete; static absl::StatusOr> CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::CheckedExpr& expr, + const Runtime& runtime, const cel::expr::CheckedExpr& expr, const Runtime::CreateProgramOptions options = {}); static absl::StatusOr> CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::ParsedExpr& expr, + const Runtime& runtime, const cel::expr::ParsedExpr& expr, const Runtime::CreateProgramOptions options = {}); static absl::StatusOr> CreateProgram( - const Runtime& runtime, const google::api::expr::v1alpha1::Expr& expr, - const google::api::expr::v1alpha1::SourceInfo* source_info = nullptr, + const Runtime& runtime, const cel::expr::Expr& expr, + const cel::expr::SourceInfo* source_info = nullptr, const Runtime::CreateProgramOptions options = {}); }; diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc index 35cb0a5e3..0ea783fe1 100644 --- a/extensions/protobuf/type_introspector_test.cc +++ b/extensions/protobuf/type_introspector_test.cc @@ -19,14 +19,14 @@ #include "common/type_kind.h" #include "common/type_testing.h" #include "internal/testing.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/descriptor.h" namespace cel::extensions { namespace { using ::absl_testing::IsOkAndHolds; -using ::google::api::expr::test::v1::proto2::TestAllTypes; +using ::cel::expr::conformance::proto2::TestAllTypes; using ::testing::Eq; using ::testing::Optional; @@ -69,8 +69,8 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstant) { ASSERT_OK_AND_ASSIGN( auto enum_constant, introspector.FindEnumConstant( - type_manager(), - "google.api.expr.test.v1.proto2.TestAllTypes.NestedEnum", "BAZ")); + type_manager(), "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "BAZ")); ASSERT_TRUE(enum_constant.has_value()); EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); @@ -106,8 +106,8 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantUnknownValue) { ASSERT_OK_AND_ASSIGN( auto enum_constant, introspector.FindEnumConstant( - type_manager(), - "google.api.expr.test.v1.proto2.TestAllTypes.NestedEnum", "QUX")); + type_manager(), "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "QUX")); ASSERT_FALSE(enum_constant.has_value()); } diff --git a/extensions/protobuf/type_reflector_test.cc b/extensions/protobuf/type_reflector_test.cc index d51861650..b56047b90 100644 --- a/extensions/protobuf/type_reflector_test.cc +++ b/extensions/protobuf/type_reflector_test.cc @@ -22,13 +22,13 @@ #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" namespace cel::extensions { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::test::v1::proto2::TestAllTypes; +using ::cel::expr::conformance::proto2::TestAllTypes; using ::testing::IsNull; using ::testing::NotNull; diff --git a/extensions/protobuf/value_end_to_end_test.cc b/extensions/protobuf/value_end_to_end_test.cc index e1c2b1841..7e90347d1 100644 --- a/extensions/protobuf/value_end_to_end_test.cc +++ b/extensions/protobuf/value_end_to_end_test.cc @@ -34,7 +34,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/text_format.h" @@ -42,6 +42,7 @@ namespace cel::extensions { namespace { using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; @@ -56,9 +57,8 @@ using ::cel::test::StructValueIs; using ::cel::test::TimestampValueIs; using ::cel::test::UintValueIs; using ::cel::test::ValueMatcher; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; -using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::testing::_; using ::testing::AnyOf; using ::testing::HasSubstr; diff --git a/extensions/protobuf/value_test.cc b/extensions/protobuf/value_test.cc index 3f74f0a6f..279fd378a 100644 --- a/extensions/protobuf/value_test.cc +++ b/extensions/protobuf/value_test.cc @@ -38,7 +38,7 @@ #include "common/value_testing.h" #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" @@ -47,6 +47,7 @@ namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; +using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; @@ -62,7 +63,6 @@ using ::cel::test::StructValueIs; using ::cel::test::TimestampValueIs; using ::cel::test::UintValueIs; using ::cel::test::ValueKindIs; -using ::google::api::expr::test::v1::proto2::TestAllTypes; using ::testing::_; using ::testing::ElementsAre; using ::testing::Eq; diff --git a/extensions/protobuf/value_testing_test.cc b/extensions/protobuf/value_testing_test.cc index eaa109d1b..edd594d2c 100644 --- a/extensions/protobuf/value_testing_test.cc +++ b/extensions/protobuf/value_testing_test.cc @@ -21,15 +21,15 @@ #include "extensions/protobuf/value.h" #include "internal/proto_matchers.h" #include "internal/testing.h" -#include "proto/test/v1/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel::extensions::test { namespace { +using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::extensions::ProtoMessageToValue; using ::cel::internal::test::EqualsProto; -using ::google::api::expr::test::v1::proto2::TestAllTypes; class ProtoValueTesting : public common_internal::ThreadCompatibleValueTest<> { protected: diff --git a/extensions/sets_functions_benchmark_test.cc b/extensions/sets_functions_benchmark_test.cc index 1ea2ee3d8..401b8e638 100644 --- a/extensions/sets_functions_benchmark_test.cc +++ b/extensions/sets_functions_benchmark_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -50,7 +50,7 @@ namespace cel::extensions { namespace { using ::cel::Value; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; diff --git a/extensions/sets_functions_test.cc b/extensions/sets_functions_test.cc index c1c6780e7..4f0376a76 100644 --- a/extensions/sets_functions_test.cc +++ b/extensions/sets_functions_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" @@ -32,9 +32,9 @@ namespace cel::extensions { namespace { -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; -using ::google::api::expr::v1alpha1::SourceInfo; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; +using ::cel::expr::SourceInfo; using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index 0dcc99d9d..8174f4e66 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "common/memory.h" @@ -38,7 +38,7 @@ namespace cel::extensions { namespace { using ::absl_testing::IsOk; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; diff --git a/internal/BUILD b/internal/BUILD index 18064b629..5dd0d26ad 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -549,14 +549,14 @@ cel_proto_transitive_descriptor_set( name = "testing_descriptor_set", testonly = True, deps = [ + "@com_google_cel_spec//proto/cel/expr:checked_proto", + "@com_google_cel_spec//proto/cel/expr:eval_proto", + "@com_google_cel_spec//proto/cel/expr:explain_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", - "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_proto", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:eval_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:explain_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:value_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_proto", + "@com_google_cel_spec//proto/cel/expr:value_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", "@com_google_protobuf//:any_proto", "@com_google_protobuf//:duration_proto", "@com_google_protobuf//:empty_proto", @@ -720,7 +720,7 @@ cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -768,7 +768,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -819,7 +819,7 @@ cc_test( "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/internal/json_test.cc b/internal/json_test.cc index 96df6d0c2..262d31fa3 100644 --- a/internal/json_test.cc +++ b/internal/json_test.cc @@ -33,7 +33,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -50,7 +50,7 @@ using ::testing::HasSubstr; using ::testing::Test; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class CheckJsonTest : public Test { public: @@ -622,7 +622,7 @@ TEST_F(MessageToJsonTest, Any_Generated) { EXPECT_THAT( MessageToJson( *DynamicParseTextProto( - R"pb(type_url: "type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes" + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" value: "\x68\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); @@ -633,7 +633,7 @@ TEST_F(MessageToJsonTest, Any_Generated) { fields { key: "@type" value: { - string_value: "type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes" + string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" } } fields { @@ -648,7 +648,7 @@ TEST_F(MessageToJsonTest, Any_Dynamic) { EXPECT_THAT( MessageToJson( *DynamicParseTextProto( - R"pb(type_url: "type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes" + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" value: "\x68\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); @@ -659,7 +659,7 @@ TEST_F(MessageToJsonTest, Any_Dynamic) { fields { key: "@type" value: { - string_value: "type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes" + string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" } } fields { @@ -2051,32 +2051,32 @@ class MessageFieldToJsonTest : public Test { TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Generated) { auto* result = MakeGenerated(); - EXPECT_THAT(MessageFieldToJson( - *DynamicParseTextProto( - R"pb(single_bool: true)pb"), - ABSL_DIE_IF_NULL( - ABSL_DIE_IF_NULL( - descriptor_pool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) - ->FindFieldByName("single_bool")), - descriptor_pool(), message_factory(), result), - IsOk()); + EXPECT_THAT( + MessageFieldToJson( + *DynamicParseTextProto( + R"pb(single_bool: true)pb"), + ABSL_DIE_IF_NULL( + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("single_bool")), + descriptor_pool(), message_factory(), result), + IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Dynamic) { auto* result = MakeDynamic(); - EXPECT_THAT(MessageFieldToJson( - *DynamicParseTextProto( - R"pb(single_bool: true)pb"), - ABSL_DIE_IF_NULL( - ABSL_DIE_IF_NULL( - descriptor_pool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes")) - ->FindFieldByName("single_bool")), - descriptor_pool(), message_factory(), result), - IsOk()); + EXPECT_THAT( + MessageFieldToJson( + *DynamicParseTextProto( + R"pb(single_bool: true)pb"), + ABSL_DIE_IF_NULL( + ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")) + ->FindFieldByName("single_bool")), + descriptor_pool(), message_factory(), result), + IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index 0394b539e..cc448c7bd 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -39,7 +39,7 @@ #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "internal/well_known_types.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -54,7 +54,7 @@ using ::testing::TestParamInfo; using ::testing::TestWithParam; using ::testing::ValuesIn; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; template Owned ParseTextProto(absl::string_view text) { diff --git a/internal/testing_descriptor_pool_test.cc b/internal/testing_descriptor_pool_test.cc index d31ff2d15..093ce8beb 100644 --- a/internal/testing_descriptor_pool_test.cc +++ b/internal/testing_descriptor_pool_test.cc @@ -161,13 +161,13 @@ TEST(TestingDescriptorPool, Empty) { TEST(TestingDescriptorPool, TestAllTypesProto2) { EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto2.TestAllTypes"), + "cel.expr.conformance.proto2.TestAllTypes"), NotNull()); } TEST(TestingDescriptorPool, TestAllTypesProto3) { EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( - "google.api.expr.test.v1.proto3.TestAllTypes"), + "cel.expr.conformance.proto3.TestAllTypes"), NotNull()); } diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc index 0447fda90..80d033477 100644 --- a/internal/well_known_types_test.cc +++ b/internal/well_known_types_test.cc @@ -42,7 +42,7 @@ #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" -#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -63,7 +63,7 @@ using ::testing::NotNull; using ::testing::Test; using ::testing::VariantWith; -using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; +using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ReflectionTest : public Test { public: @@ -904,7 +904,7 @@ TEST_F(AdaptFromMessageTest, Any_Struct) { TEST_F(AdaptFromMessageTest, Any_TestAllTypesProto3) { auto message = DynamicParseTextProto( - R"pb(type_url: "type.googleapis.com/google.api.expr.test.v1.proto3.TestAllTypes")pb"); + R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith>(NotNull()))); } diff --git a/parser/BUILD b/parser/BUILD index 9fb4b6ab1..f5e91a5d1 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -57,7 +57,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -181,7 +181,7 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) diff --git a/parser/parser.cc b/parser/parser.cc index fe47b9223..317663f80 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -30,7 +30,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/container/btree_map.h" @@ -423,7 +423,7 @@ using ::cel_parser_internal::CelLexer; using ::cel_parser_internal::CelParser; using common::CelOperator; using common::ReverseLookupOperator; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; class CodePointStream final : public CharStream { public: @@ -637,7 +637,7 @@ class ParserVisitor final : public CelBaseVisitor, std::any visitBoolTrue(CelParser::BoolTrueContext* ctx) override; std::any visitBoolFalse(CelParser::BoolFalseContext* ctx) override; std::any visitNull(CelParser::NullContext* ctx) override; - absl::Status GetSourceInfo(google::api::expr::v1alpha1::SourceInfo* source_info) const; + absl::Status GetSourceInfo(cel::expr::SourceInfo* source_info) const; EnrichedSourceInfo enriched_source_info() const; void syntaxError(antlr4::Recognizer* recognizer, antlr4::Token* offending_symbol, size_t line, size_t col, @@ -1345,7 +1345,7 @@ std::any ParserVisitor::visitNull(CelParser::NullContext* ctx) { } absl::Status ParserVisitor::GetSourceInfo( - google::api::expr::v1alpha1::SourceInfo* source_info) const { + cel::expr::SourceInfo* source_info) const { source_info->set_location(source_.description()); for (const auto& positions : factory_.positions()) { source_info->mutable_positions()->insert( @@ -1356,7 +1356,7 @@ absl::Status ParserVisitor::GetSourceInfo( source_info->mutable_line_offsets()->Add(line_offset); } for (const auto& macro_call : factory_.macro_calls()) { - google::api::expr::v1alpha1::Expr macro_call_proto; + cel::expr::Expr macro_call_proto; CEL_RETURN_IF_ERROR(cel::extensions::protobuf_internal::ExprToProto( macro_call.second, ¯o_call_proto)); source_info->mutable_macro_calls()->insert( @@ -1684,7 +1684,7 @@ absl::StatusOr EnrichedParse( } } -absl::StatusOr Parse( +absl::StatusOr Parse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { CEL_ASSIGN_OR_RETURN(auto verbose_expr, diff --git a/parser/parser.h b/parser/parser.h index 8b3347c1f..24229bc2a 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -24,7 +24,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/source.h" @@ -37,12 +37,12 @@ namespace google::api::expr::parser { class VerboseParsedExpr { public: - VerboseParsedExpr(google::api::expr::v1alpha1::ParsedExpr parsed_expr, + VerboseParsedExpr(cel::expr::ParsedExpr parsed_expr, EnrichedSourceInfo enriched_source_info) : parsed_expr_(std::move(parsed_expr)), enriched_source_info_(std::move(enriched_source_info)) {} - const google::api::expr::v1alpha1::ParsedExpr& parsed_expr() const { + const cel::expr::ParsedExpr& parsed_expr() const { return parsed_expr_; } const EnrichedSourceInfo& enriched_source_info() const { @@ -50,7 +50,7 @@ class VerboseParsedExpr { } private: - google::api::expr::v1alpha1::ParsedExpr parsed_expr_; + cel::expr::ParsedExpr parsed_expr_; EnrichedSourceInfo enriched_source_info_; }; @@ -63,13 +63,13 @@ absl::StatusOr EnrichedParse( // See comments at the top of the file for information about usage during C++ // static initialization. -absl::StatusOr Parse( +absl::StatusOr Parse( absl::string_view expression, absl::string_view description = "", const ParserOptions& options = ParserOptions()); // See comments at the top of the file for information about usage during C++ // static initialization. -absl::StatusOr ParseWithMacros( +absl::StatusOr ParseWithMacros( absl::string_view expression, const std::vector& macros, absl::string_view description = "", const ParserOptions& options = ParserOptions()); @@ -82,7 +82,7 @@ absl::StatusOr EnrichedParse( // See comments at the top of the file for information about usage during C++ // static initialization. -absl::StatusOr Parse( +absl::StatusOr Parse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options = ParserOptions()); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 34b59b56c..9ca16b0d0 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -21,7 +21,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status_matchers.h" #include "absl/strings/ascii.h" @@ -47,7 +47,7 @@ using ::absl_testing::IsOk; using ::cel::ConstantKindCase; using ::cel::ExprKindCase; using ::cel::test::ExprPrinter; -using ::google::api::expr::v1alpha1::Expr; +using ::cel::expr::Expr; using ::testing::HasSubstr; using ::testing::Not; @@ -1272,8 +1272,8 @@ class KindAndIdAdorner : public cel::test::ExpressionAdorner { // will prevent macro_calls lookups from interfering with adorning expressions // that don't need to use macro_calls, such as the parsed AST. explicit KindAndIdAdorner( - const google::api::expr::v1alpha1::SourceInfo& source_info = - google::api::expr::v1alpha1::SourceInfo::default_instance()) + const cel::expr::SourceInfo& source_info = + cel::expr::SourceInfo::default_instance()) : source_info_(source_info) {} std::string Adorn(const cel::Expr& e) const override { @@ -1302,12 +1302,12 @@ class KindAndIdAdorner : public cel::test::ExpressionAdorner { } private: - const google::api::expr::v1alpha1::SourceInfo& source_info_; + const cel::expr::SourceInfo& source_info_; }; class LocationAdorner : public cel::test::ExpressionAdorner { public: - explicit LocationAdorner(const google::api::expr::v1alpha1::SourceInfo& source_info) + explicit LocationAdorner(const cel::expr::SourceInfo& source_info) : source_info_(source_info) {} std::string Adorn(const cel::Expr& e) const override { @@ -1355,7 +1355,7 @@ class LocationAdorner : public cel::test::ExpressionAdorner { return std::make_pair(line, col); } - const google::api::expr::v1alpha1::SourceInfo& source_info_; + const cel::expr::SourceInfo& source_info_; }; std::string ConvertEnrichedSourceInfoToString( @@ -1369,11 +1369,11 @@ std::string ConvertEnrichedSourceInfoToString( } std::string ConvertMacroCallsToString( - const google::api::expr::v1alpha1::SourceInfo& source_info) { + const cel::expr::SourceInfo& source_info) { KindAndIdAdorner macro_calls_adorner(source_info); ExprPrinter w(macro_calls_adorner); // Use a list so we can sort the macro calls ensuring order for appending - std::vector> macro_calls; + std::vector> macro_calls; for (auto pair : source_info.macro_calls()) { // Set ID to the map key for the adorner pair.second.set_id(pair.first); @@ -1381,8 +1381,8 @@ std::string ConvertMacroCallsToString( } // Sort in reverse because the first macro will have the highest id absl::c_sort(macro_calls, - [](const std::pair& p1, - const std::pair& p2) { + [](const std::pair& p1, + const std::pair& p2) { return p1.first > p2.first; }); std::string result = ""; diff --git a/runtime/BUILD b/runtime/BUILD index e5cb7f268..6dd16eeda 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -266,7 +266,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -337,7 +337,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -382,7 +382,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -425,8 +425,8 @@ cc_test( "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -470,7 +470,7 @@ cc_test( "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -564,7 +564,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/runtime/comprehension_vulnerability_check_test.cc b/runtime/comprehension_vulnerability_check_test.cc index 3ded61824..ba9c7572a 100644 --- a/runtime/comprehension_vulnerability_check_test.cc +++ b/runtime/comprehension_vulnerability_check_test.cc @@ -16,7 +16,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "extensions/protobuf/runtime_adapter.h" @@ -34,7 +34,7 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::protobuf::TextFormat; using ::testing::HasSubstr; diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc index 13145a4b4..af3010b62 100644 --- a/runtime/constant_folding_test.cc +++ b/runtime/constant_folding_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -39,7 +39,7 @@ namespace cel::extensions { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc index 18ea1841a..a2381c9e8 100644 --- a/runtime/optional_types_test.cc +++ b/runtime/optional_types_test.cc @@ -21,7 +21,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -58,7 +58,7 @@ using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::cel::test::OptionalValueIs; using ::cel::test::OptionalValueIsEmpty; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; using ::testing::ElementsAre; diff --git a/runtime/reference_resolver_test.cc b/runtime/reference_resolver_test.cc index 3afcae2f6..2f6a7f483 100644 --- a/runtime/reference_resolver_test.cc +++ b/runtime/reference_resolver_test.cc @@ -16,8 +16,8 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/checked.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/function_adapter.h" @@ -38,9 +38,9 @@ namespace cel { namespace { using ::cel::extensions::ProtobufRuntimeAdapter; -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Expr; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Expr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; @@ -260,12 +260,12 @@ TEST(ReferenceResolver, ResolveQualifiedIdentifiersSkipParseOnly) { StatusIs(absl::StatusCode::kUnknown, HasSubstr("\"com\""))); } -// google.api.expr.test.v1.proto2.GlobalEnum.GAZ == 2 +// cel.expr.conformance.proto2.GlobalEnum.GAZ == 2 constexpr absl::string_view kEnumExpr = R"pb( reference_map: { key: 8 value: { - name: "google.api.expr.test.v1.proto2.GlobalEnum.GAZ" + name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" value: { int64_value: 2 } } } @@ -307,7 +307,7 @@ constexpr absl::string_view kEnumExpr = R"pb( function: "_==_" args: { id: 8 - ident_expr: { name: "google.api.expr.test.v1.proto2.GlobalEnum.GAZ" } + ident_expr: { name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" } } args: { id: 10 @@ -373,7 +373,7 @@ TEST(ReferenceResolver, ResolveEnumConstantsSkipParseOnly) { EXPECT_THAT( value.GetError().NativeValue(), StatusIs(absl::StatusCode::kUnknown, - HasSubstr("\"google.api.expr.test.v1.proto2.GlobalEnum.GAZ\""))); + HasSubstr("\"cel.expr.conformance.proto2.GlobalEnum.GAZ\""))); } } // namespace diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc index ec081e4a6..b5da4aa4e 100644 --- a/runtime/regex_precompilation_test.cc +++ b/runtime/regex_precompilation_test.cc @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -40,7 +40,7 @@ namespace cel::extensions { namespace { using ::absl_testing::StatusIs; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::_; using ::testing::HasSubstr; diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index a56e2d900..00a9899e3 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -21,7 +21,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" @@ -55,7 +55,7 @@ namespace { using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::test::BoolValueIs; -using ::google::api::expr::v1alpha1::ParsedExpr; +using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::Truly; diff --git a/testutil/BUILD b/testutil/BUILD index f11150d37..96124bb06 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -31,7 +31,7 @@ cc_library( "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -71,7 +71,7 @@ cc_library( "//common:expr", "//extensions/protobuf:ast_converters", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", ], ) diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index ab94c7a2b..18ef9cd7b 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -146,7 +146,7 @@ std::string FormatBaselineAst(const Ast& ast) { } std::string FormatBaselineCheckedExpr( - const google::api::expr::v1alpha1::CheckedExpr& checked) { + const cel::expr::CheckedExpr& checked) { auto ast = cel::extensions::CreateAstFromCheckedExpr(checked); if (!ast.ok()) { return ast.status().ToString(); diff --git a/testutil/baseline_tests.h b/testutil/baseline_tests.h index 857211729..bcb7852a2 100644 --- a/testutil/baseline_tests.h +++ b/testutil/baseline_tests.h @@ -41,7 +41,7 @@ #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "common/ast.h" namespace cel::test { @@ -49,7 +49,7 @@ namespace cel::test { std::string FormatBaselineAst(const Ast& ast); std::string FormatBaselineCheckedExpr( - const google::api::expr::v1alpha1::CheckedExpr& checked); + const cel::expr::CheckedExpr& checked); } // namespace cel::test diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index 20cfc207a..325597d42 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -26,7 +26,7 @@ namespace cel::test { namespace { using ::cel::ast_internal::AstImpl; -using ::google::api::expr::v1alpha1::CheckedExpr; +using ::cel::expr::CheckedExpr; using AstType = ast_internal::Type; diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index 13b468a02..d38d7aa77 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -313,7 +313,7 @@ const ExpressionAdorner& EmptyAdorner() { return *kInstance; } -std::string ExprPrinter::PrintProto(const google::api::expr::v1alpha1::Expr& expr) const { +std::string ExprPrinter::PrintProto(const cel::expr::Expr& expr) const { StringBuilder w(adorner_); absl::StatusOr> ast = CreateAstFromParsedExpr(expr); if (!ast.ok()) { diff --git a/testutil/expr_printer.h b/testutil/expr_printer.h index 643ee9728..6b0a8c161 100644 --- a/testutil/expr_printer.h +++ b/testutil/expr_printer.h @@ -17,7 +17,7 @@ #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "common/expr.h" namespace cel::test { @@ -45,7 +45,7 @@ class ExprPrinter { ExprPrinter() : adorner_(EmptyAdorner()) {} explicit ExprPrinter(const ExpressionAdorner& adorner) : adorner_(adorner) {} - std::string PrintProto(const google::api::expr::v1alpha1::Expr& expr) const; + std::string PrintProto(const cel::expr::Expr& expr) const; std::string Print(const Expr& expr) const; private: diff --git a/tools/BUILD b/tools/BUILD index 38d80f4e2..029aed119 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -54,8 +54,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -69,7 +69,7 @@ cc_test( "//parser", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) @@ -91,7 +91,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc index 904b5876f..d6155cb86 100644 --- a/tools/branch_coverage.cc +++ b/tools/branch_coverage.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" @@ -36,8 +36,8 @@ namespace cel { namespace { -using ::google::api::expr::v1alpha1::CheckedExpr; -using ::google::api::expr::v1alpha1::Type; +using ::cel::expr::CheckedExpr; +using ::cel::expr::Type; using ::google::api::expr::runtime::CelValue; const absl::Status& UnsupportedConversionError() { diff --git a/tools/branch_coverage.h b/tools/branch_coverage.h index 69f25e07d..77c28952c 100644 --- a/tools/branch_coverage.h +++ b/tools/branch_coverage.h @@ -18,7 +18,7 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "absl/base/attributes.h" #include "common/value.h" #include "eval/public/cel_value.h" @@ -55,12 +55,12 @@ class BranchCoverage { virtual NodeCoverageStats StatsForNode(int64_t expr_id) const = 0; virtual const NavigableAst& ast() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; - virtual const google::api::expr::v1alpha1::CheckedExpr& expr() const + virtual const cel::expr::CheckedExpr& expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; }; std::unique_ptr CreateBranchCoverage( - const google::api::expr::v1alpha1::CheckedExpr& expr); + const cel::expr::CheckedExpr& expr); } // namespace cel diff --git a/tools/branch_coverage_test.cc b/tools/branch_coverage_test.cc index 235d11ffc..29fa9bbe4 100644 --- a/tools/branch_coverage_test.cc +++ b/tools/branch_coverage_test.cc @@ -39,7 +39,7 @@ namespace cel { namespace { using ::cel::internal::test::ReadTextProtoFromFile; -using ::google::api::expr::v1alpha1::CheckedExpr; +using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; diff --git a/tools/navigable_ast.cc b/tools/navigable_ast.cc index 7aa862a71..84025c77c 100644 --- a/tools/navigable_ast.cc +++ b/tools/navigable_ast.cc @@ -20,7 +20,7 @@ #include #include -#include "google/api/expr/v1alpha1/checked.pb.h" +#include "cel/expr/checked.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" @@ -50,7 +50,7 @@ size_t AstMetadata::AddNode() { namespace { -using google::api::expr::v1alpha1::Expr; +using cel::expr::Expr; using google::api::expr::runtime::AstTraverse; using google::api::expr::runtime::SourcePosition; diff --git a/tools/navigable_ast.h b/tools/navigable_ast.h index c1f4bf23a..3bc71e7d1 100644 --- a/tools/navigable_ast.h +++ b/tools/navigable_ast.h @@ -22,7 +22,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" @@ -88,7 +88,7 @@ struct AstMetadata; // without exposing too much mutable state on the non-internal classes. struct AstNodeData { AstNode* parent; - const ::google::api::expr::v1alpha1::Expr* expr; + const ::cel::expr::Expr* expr; ChildKind parent_relation; NodeKind node_kind; const AstMetadata* metadata; @@ -101,7 +101,7 @@ struct AstMetadata { std::vector> nodes; std::vector postorder; absl::flat_hash_map id_to_node; - absl::flat_hash_map expr_to_node; + absl::flat_hash_map expr_to_node; AstNodeData& NodeDataAt(size_t index); size_t AddNode(); @@ -133,7 +133,7 @@ class AstNode { // The parent of this node or nullptr if it is a root. absl::Nullable parent() const { return data_.parent; } - absl::Nonnull expr() const { + absl::Nonnull expr() const { return data_.expr; } @@ -192,7 +192,7 @@ class AstNode { // if no mutations take place on the input. class NavigableAst { public: - static NavigableAst Build(const google::api::expr::v1alpha1::Expr& expr); + static NavigableAst Build(const cel::expr::Expr& expr); // Default constructor creates an empty instance. // @@ -222,7 +222,7 @@ class NavigableAst { // Return ptr to the AST node representing the given Expr protobuf node. absl::Nullable FindExpr( - const google::api::expr::v1alpha1::Expr* expr) const { + const cel::expr::Expr* expr) const { auto it = metadata_->expr_to_node.find(expr); if (it == metadata_->expr_to_node.end()) { return nullptr; diff --git a/tools/navigable_ast_test.cc b/tools/navigable_ast_test.cc index 2e3622fb7..63b4ebd5c 100644 --- a/tools/navigable_ast_test.cc +++ b/tools/navigable_ast_test.cc @@ -17,7 +17,7 @@ #include #include -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/syntax.pb.h" #include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "base/builtins.h" @@ -27,7 +27,7 @@ namespace cel { namespace { -using ::google::api::expr::v1alpha1::Expr; +using ::cel::expr::Expr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::IsEmpty; From 13e252f09676a19c0051d942b14f2e2e33ca8368 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 25 Oct 2024 12:52:12 -0700 Subject: [PATCH 004/180] Remove dependency on getting list/map value builders from factory/managers PiperOrigin-RevId: 689882490 --- common/values/list_value_builder.h | 2 + common/values/map_value_builder.h | 2 + common/values/value_builder.cc | 63 ++++++++++++++---------------- eval/eval/create_list_step.cc | 14 +++---- eval/eval/create_map_step.cc | 12 +++--- 5 files changed, 45 insertions(+), 48 deletions(-) diff --git a/common/values/list_value_builder.h b/common/values/list_value_builder.h index e213574ff..0845ced08 100644 --- a/common/values/list_value_builder.h +++ b/common/values/list_value_builder.h @@ -96,6 +96,8 @@ const MutableListValue& GetMutableListValue( const MutableListValue& GetMutableListValue( const ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +absl::Nonnull NewListValueBuilder( + Allocator<> allocator); absl::Nonnull NewListValueBuilder( ValueFactory& value_factory); diff --git a/common/values/map_value_builder.h b/common/values/map_value_builder.h index 05621512a..ac2cdb1dd 100644 --- a/common/values/map_value_builder.h +++ b/common/values/map_value_builder.h @@ -96,6 +96,8 @@ const MutableMapValue& GetMutableMapValue( const MutableMapValue& GetMutableMapValue( const MapValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); +absl::Nonnull NewMapValueBuilder( + Allocator<> allocator); absl::Nonnull NewMapValueBuilder( ValueFactory& value_factory); diff --git a/common/values/value_builder.cc b/common/values/value_builder.cc index 3afe373ce..cc3120113 100644 --- a/common/values/value_builder.cc +++ b/common/values/value_builder.cc @@ -502,11 +502,8 @@ class NonTrivialMutableListValueImpl final : public MutableListValue { class TrivialListValueBuilderImpl final : public ListValueBuilder { public: - TrivialListValueBuilderImpl(ValueFactory& value_factory, - absl::Nonnull arena) - : value_factory_(value_factory), elements_(arena) { - ABSL_DCHECK_EQ(value_factory_.GetMemoryManager().arena(), arena); - } + explicit TrivialListValueBuilderImpl(absl::Nonnull arena) + : arena_(arena), elements_(arena_) {} absl::Status Add(Value value) override { CEL_RETURN_IF_ERROR(CheckListElement(value)); @@ -524,19 +521,18 @@ class TrivialListValueBuilderImpl final : public ListValueBuilder { return ListValue(); } return ParsedListValue( - value_factory_.GetMemoryManager().MakeShared( + MemoryManager::Pooling(arena_).MakeShared( std::move(elements_))); } private: - ValueFactory& value_factory_; + absl::Nonnull const arena_; TrivialValueVector elements_; }; class NonTrivialListValueBuilderImpl final : public ListValueBuilder { public: - explicit NonTrivialListValueBuilderImpl(ValueFactory& value_factory) - : value_factory_(value_factory) {} + NonTrivialListValueBuilderImpl() = default; absl::Status Add(Value value) override { CEL_RETURN_IF_ERROR(CheckListElement(value)); @@ -553,12 +549,11 @@ class NonTrivialListValueBuilderImpl final : public ListValueBuilder { return ListValue(); } return ParsedListValue( - value_factory_.GetMemoryManager().MakeShared( + MemoryManager::ReferenceCounting().MakeShared( std::move(elements_))); } private: - ValueFactory& value_factory_; NonTrivialValueVector elements_; }; @@ -676,13 +671,17 @@ const MutableListValue& GetMutableListValue(const ListValue& value) { } absl::Nonnull NewListValueBuilder( - ValueFactory& value_factory) { - if (absl::Nullable arena = - value_factory.GetMemoryManager().arena(); + Allocator<> allocator) { + if (absl::Nullable arena = allocator.arena(); arena != nullptr) { - return std::make_unique(value_factory, arena); + return std::make_unique(arena); } - return std::make_unique(value_factory); + return std::make_unique(); +} + +absl::Nonnull NewListValueBuilder( + ValueFactory& value_factory) { + return NewListValueBuilder(value_factory.GetMemoryManager()); } } // namespace common_internal @@ -1451,11 +1450,8 @@ class NonTrivialMutableMapValueImpl final : public MutableMapValue { class TrivialMapValueBuilderImpl final : public MapValueBuilder { public: - TrivialMapValueBuilderImpl(ValueFactory& value_factory, - absl::Nonnull arena) - : value_factory_(value_factory), map_(arena) { - ABSL_DCHECK_EQ(value_factory_.GetMemoryManager().arena(), arena); - } + explicit TrivialMapValueBuilderImpl(absl::Nonnull arena) + : arena_(arena), map_(arena_) {} absl::Status Put(Value key, Value value) override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); @@ -1480,20 +1476,18 @@ class TrivialMapValueBuilderImpl final : public MapValueBuilder { return MapValue(); } return ParsedMapValue( - value_factory_.GetMemoryManager().MakeShared( + MemoryManager::Pooling(arena_).MakeShared( std::move(map_))); } private: - ValueFactory& value_factory_; + absl::Nonnull const arena_; TrivialValueFlatHashMap map_; }; class NonTrivialMapValueBuilderImpl final : public MapValueBuilder { public: - explicit NonTrivialMapValueBuilderImpl(ValueFactory& value_factory) - : value_factory_(value_factory), - map_(NonTrivialValueFlatHashMapAllocator{}) {} + NonTrivialMapValueBuilderImpl() = default; absl::Status Put(Value key, Value value) override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); @@ -1517,12 +1511,11 @@ class NonTrivialMapValueBuilderImpl final : public MapValueBuilder { return MapValue(); } return ParsedMapValue( - value_factory_.GetMemoryManager().MakeShared( + MemoryManager::ReferenceCounting().MakeShared( std::move(map_))); } private: - ValueFactory& value_factory_; NonTrivialValueFlatHashMap map_; }; @@ -1644,13 +1637,17 @@ const MutableMapValue& GetMutableMapValue(const MapValue& value) { } absl::Nonnull NewMapValueBuilder( - ValueFactory& value_factory) { - if (absl::Nullable arena = - value_factory.GetMemoryManager().arena(); + Allocator<> allocator) { + if (absl::Nullable arena = allocator.arena(); arena != nullptr) { - return std::make_unique(value_factory, arena); + return std::make_unique(arena); } - return std::make_unique(value_factory); + return std::make_unique(); +} + +absl::Nonnull NewMapValueBuilder( + ValueFactory& value_factory) { + return NewMapValueBuilder(value_factory.GetMemoryManager()); } } // namespace common_internal diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index 065534daf..e1895ad82 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -12,7 +12,6 @@ #include "absl/types/optional.h" #include "base/ast_internal/expr.h" #include "common/casting.h" -#include "common/type.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "eval/eval/attribute_trail.h" @@ -29,9 +28,10 @@ namespace { using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; -using ::cel::ListValueBuilderInterface; +using ::cel::ListValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::common_internal::NewListValueBuilder; class CreateListStep : public ExpressionStepBase { public: @@ -83,8 +83,7 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { } } - CEL_ASSIGN_OR_RETURN(auto builder, frame->value_manager().NewListValueBuilder( - cel::ListType())); + ListValueBuilderPtr builder = NewListValueBuilder(frame->memory_manager()); builder->Reserve(args.size()); for (size_t i = 0; i < args.size(); ++i) { @@ -130,9 +129,8 @@ class CreateListDirectStep : public DirectExpressionStep { absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override { - CEL_ASSIGN_OR_RETURN( - auto builder, - frame.value_manager().NewListValueBuilder(cel::ListType())); + ListValueBuilderPtr builder = + NewListValueBuilder(frame.value_manager().GetMemoryManager()); builder->Reserve(elements_.size()); AttributeUtility::Accumulator unknowns = @@ -231,8 +229,6 @@ class DirectMutableListStep : public DirectExpressionStep { absl::Status DirectMutableListStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { - CEL_ASSIGN_OR_RETURN( - auto builder, frame.value_manager().NewListValueBuilder(cel::ListType())); result = cel::ParsedListValue(cel::common_internal::NewMutableListValue( frame.value_manager().GetMemoryManager().arena())); return absl::OkStatus(); diff --git a/eval/eval/create_map_step.cc b/eval/eval/create_map_step.cc index f52d7b2ea..3d8d86729 100644 --- a/eval/eval/create_map_step.cc +++ b/eval/eval/create_map_step.cc @@ -26,9 +26,9 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/casting.h" -#include "common/type.h" #include "common/value.h" #include "common/value_manager.h" +#include "common/values/map_value_builder.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" @@ -42,9 +42,10 @@ namespace { using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; -using ::cel::StructValueBuilderInterface; +using ::cel::MapValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::common_internal::NewMapValueBuilder; // `CreateStruct` implementation for map. class CreateStructStepForMap final : public ExpressionStepBase { @@ -77,8 +78,7 @@ absl::StatusOr CreateStructStepForMap::DoEvaluate( } } - CEL_ASSIGN_OR_RETURN( - auto builder, frame->value_manager().NewMapValueBuilder(cel::MapType{})); + MapValueBuilderPtr builder = NewMapValueBuilder(frame->memory_manager()); builder->Reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { @@ -151,8 +151,8 @@ absl::Status DirectCreateMapStep::Evaluate( AttributeTrail tmp_attr; auto unknowns = frame.attribute_utility().CreateAccumulator(); - CEL_ASSIGN_OR_RETURN( - auto builder, frame.value_manager().NewMapValueBuilder(cel::MapType())); + MapValueBuilderPtr builder = + NewMapValueBuilder(frame.value_manager().GetMemoryManager()); builder->Reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { From 833abb15a4e1f5d2c793c52890f51bb20980f27b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 25 Oct 2024 15:45:50 -0700 Subject: [PATCH 005/180] Add support for declaring an overall expected type to the type checker. PiperOrigin-RevId: 689936845 --- checker/internal/type_check_env.h | 6 ++++ checker/internal/type_checker_impl.cc | 13 ++++++- checker/internal/type_checker_impl.h | 2 +- checker/internal/type_checker_impl_test.cc | 42 ++++++++++++++++++++++ checker/type_checker_builder.cc | 5 +++ checker/type_checker_builder.h | 3 ++ 6 files changed, 69 insertions(+), 2 deletions(-) diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 2c694dd2e..9cad4ae72 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -107,6 +107,10 @@ class TypeCheckEnv { container_ = std::move(container); } + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } + + const absl::optional& expected_type() const { return expected_type_; } + absl::Span> type_providers() const { return type_providers_; } @@ -198,6 +202,8 @@ class TypeCheckEnv { // Type providers for custom types. std::vector> type_providers_; + + absl::optional expected_type_; }; } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 4c1975bfa..fd665b5d9 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -239,7 +239,7 @@ absl::StatusOr FlattenType(const Type& type) { return absl::InternalError( absl::StrCat("Unsupported type: ", type.DebugString())); } -} // namespace +} class ResolveVisitor : public AstVisitorBase { public: @@ -322,6 +322,13 @@ class ResolveVisitor : public AstVisitorBase { const absl::Status& status() const { return status_; } + void AssertExpectedType(const Expr& expr, const Type& expected_type) { + Type observed = GetTypeOrDyn(&expr); + if (!inference_context_->IsAssignable(observed, expected_type)) { + ReportTypeMismatch(expr.id(), expected_type, observed); + } + } + private: struct ComprehensionScope { const Expr* comprehension_expr; @@ -1231,6 +1238,10 @@ absl::StatusOr TypeCheckerImpl::Check( AstTraverse(ast_impl.root_expr(), visitor, opts); CEL_RETURN_IF_ERROR(visitor.status()); + if (env_.expected_type().has_value()) { + visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); + } + // If any issues are errors, return without an AST. for (const auto& issue : issues) { if (issue.severity() == Severity::kError) { diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index f28621030..1b9062ec1 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -33,7 +33,7 @@ namespace cel::checker_internal { // See cel::TypeCheckerBuilder for constructing instances. class TypeCheckerImpl : public TypeChecker { public: - TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) + explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) : env_(std::move(env)), options_(options) {} TypeCheckerImpl(const TypeCheckerImpl&) = delete; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 50be6d671..e0ff26ff8 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -1230,6 +1230,48 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(ast_internal::DynamicType()))))))); } +TEST(TypeCheckerImplTest, ExpectedTypeMatches) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + + EXPECT_THAT( + ast_impl.type_map(), + Contains(Pair( + ast_impl.root_expr().id(), + Eq(AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique( + ast_internal::PrimitiveType::kString))))))); +} + +TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, + "expected type 'map' but found 'map'"))); +} + TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container("google.protobuf"); diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index bd5eee3f9..c10675bf4 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -32,6 +32,7 @@ #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" @@ -170,4 +171,8 @@ void TypeCheckerBuilder::set_container(absl::string_view container) { env_.set_container(std::string(container)); } +void TypeCheckerBuilder::SetExpectedType(const Type& type) { + env_.set_expected_type(type); +} + } // namespace cel diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f6eb5aec0..1253c0cae 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -30,6 +30,7 @@ #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" @@ -91,6 +92,8 @@ class TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl); absl::Status AddFunction(const FunctionDecl& decl); + void SetExpectedType(const Type& type); + // Adds function declaration overloads to the TypeChecker being built. // // Attempts to merge with any existing overloads for a function decl with the From f10fd17c42c62c58f5de8bb83c7850146846019e Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 25 Oct 2024 16:47:32 -0700 Subject: [PATCH 006/180] Fix forwarding checker options from CreateTypeCheckerBuilder overload. PiperOrigin-RevId: 689953414 --- checker/type_checker_builder.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index c10675bf4..06e755ddd 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -86,8 +86,11 @@ absl::StatusOr CreateTypeCheckerBuilder( absl::Nonnull descriptor_pool, const CheckerOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); - return CreateTypeCheckerBuilder(std::shared_ptr( - descriptor_pool, [](absl::Nullable) {})); + return CreateTypeCheckerBuilder( + std::shared_ptr( + descriptor_pool, + [](absl::Nullable) {}), + options); } absl::StatusOr CreateTypeCheckerBuilder( From 80f0b1107590ac10f7ef29c87dd574be277d5178 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 28 Oct 2024 10:55:00 -0700 Subject: [PATCH 007/180] Update `RuntimeBuilder` to accept `std::shared_ptr` to match type checker PiperOrigin-RevId: 690674918 --- checker/BUILD | 1 + checker/type_checker_builder.cc | 3 +- internal/BUILD | 7 +++ internal/noop_delete.h | 53 +++++++++++++++++++++ internal/testing_descriptor_pool.cc | 3 +- internal/well_known_types.cc | 12 +++++ internal/well_known_types.h | 2 + runtime/BUILD | 4 ++ runtime/internal/BUILD | 3 ++ runtime/internal/runtime_impl.h | 20 +++++--- runtime/runtime_builder.h | 6 ++- runtime/runtime_builder_factory.cc | 35 ++++++++++---- runtime/runtime_builder_factory.h | 6 +++ runtime/standard_runtime_builder_factory.cc | 22 ++++++++- runtime/standard_runtime_builder_factory.h | 6 +++ 15 files changed, 163 insertions(+), 20 deletions(-) create mode 100644 internal/noop_delete.h diff --git a/checker/BUILD b/checker/BUILD index 25074887a..7a5ffab13 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -90,6 +90,7 @@ cc_library( "//checker/internal:type_checker_impl", "//common:decl", "//common:type", + "//internal:noop_delete", "//internal:status_macros", "//internal:well_known_types", "//parser:macro", diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index 06e755ddd..aada156ed 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -34,6 +34,7 @@ #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/noop_delete.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "parser/macro.h" @@ -89,7 +90,7 @@ absl::StatusOr CreateTypeCheckerBuilder( return CreateTypeCheckerBuilder( std::shared_ptr( descriptor_pool, - [](absl::Nullable) {}), + internal::NoopDeleteFor()), options); } diff --git a/internal/BUILD b/internal/BUILD index 5dd0d26ad..13cc004b5 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -580,6 +580,7 @@ cc_library( hdrs = ["testing_descriptor_pool.h"], textual_hdrs = [":testing_descriptor_set_embed"], deps = [ + ":noop_delete", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", @@ -829,3 +830,9 @@ cc_library( hdrs = ["protobuf_runtime_version.h"], deps = ["@com_google_protobuf//:protobuf"], ) + +cc_library( + name = "noop_delete", + hdrs = ["noop_delete.h"], + deps = ["@com_google_absl//absl/base:nullability"], +) diff --git a/internal/noop_delete.h b/internal/noop_delete.h new file mode 100644 index 000000000..151a87c0f --- /dev/null +++ b/internal/noop_delete.h @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ + +#include + +#include "absl/base/nullability.h" + +namespace cel::internal { + +// Like `std::default_delete`, except it does nothing. +template +struct NoopDelete { + static_assert(!std::is_function::value, + "NoopDelete cannot be instantiated for function types"); + + constexpr NoopDelete() noexcept = default; + constexpr NoopDelete(const NoopDelete&) noexcept = default; + + template < + typename U, + typename = std::enable_if_t>, std::is_convertible>>> + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr NoopDelete(const NoopDelete&) noexcept {} + + constexpr void operator()(absl::Nullable) const noexcept { + static_assert(sizeof(T) >= 0, "cannot delete an incomplete type"); + static_assert(!std::is_void::value, "cannot delete an incomplete type"); + } +}; + +template +inline constexpr NoopDelete NoopDeleteFor() noexcept { + return NoopDelete{}; +} + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ diff --git a/internal/testing_descriptor_pool.cc b/internal/testing_descriptor_pool.cc index 3e3ab193e..bcacf1b5d 100644 --- a/internal/testing_descriptor_pool.cc +++ b/internal/testing_descriptor_pool.cc @@ -23,6 +23,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" +#include "internal/noop_delete.h" #include "google/protobuf/descriptor.h" namespace cel::internal { @@ -54,7 +55,7 @@ GetSharedTestingDescriptorPool() { static const absl::NoDestructor< absl::Nonnull>> instance(GetTestingDescriptorPool(), - [](absl::Nullable) {}); + internal::NoopDeleteFor()); return *instance; } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index f6511cff2..653ba65ad 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1762,6 +1762,18 @@ absl::Status Reflection::Initialize(absl::Nonnull pool) { return absl::OkStatus(); } +bool Reflection::IsInitialized() const { + // Check that everything is initialized except field mask, which is optional. + return NullValue().IsInitialized() && BoolValue().IsInitialized() && + Int32Value().IsInitialized() && Int64Value().IsInitialized() && + UInt32Value().IsInitialized() && UInt64Value().IsInitialized() && + FloatValue().IsInitialized() && DoubleValue().IsInitialized() && + BytesValue().IsInitialized() && StringValue().IsInitialized() && + Any().IsInitialized() && Duration().IsInitialized() && + Timestamp().IsInitialized() && Value().IsInitialized() && + ListValue().IsInitialized() && Struct().IsInitialized(); +} + namespace { // AdaptListValue verifies the message is the well known type diff --git a/internal/well_known_types.h b/internal/well_known_types.h index 94d3b37d6..fa4fe485c 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -1391,6 +1391,8 @@ class Reflection final { absl::Status Initialize(absl::Nonnull pool); + bool IsInitialized() const; + // At the moment we only use this class for verifying well known types in // descriptor pools. We could eagerly initialize it and cache it somewhere to // make things faster. diff --git a/runtime/BUILD b/runtime/BUILD index 6dd16eeda..c453afb89 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -213,10 +213,12 @@ cc_library( deps = [ ":runtime_builder", ":runtime_options", + "//internal:noop_delete", "//internal:status_macros", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], @@ -231,9 +233,11 @@ cc_library( ":runtime_builder_factory", ":runtime_options", ":standard_functions", + "//internal:noop_delete", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 503fbe786..69b8a8e3e 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -69,8 +69,11 @@ cc_library( "//runtime:function_registry", "//runtime:runtime_options", "//runtime:type_registry", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) diff --git a/runtime/internal/runtime_impl.h b/runtime/internal/runtime_impl.h index 4782fe95b..4dc2fe929 100644 --- a/runtime/internal/runtime_impl.h +++ b/runtime/internal/runtime_impl.h @@ -16,7 +16,11 @@ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ #include +#include +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/type_provider.h" @@ -27,21 +31,28 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" namespace cel::runtime_internal { class RuntimeImpl : public Runtime { public: struct Environment { + ABSL_ATTRIBUTE_UNUSED + absl::Nonnull> + descriptor_pool; TypeRegistry type_registry; FunctionRegistry function_registry; well_known_types::Reflection well_known_types; }; - explicit RuntimeImpl(const RuntimeOptions& options) - : environment_(std::make_shared()), + RuntimeImpl(absl::Nonnull> environment, + const RuntimeOptions& options) + : environment_(std::move(environment)), expr_builder_(environment_->function_registry, - environment_->type_registry, options) {} + environment_->type_registry, options) { + ABSL_DCHECK(environment_->well_known_types.IsInitialized()); + } TypeRegistry& type_registry() { return environment_->type_registry; } const TypeRegistry& type_registry() const { @@ -55,9 +66,6 @@ class RuntimeImpl : public Runtime { return environment_->function_registry; } - well_known_types::Reflection& well_known_types() { - return environment_->well_known_types; - } const well_known_types::Reflection& well_known_types() const { return environment_->well_known_types; } diff --git a/runtime/runtime_builder.h b/runtime/runtime_builder.h index 3dcb3e280..3bfbcd62f 100644 --- a/runtime/runtime_builder.h +++ b/runtime/runtime_builder.h @@ -36,7 +36,8 @@ class RuntimeFriendAccess; class RuntimeBuilder; absl::StatusOr CreateRuntimeBuilder( - absl::Nonnull, const RuntimeOptions&); + absl::Nonnull>, + const RuntimeOptions&); // RuntimeBuilder provides mutable accessors to configure a new runtime. // @@ -64,7 +65,8 @@ class RuntimeBuilder { private: friend class runtime_internal::RuntimeFriendAccess; friend absl::StatusOr CreateRuntimeBuilder( - absl::Nonnull, const RuntimeOptions&); + absl::Nonnull>, + const RuntimeOptions&); // Constructor for a new runtime builder. // diff --git a/runtime/runtime_builder_factory.cc b/runtime/runtime_builder_factory.cc index 7b726bff0..34e16b03a 100644 --- a/runtime/runtime_builder_factory.cc +++ b/runtime/runtime_builder_factory.cc @@ -18,7 +18,9 @@ #include #include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" +#include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" @@ -27,27 +29,44 @@ namespace cel { +using ::cel::runtime_internal::RuntimeImpl; + absl::StatusOr CreateRuntimeBuilder( absl::Nonnull descriptor_pool, const RuntimeOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateRuntimeBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr CreateRuntimeBuilder( + absl::Nonnull> + descriptor_pool, + const RuntimeOptions& options) { // TODO: and internal API for adding extensions that need to // downcast to the runtime impl. // TODO: add API for attaching an issue listener (replacing the // vector overloads). - auto mutable_runtime = - std::make_unique(options); - CEL_RETURN_IF_ERROR( - mutable_runtime->well_known_types().Initialize(descriptor_pool)); - mutable_runtime->expr_builder().set_container(options.container); + ABSL_DCHECK(descriptor_pool != nullptr); + auto environment = std::make_shared(); + environment->descriptor_pool = std::move(descriptor_pool); + CEL_RETURN_IF_ERROR(environment->well_known_types.Initialize( + environment->descriptor_pool.get())); + auto runtime_impl = + std::make_unique(std::move(environment), options); + runtime_impl->expr_builder().set_container(options.container); - auto& type_registry = mutable_runtime->type_registry(); - auto& function_registry = mutable_runtime->function_registry(); + auto& type_registry = runtime_impl->type_registry(); + auto& function_registry = runtime_impl->function_registry(); type_registry.set_use_legacy_container_builders( options.use_legacy_container_builders); return RuntimeBuilder(type_registry, function_registry, - std::move(mutable_runtime)); + std::move(runtime_impl)); } } // namespace cel diff --git a/runtime/runtime_builder_factory.h b/runtime/runtime_builder_factory.h index 8ee9f2ec0..377727bea 100644 --- a/runtime/runtime_builder_factory.h +++ b/runtime/runtime_builder_factory.h @@ -15,6 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ +#include + #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -54,6 +56,10 @@ absl::StatusOr CreateRuntimeBuilder( absl::Nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, const RuntimeOptions& options); +absl::StatusOr CreateRuntimeBuilder( + absl::Nonnull> + descriptor_pool, + const RuntimeOptions& options); } // namespace cel diff --git a/runtime/standard_runtime_builder_factory.cc b/runtime/standard_runtime_builder_factory.cc index e42766398..2d28c9444 100644 --- a/runtime/standard_runtime_builder_factory.cc +++ b/runtime/standard_runtime_builder_factory.cc @@ -14,8 +14,13 @@ #include "runtime/standard_runtime_builder_factory.h" +#include +#include + #include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" +#include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_builder_factory.h" @@ -28,8 +33,21 @@ namespace cel { absl::StatusOr CreateStandardRuntimeBuilder( absl::Nonnull descriptor_pool, const RuntimeOptions& options) { - CEL_ASSIGN_OR_RETURN(auto builder, - CreateRuntimeBuilder(descriptor_pool, options)); + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateStandardRuntimeBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr CreateStandardRuntimeBuilder( + absl::Nonnull> + descriptor_pool, + const RuntimeOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + CEL_ASSIGN_OR_RETURN( + auto builder, CreateRuntimeBuilder(std::move(descriptor_pool), options)); CEL_RETURN_IF_ERROR( RegisterStandardFunctions(builder.function_registry(), options)); return builder; diff --git a/runtime/standard_runtime_builder_factory.h b/runtime/standard_runtime_builder_factory.h index 523b9fb02..70ff62e31 100644 --- a/runtime/standard_runtime_builder_factory.h +++ b/runtime/standard_runtime_builder_factory.h @@ -15,6 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ +#include + #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -32,6 +34,10 @@ absl::StatusOr CreateStandardRuntimeBuilder( absl::Nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, const RuntimeOptions& options); +absl::StatusOr CreateStandardRuntimeBuilder( + absl::Nonnull> + descriptor_pool, + const RuntimeOptions& options); } // namespace cel From 29818b826c130c8dd5b0d5b64c8d4953349e5ad8 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 28 Oct 2024 14:07:25 -0700 Subject: [PATCH 008/180] Remove portable expression factory PiperOrigin-RevId: 690743957 --- eval/public/BUILD | 72 +- eval/public/cel_expr_builder_factory.cc | 106 ++- eval/public/cel_expr_builder_factory.h | 5 +- .../portable_cel_expr_builder_factory.cc | 139 ---- .../portable_cel_expr_builder_factory.h | 44 -- .../portable_cel_expr_builder_factory_test.cc | 689 ------------------ 6 files changed, 118 insertions(+), 937 deletions(-) delete mode 100644 eval/public/portable_cel_expr_builder_factory.cc delete mode 100644 eval/public/portable_cel_expr_builder_factory.h delete mode 100644 eval/public/portable_cel_expr_builder_factory_test.cc diff --git a/eval/public/BUILD b/eval/public/BUILD index c11f16fc1..be7b3a1c8 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -570,13 +570,26 @@ cc_library( ], deps = [ ":cel_expression", + ":cel_function", ":cel_options", - ":portable_cel_expr_builder_factory", + "//base:kind", + "//base/ast_internal:ast_impl", + "//common:memory", + "//eval/compiler:cel_expression_builder_flat_impl", + "//eval/compiler:comprehension_vulnerability_check", + "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", - "//eval/public/structs:proto_message_type_adapter", + "//eval/compiler:flat_expr_builder_extensions", + "//eval/compiler:qualified_reference_resolver", + "//eval/compiler:regex_precompilation_optimization", "//eval/public/structs:protobuf_descriptor_type_provider", + "//extensions:select_optimization", + "//extensions/protobuf:memory_manager", "//internal:proto_util", + "//runtime:runtime_options", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) @@ -1095,35 +1108,6 @@ cc_library( ], ) -cc_library( - name = "portable_cel_expr_builder_factory", - srcs = ["portable_cel_expr_builder_factory.cc"], - hdrs = ["portable_cel_expr_builder_factory.h"], - deps = [ - ":cel_expression", - ":cel_function", - ":cel_options", - "//base:kind", - "//base/ast_internal:ast_impl", - "//common:memory", - "//common:value", - "//eval/compiler:cel_expression_builder_flat_impl", - "//eval/compiler:comprehension_vulnerability_check", - "//eval/compiler:constant_folding", - "//eval/compiler:flat_expr_builder", - "//eval/compiler:flat_expr_builder_extensions", - "//eval/compiler:qualified_reference_resolver", - "//eval/compiler:regex_precompilation_optimization", - "//eval/public/structs:legacy_type_provider", - "//extensions:select_optimization", - "//extensions/protobuf:memory_manager", - "//runtime:runtime_options", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - cc_library( name = "string_extension_func_registrar", srcs = ["string_extension_func_registrar.cc"], @@ -1151,29 +1135,3 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) - -cc_test( - name = "portable_cel_expr_builder_factory_test", - srcs = ["portable_cel_expr_builder_factory_test.cc"], - deps = [ - ":activation", - ":builtin_func_registrar", - ":cel_options", - ":cel_value", - ":portable_cel_expr_builder_factory", - "//eval/public/structs:legacy_type_adapter", - "//eval/public/structs:legacy_type_info_apis", - "//eval/public/structs:legacy_type_provider", - "//eval/testutil:test_message_cc_proto", - "//extensions/protobuf:memory_manager", - "//internal:casts", - "//internal:proto_time_encoding", - "//internal:testing", - "//parser", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_protobuf//:protobuf", - ], -) diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index b0eda9a55..cc061a7ea 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -17,21 +17,56 @@ #include "eval/public/cel_expr_builder_factory.h" #include -#include -#include +#include "absl/log/absl_log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "base/ast_internal/ast_impl.h" +#include "base/kind.h" +#include "common/memory.h" +#include "eval/compiler/cel_expression_builder_flat_impl.h" +#include "eval/compiler/comprehension_vulnerability_check.h" +#include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" +#include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" #include "eval/public/cel_options.h" -#include "eval/public/portable_cel_expr_builder_factory.h" -#include "eval/public/structs/proto_message_type_adapter.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "extensions/protobuf/memory_manager.h" +#include "extensions/select_optimization.h" #include "internal/proto_util.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { + +using ::cel::MemoryManagerRef; +using ::cel::ast_internal::AstImpl; +using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; +using ::cel::extensions::kCelAttribute; +using ::cel::extensions::kCelHasField; +using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::extensions::SelectOptimizationAstUpdater; +using ::cel::runtime_internal::CreateConstantFoldingOptimizer; using ::google::api::expr::internal::ValidateStandardMessageTypes; + +// Adapter for a raw arena* pointer. Manages a MemoryManager object for the +// constant folding extension. +struct ArenaBackedConstfoldingFactory { + MemoryManagerRef memory_manager; + + absl::StatusOr> operator()( + PlannerContext& ctx, const AstImpl& ast) const { + return CreateConstantFoldingOptimizer(memory_manager)(ctx, ast); + } +}; + } // namespace std::unique_ptr CreateCelExpressionBuilder( @@ -49,10 +84,67 @@ std::unique_ptr CreateCelExpressionBuilder( return nullptr; } + cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); auto builder = - CreatePortableExprBuilder(std::make_unique( - descriptor_pool, message_factory), - options); + std::make_unique(runtime_options); + + builder->GetTypeRegistry() + ->InternalGetModernRegistry() + .set_use_legacy_container_builders(options.use_legacy_container_builders); + + builder->GetTypeRegistry()->RegisterTypeProvider( + std::make_unique(descriptor_pool, + message_factory)); + + FlatExprBuilder& flat_expr_builder = builder->flat_expr_builder(); + + flat_expr_builder.AddAstTransform(NewReferenceResolverExtension( + (options.enable_qualified_identifier_rewrites) + ? ReferenceResolverOption::kAlways + : ReferenceResolverOption::kCheckedOnly)); + + if (options.enable_comprehension_vulnerability_check) { + builder->flat_expr_builder().AddProgramOptimizer( + CreateComprehensionVulnerabilityCheck()); + } + + if (options.constant_folding) { + builder->flat_expr_builder().AddProgramOptimizer( + ArenaBackedConstfoldingFactory{ + ProtoMemoryManagerRef(options.constant_arena)}); + } + + if (options.enable_regex_precompilation) { + flat_expr_builder.AddProgramOptimizer( + CreateRegexPrecompilationExtension(options.regex_max_program_size)); + } + + if (options.enable_select_optimization) { + // Add AST transform to update select branches on a stored + // CheckedExpression. This may already be performed by a type checker. + flat_expr_builder.AddAstTransform( + std::make_unique()); + // Add overloads for select optimization signature. + // These are never bound, only used to prevent the builder from failing on + // the overloads check. + absl::Status status = + builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + kCelAttribute, false, {cel::Kind::kAny, cel::Kind::kList})); + if (!status.ok()) { + ABSL_LOG(ERROR) << "Failed to register " << kCelAttribute << ": " + << status; + } + status = builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + kCelHasField, false, {cel::Kind::kAny, cel::Kind::kList})); + if (!status.ok()) { + ABSL_LOG(ERROR) << "Failed to register " << kCelHasField << ": " + << status; + } + // Add runtime implementation. + flat_expr_builder.AddProgramOptimizer( + CreateSelectOptimizationProgramOptimizer()); + } + return builder; } diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index 7321e29a2..0fd7f95fc 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -1,9 +1,12 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ -#include "google/protobuf/descriptor.h" +#include + #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google { namespace api { diff --git a/eval/public/portable_cel_expr_builder_factory.cc b/eval/public/portable_cel_expr_builder_factory.cc deleted file mode 100644 index eb78854c9..000000000 --- a/eval/public/portable_cel_expr_builder_factory.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "eval/public/portable_cel_expr_builder_factory.h" - -#include -#include - -#include "absl/log/absl_log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "base/ast_internal/ast_impl.h" -#include "base/kind.h" -#include "common/memory.h" -#include "common/values/legacy_type_reflector.h" -#include "eval/compiler/cel_expression_builder_flat_impl.h" -#include "eval/compiler/comprehension_vulnerability_check.h" -#include "eval/compiler/constant_folding.h" -#include "eval/compiler/flat_expr_builder.h" -#include "eval/compiler/flat_expr_builder_extensions.h" -#include "eval/compiler/qualified_reference_resolver.h" -#include "eval/compiler/regex_precompilation_optimization.h" -#include "eval/public/cel_expression.h" -#include "eval/public/cel_function.h" -#include "eval/public/cel_options.h" -#include "eval/public/structs/legacy_type_provider.h" -#include "extensions/protobuf/memory_manager.h" -#include "extensions/select_optimization.h" -#include "runtime/runtime_options.h" - -namespace google::api::expr::runtime { -namespace { - -using ::cel::MemoryManagerRef; -using ::cel::ast_internal::AstImpl; -using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; -using ::cel::extensions::kCelAttribute; -using ::cel::extensions::kCelHasField; -using ::cel::extensions::ProtoMemoryManagerRef; -using ::cel::extensions::SelectOptimizationAstUpdater; -using ::cel::runtime_internal::CreateConstantFoldingOptimizer; - -// Adapter for a raw arena* pointer. Manages a MemoryManager object for the -// constant folding extension. -struct ArenaBackedConstfoldingFactory { - MemoryManagerRef memory_manager; - - absl::StatusOr> operator()( - PlannerContext& ctx, const AstImpl& ast) const { - return CreateConstantFoldingOptimizer(memory_manager)(ctx, ast); - } -}; - -} // namespace - -std::unique_ptr CreatePortableExprBuilder( - std::unique_ptr type_provider, - const InterpreterOptions& options) { - if (type_provider == nullptr) { - ABSL_LOG(ERROR) << "Cannot pass nullptr as type_provider to " - "CreatePortableExprBuilder"; - return nullptr; - } - cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); - auto builder = - std::make_unique(runtime_options); - - builder->GetTypeRegistry() - ->InternalGetModernRegistry() - .set_use_legacy_container_builders(options.use_legacy_container_builders); - - builder->GetTypeRegistry()->RegisterTypeProvider(std::move(type_provider)); - - FlatExprBuilder& flat_expr_builder = builder->flat_expr_builder(); - - flat_expr_builder.AddAstTransform(NewReferenceResolverExtension( - (options.enable_qualified_identifier_rewrites) - ? ReferenceResolverOption::kAlways - : ReferenceResolverOption::kCheckedOnly)); - - if (options.enable_comprehension_vulnerability_check) { - builder->flat_expr_builder().AddProgramOptimizer( - CreateComprehensionVulnerabilityCheck()); - } - - if (options.constant_folding) { - builder->flat_expr_builder().AddProgramOptimizer( - ArenaBackedConstfoldingFactory{ - ProtoMemoryManagerRef(options.constant_arena)}); - } - - if (options.enable_regex_precompilation) { - flat_expr_builder.AddProgramOptimizer( - CreateRegexPrecompilationExtension(options.regex_max_program_size)); - } - - if (options.enable_select_optimization) { - // Add AST transform to update select branches on a stored - // CheckedExpression. This may already be performed by a type checker. - flat_expr_builder.AddAstTransform( - std::make_unique()); - // Add overloads for select optimization signature. - // These are never bound, only used to prevent the builder from failing on - // the overloads check. - absl::Status status = - builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( - kCelAttribute, false, {cel::Kind::kAny, cel::Kind::kList})); - if (!status.ok()) { - ABSL_LOG(ERROR) << "Failed to register " << kCelAttribute << ": " - << status; - } - status = builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( - kCelHasField, false, {cel::Kind::kAny, cel::Kind::kList})); - if (!status.ok()) { - ABSL_LOG(ERROR) << "Failed to register " << kCelHasField << ": " - << status; - } - // Add runtime implementation. - flat_expr_builder.AddProgramOptimizer( - CreateSelectOptimizationProgramOptimizer()); - } - - return builder; -} - -} // namespace google::api::expr::runtime diff --git a/eval/public/portable_cel_expr_builder_factory.h b/eval/public/portable_cel_expr_builder_factory.h deleted file mode 100644 index b31b51ccf..000000000 --- a/eval/public/portable_cel_expr_builder_factory.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ -#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ - -#include "eval/public/cel_expression.h" -#include "eval/public/cel_options.h" -#include "eval/public/structs/legacy_type_provider.h" - -namespace google { -namespace api { -namespace expr { -namespace runtime { - -// Factory for initializing a CelExpressionBuilder implementation for public -// use. -// -// This version does not include any message type information, instead deferring -// to the type_provider argument. type_provider is guaranteed to be the first -// type provider in the type registry. -std::unique_ptr CreatePortableExprBuilder( - std::unique_ptr type_provider, - const InterpreterOptions& options = InterpreterOptions()); - -} // namespace runtime -} // namespace expr -} // namespace api -} // namespace google - -#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_EXPR_BUILDER_FACTORY_H_ diff --git a/eval/public/portable_cel_expr_builder_factory_test.cc b/eval/public/portable_cel_expr_builder_factory_test.cc deleted file mode 100644 index 31d41f1d0..000000000 --- a/eval/public/portable_cel_expr_builder_factory_test.cc +++ /dev/null @@ -1,689 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "eval/public/portable_cel_expr_builder_factory.h" - -#include -#include -#include -#include -#include - -#include "google/protobuf/duration.pb.h" -#include "google/protobuf/timestamp.pb.h" -#include "google/protobuf/wrappers.pb.h" -#include "absl/container/node_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "eval/public/activation.h" -#include "eval/public/builtin_func_registrar.h" -#include "eval/public/cel_options.h" -#include "eval/public/cel_value.h" -#include "eval/public/structs/legacy_type_adapter.h" -#include "eval/public/structs/legacy_type_info_apis.h" -#include "eval/public/structs/legacy_type_provider.h" -#include "eval/testutil/test_message.pb.h" -#include "extensions/protobuf/memory_manager.h" -#include "internal/casts.h" -#include "internal/proto_time_encoding.h" -#include "internal/testing.h" -#include "parser/parser.h" - -namespace google::api::expr::runtime { -namespace { - -using ::cel::expr::ParsedExpr; -using ::google::protobuf::Int64Value; - -// Helpers for c++ / proto to cel value conversions. -absl::optional Unwrap(const google::protobuf::MessageLite* wrapper) { - if (wrapper->GetTypeName() == "google.protobuf.Duration") { - const auto* duration = - cel::internal::down_cast(wrapper); - return CelValue::CreateDuration(cel::internal::DecodeDuration(*duration)); - } else if (wrapper->GetTypeName() == "google.protobuf.Timestamp") { - const auto* timestamp = - cel::internal::down_cast(wrapper); - return CelValue::CreateTimestamp(cel::internal::DecodeTime(*timestamp)); - } - return absl::nullopt; -} - -struct NativeToCelValue { - template - absl::optional Convert(T arg) const { - return absl::nullopt; - } - - absl::optional Convert(int64_t v) const { - return CelValue::CreateInt64(v); - } - - absl::optional Convert(const std::string& str) const { - return CelValue::CreateString(&str); - } - - absl::optional Convert(double v) const { - return CelValue::CreateDouble(v); - } - - absl::optional Convert(bool v) const { - return CelValue::CreateBool(v); - } - - absl::optional Convert(const Int64Value& v) const { - return CelValue::CreateInt64(v.value()); - } -}; - -template -class FieldImpl; - -template -class ProtoField { - public: - template - using FieldImpl = FieldImpl; - - virtual ~ProtoField() = default; - virtual absl::Status Set(MessageT* m, CelValue v) const = 0; - virtual absl::StatusOr Get(const MessageT* m) const = 0; - virtual bool Has(const MessageT* m) const = 0; -}; - -// template helpers for wrapping member accessors generically. -template -struct ScalarApiWrap { - using GetFn = FieldT (MessageT::*)() const; - using HasFn = bool (MessageT::*)() const; - using SetFn = void (MessageT::*)(FieldT); - - ScalarApiWrap(GetFn get_fn, HasFn has_fn, SetFn set_fn) - : get_fn(get_fn), has_fn(has_fn), set_fn(set_fn) {} - - FieldT InvokeGet(const MessageT* msg) const { - return std::invoke(get_fn, msg); - } - bool InvokeHas(const MessageT* msg) const { - if (has_fn == nullptr) return true; - return std::invoke(has_fn, msg); - } - void InvokeSet(MessageT* msg, FieldT arg) const { - if (set_fn != nullptr) { - std::invoke(set_fn, msg, arg); - } - } - - GetFn get_fn; - HasFn has_fn; - SetFn set_fn; -}; - -template -struct ComplexTypeApiWrap { - public: - using GetFn = const FieldT& (MessageT::*)() const; - using HasFn = bool (MessageT::*)() const; - using SetAllocatedFn = void (MessageT::*)(FieldT*); - - ComplexTypeApiWrap(GetFn get_fn, HasFn has_fn, - SetAllocatedFn set_allocated_fn) - : get_fn(get_fn), has_fn(has_fn), set_allocated_fn(set_allocated_fn) {} - - const FieldT& InvokeGet(const MessageT* msg) const { - return std::invoke(get_fn, msg); - } - bool InvokeHas(const MessageT* msg) const { - if (has_fn == nullptr) return true; - return std::invoke(has_fn, msg); - } - - void InvokeSetAllocated(MessageT* msg, FieldT* arg) const { - if (set_allocated_fn != nullptr) { - std::invoke(set_allocated_fn, msg, arg); - } - } - - GetFn get_fn; - HasFn has_fn; - SetAllocatedFn set_allocated_fn; -}; - -template -class FieldImpl : public ProtoField { - private: - using ApiWrap = ScalarApiWrap; - - public: - FieldImpl(typename ApiWrap::GetFn get_fn, typename ApiWrap::HasFn has_fn, - typename ApiWrap::SetFn set_fn) - : api_wrapper_(get_fn, has_fn, set_fn) {} - absl::Status Set(TestMessage* m, CelValue v) const override { - FieldT arg; - if (!v.GetValue(&arg)) { - return absl::InvalidArgumentError("wrong type for set"); - } - api_wrapper_.InvokeSet(m, arg); - return absl::OkStatus(); - } - - absl::StatusOr Get(const TestMessage* m) const override { - FieldT result = api_wrapper_.InvokeGet(m); - auto converted = NativeToCelValue().Convert(result); - if (converted.has_value()) { - return *converted; - } - return absl::UnimplementedError("not implemented for type"); - } - - bool Has(const TestMessage* m) const override { - return api_wrapper_.InvokeHas(m); - } - - private: - ApiWrap api_wrapper_; -}; - -template -class FieldImpl : public ProtoField { - using ApiWrap = ComplexTypeApiWrap; - - public: - FieldImpl(typename ApiWrap::GetFn get_fn, typename ApiWrap::HasFn has_fn, - typename ApiWrap::SetAllocatedFn set_fn) - : api_wrapper_(get_fn, has_fn, set_fn) {} - absl::Status Set(TestMessage* m, CelValue v) const override { - int64_t arg; - if (!v.GetValue(&arg)) { - return absl::InvalidArgumentError("wrong type for set"); - } - Int64Value* proto_value = new Int64Value(); - proto_value->set_value(arg); - api_wrapper_.InvokeSetAllocated(m, proto_value); - return absl::OkStatus(); - } - - absl::StatusOr Get(const TestMessage* m) const override { - if (!api_wrapper_.InvokeHas(m)) { - return CelValue::CreateNull(); - } - Int64Value result = api_wrapper_.InvokeGet(m); - auto converted = NativeToCelValue().Convert(std::move(result)); - if (converted.has_value()) { - return *converted; - } - return absl::UnimplementedError("not implemented for type"); - } - - bool Has(const TestMessage* m) const override { - return api_wrapper_.InvokeHas(m); - } - - private: - ApiWrap api_wrapper_; -}; - -// Simple type system for Testing. -class DemoTypeProvider; - -class DemoTimestamp : public LegacyTypeInfoApis, public LegacyTypeMutationApis { - public: - DemoTimestamp() {} - - std::string DebugString( - const MessageWrapper& wrapped_message) const override { - return std::string(GetTypename(wrapped_message)); - } - - absl::string_view GetTypename( - const MessageWrapper& wrapped_message) const override { - return "google.protobuf.Timestamp"; - } - - const LegacyTypeAccessApis* GetAccessApis( - const MessageWrapper& wrapped_message) const override { - return nullptr; - } - - bool DefinesField(absl::string_view field_name) const override { - return field_name == "seconds" || field_name == "nanos"; - } - - absl::StatusOr NewInstance( - cel::MemoryManagerRef memory_manager) const override; - - absl::StatusOr AdaptFromWellKnownType( - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder instance) const override; - - absl::Status SetField( - absl::string_view field_name, const CelValue& value, - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder& instance) const override; - - private: - absl::Status Validate(const google::protobuf::MessageLite* wrapped_message) const { - if (wrapped_message->GetTypeName() != "google.protobuf.Timestamp") { - return absl::InvalidArgumentError("not a timestamp"); - } - return absl::OkStatus(); - } -}; - -class DemoTypeInfo : public LegacyTypeInfoApis { - public: - explicit DemoTypeInfo(const DemoTypeProvider* owning_provider) - : owning_provider_(*owning_provider) {} - std::string DebugString(const MessageWrapper& wrapped_message) const override; - - absl::string_view GetTypename( - const MessageWrapper& wrapped_message) const override; - - const LegacyTypeAccessApis* GetAccessApis( - const MessageWrapper& wrapped_message) const override; - - private: - const DemoTypeProvider& owning_provider_; -}; - -class DemoTestMessage : public LegacyTypeInfoApis, - public LegacyTypeMutationApis, - public LegacyTypeAccessApis { - public: - explicit DemoTestMessage(const DemoTypeProvider* owning_provider); - - std::string DebugString( - const MessageWrapper& wrapped_message) const override { - return std::string(GetTypename(wrapped_message)); - } - - absl::string_view GetTypename( - const MessageWrapper& wrapped_message) const override { - return "google.api.expr.runtime.TestMessage"; - } - - const LegacyTypeAccessApis* GetAccessApis( - const MessageWrapper& wrapped_message) const override { - return this; - } - - const LegacyTypeMutationApis* GetMutationApis( - const MessageWrapper& wrapped_message) const override { - return this; - } - - absl::optional FindFieldByName( - absl::string_view name) const override { - if (auto it = fields_.find(name); it != fields_.end()) { - return FieldDescription{0, name}; - } - return absl::nullopt; - } - - bool DefinesField(absl::string_view field_name) const override { - return fields_.contains(field_name); - } - - absl::StatusOr NewInstance( - cel::MemoryManagerRef memory_manager) const override; - - absl::StatusOr AdaptFromWellKnownType( - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder instance) const override; - - absl::Status SetField( - absl::string_view field_name, const CelValue& value, - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder& instance) const override; - - absl::StatusOr HasField( - absl::string_view field_name, - const CelValue::MessageWrapper& value) const override; - - absl::StatusOr GetField( - absl::string_view field_name, const CelValue::MessageWrapper& instance, - ProtoWrapperTypeOptions unboxing_option, - cel::MemoryManagerRef memory_manager) const override; - - std::vector ListFields( - const CelValue::MessageWrapper& instance) const override { - std::vector fields; - fields.reserve(fields_.size()); - for (const auto& field : fields_) { - fields.emplace_back(field.first); - } - return fields; - } - - private: - using Field = ProtoField; - const DemoTypeProvider& owning_provider_; - absl::flat_hash_map> fields_; -}; - -class DemoTypeProvider : public LegacyTypeProvider { - public: - DemoTypeProvider() : timestamp_type_(), test_message_(this), info_(this) {} - const LegacyTypeInfoApis* GetTypeInfoInstance() const { return &info_; } - - absl::optional ProvideLegacyType( - absl::string_view name) const override { - if (name == "google.protobuf.Timestamp") { - return LegacyTypeAdapter(nullptr, ×tamp_type_); - } else if (name == "google.api.expr.runtime.TestMessage") { - return LegacyTypeAdapter(&test_message_, &test_message_); - } - return absl::nullopt; - } - - absl::optional ProvideLegacyTypeInfo( - absl::string_view name) const override { - if (name == "google.protobuf.Timestamp") { - return ×tamp_type_; - } else if (name == "google.api.expr.runtime.TestMessage") { - return &test_message_; - } - return absl::nullopt; - } - - const std::string& GetStableType( - const google::protobuf::MessageLite* wrapped_message) const { - std::string name(wrapped_message->GetTypeName()); - auto [iter, inserted] = stable_types_.insert(name); - return *iter; - } - - CelValue WrapValue(const google::protobuf::MessageLite* message) const { - return CelValue::CreateMessageWrapper( - CelValue::MessageWrapper(message, GetTypeInfoInstance())); - } - - private: - DemoTimestamp timestamp_type_; - DemoTestMessage test_message_; - DemoTypeInfo info_; - mutable absl::node_hash_set stable_types_; // thread hostile -}; - -std::string DemoTypeInfo::DebugString( - const MessageWrapper& wrapped_message) const { - return std::string(wrapped_message.message_ptr()->GetTypeName()); -} - -absl::string_view DemoTypeInfo::GetTypename( - const MessageWrapper& wrapped_message) const { - return owning_provider_.GetStableType(wrapped_message.message_ptr()); -} - -const LegacyTypeAccessApis* DemoTypeInfo::GetAccessApis( - const MessageWrapper& wrapped_message) const { - auto adapter = owning_provider_.ProvideLegacyType( - wrapped_message.message_ptr()->GetTypeName()); - if (adapter.has_value()) { - return adapter->access_apis(); - } - return nullptr; // not implemented yet. -} - -absl::StatusOr DemoTimestamp::NewInstance( - cel::MemoryManagerRef memory_manager) const { - auto* ts = google::protobuf::Arena::Create( - cel::extensions::ProtoMemoryManagerArena(memory_manager)); - return CelValue::MessageWrapper::Builder(ts); -} -absl::StatusOr DemoTimestamp::AdaptFromWellKnownType( - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder instance) const { - auto value = Unwrap(instance.message_ptr()); - ABSL_ASSERT(value.has_value()); - return *value; -} - -absl::Status DemoTimestamp::SetField( - absl::string_view field_name, const CelValue& value, - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder& instance) const { - ABSL_ASSERT(Validate(instance.message_ptr()).ok()); - auto* mutable_ts = cel::internal::down_cast( - instance.message_ptr()); - if (field_name == "seconds" && value.IsInt64()) { - mutable_ts->set_seconds(value.Int64OrDie()); - } else if (field_name == "nanos" && value.IsInt64()) { - mutable_ts->set_nanos(value.Int64OrDie()); - } else { - return absl::UnknownError("no such field"); - } - return absl::OkStatus(); -} - -DemoTestMessage::DemoTestMessage(const DemoTypeProvider* owning_provider) - : owning_provider_(*owning_provider) { - // Note: has for non-optional scalars on proto3 messages would be implemented - // as msg.value() != MessageType::default_instance.value(), but omited for - // brevity. - fields_["int64_value"] = std::make_unique>( - &TestMessage::int64_value, - /*has_fn=*/nullptr, &TestMessage::set_int64_value); - fields_["double_value"] = std::make_unique>( - &TestMessage::double_value, - /*has_fn=*/nullptr, &TestMessage::set_double_value); - fields_["bool_value"] = std::make_unique>( - &TestMessage::bool_value, - /*has_fn=*/nullptr, &TestMessage::set_bool_value); - fields_["int64_wrapper_value"] = - std::make_unique>( - &TestMessage::int64_wrapper_value, - &TestMessage::has_int64_wrapper_value, - &TestMessage::set_allocated_int64_wrapper_value); -} - -absl::StatusOr DemoTestMessage::NewInstance( - cel::MemoryManagerRef memory_manager) const { - auto* ts = google::protobuf::Arena::Create( - cel::extensions::ProtoMemoryManagerArena(memory_manager)); - return CelValue::MessageWrapper::Builder(ts); -} - -absl::Status DemoTestMessage::SetField( - absl::string_view field_name, const CelValue& value, - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder& instance) const { - auto iter = fields_.find(field_name); - if (iter == fields_.end()) { - return absl::UnknownError("no such field"); - } - auto* mutable_test_msg = - cel::internal::down_cast(instance.message_ptr()); - return iter->second->Set(mutable_test_msg, value); -} - -absl::StatusOr DemoTestMessage::AdaptFromWellKnownType( - cel::MemoryManagerRef memory_manager, - CelValue::MessageWrapper::Builder instance) const { - return CelValue::CreateMessageWrapper( - instance.Build(owning_provider_.GetTypeInfoInstance())); -} - -absl::StatusOr DemoTestMessage::HasField( - absl::string_view field_name, const CelValue::MessageWrapper& value) const { - auto iter = fields_.find(field_name); - if (iter == fields_.end()) { - return absl::UnknownError("no such field"); - } - auto* test_msg = - cel::internal::down_cast(value.message_ptr()); - return iter->second->Has(test_msg); -} - -// Access field on instance. -absl::StatusOr DemoTestMessage::GetField( - absl::string_view field_name, const CelValue::MessageWrapper& instance, - ProtoWrapperTypeOptions unboxing_option, - cel::MemoryManagerRef memory_manager) const { - auto iter = fields_.find(field_name); - if (iter == fields_.end()) { - return absl::UnknownError("no such field"); - } - auto* test_msg = - cel::internal::down_cast(instance.message_ptr()); - return iter->second->Get(test_msg); -} - -TEST(PortableCelExprBuilderFactoryTest, CreateNullOnMissingTypeProvider) { - std::unique_ptr builder = - CreatePortableExprBuilder(nullptr); - ASSERT_EQ(builder, nullptr); -} - -TEST(PortableCelExprBuilderFactoryTest, CreateSuccess) { - google::protobuf::Arena arena; - - InterpreterOptions opts; - Activation activation; - std::unique_ptr builder = - CreatePortableExprBuilder(std::make_unique(), opts); - ASSERT_OK_AND_ASSIGN( - ParsedExpr expr, - parser::Parse("google.protobuf.Timestamp{seconds: 3000, nanos: 20}")); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); - - ASSERT_OK_AND_ASSIGN( - auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); - - ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); - - absl::Time result_time; - ASSERT_TRUE(result.GetValue(&result_time)); - EXPECT_EQ(result_time, - absl::UnixEpoch() + absl::Minutes(50) + absl::Nanoseconds(20)); -} - -TEST(PortableCelExprBuilderFactoryTest, CreateCustomMessage) { - google::protobuf::Arena arena; - - InterpreterOptions opts; - Activation activation; - std::unique_ptr builder = - CreatePortableExprBuilder(std::make_unique(), opts); - ASSERT_OK_AND_ASSIGN( - ParsedExpr expr, - parser::Parse("google.api.expr.runtime.TestMessage{int64_value: 20, " - "double_value: 3.5}.double_value")); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); - - ASSERT_OK_AND_ASSIGN( - auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); - - ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); - - double result_double; - ASSERT_TRUE(result.GetValue(&result_double)) << result.DebugString(); - EXPECT_EQ(result_double, 3.5); -} - -TEST(PortableCelExprBuilderFactoryTest, ActivationAndCreate) { - google::protobuf::Arena arena; - - InterpreterOptions opts; - Activation activation; - auto provider = std::make_unique(); - auto* provider_view = provider.get(); - std::unique_ptr builder = - CreatePortableExprBuilder(std::move(provider), opts); - builder->set_container("google.api.expr.runtime"); - ASSERT_OK_AND_ASSIGN( - ParsedExpr expr, - parser::Parse("TestMessage{int64_value: 20, bool_value: " - "false}.bool_value || my_var.bool_value ? 1 : 2")); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); - - ASSERT_OK_AND_ASSIGN( - auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); - TestMessage my_var; - my_var.set_bool_value(true); - activation.InsertValue("my_var", provider_view->WrapValue(&my_var)); - ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); - - int64_t result_int64; - ASSERT_TRUE(result.GetValue(&result_int64)) << result.DebugString(); - EXPECT_EQ(result_int64, 1); -} - -TEST(PortableCelExprBuilderFactoryTest, WrapperTypes) { - google::protobuf::Arena arena; - InterpreterOptions opts; - opts.enable_heterogeneous_equality = true; - Activation activation; - auto provider = std::make_unique(); - const auto* provider_view = provider.get(); - std::unique_ptr builder = - CreatePortableExprBuilder(std::move(provider), opts); - builder->set_container("google.api.expr.runtime"); - ASSERT_OK_AND_ASSIGN(ParsedExpr null_expr, - parser::Parse("my_var.int64_wrapper_value != null ? " - "my_var.int64_wrapper_value > 29 : null")); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); - TestMessage my_var; - my_var.set_bool_value(true); - activation.InsertValue("my_var", provider_view->WrapValue(&my_var)); - - ASSERT_OK_AND_ASSIGN( - auto plan, - builder->CreateExpression(&null_expr.expr(), &null_expr.source_info())); - ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); - - EXPECT_TRUE(result.IsNull()) << result.DebugString(); - - my_var.mutable_int64_wrapper_value()->set_value(30); - - ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena)); - bool result_bool; - ASSERT_TRUE(result.GetValue(&result_bool)) << result.DebugString(); - EXPECT_TRUE(result_bool); -} - -TEST(PortableCelExprBuilderFactoryTest, SimpleBuiltinFunctions) { - google::protobuf::Arena arena; - InterpreterOptions opts; - opts.enable_heterogeneous_equality = true; - Activation activation; - auto provider = std::make_unique(); - std::unique_ptr builder = - CreatePortableExprBuilder(std::move(provider), opts); - builder->set_container("google.api.expr.runtime"); - - // Fairly complicated but silly expression to cover a mix of builtins - // (comparisons, arithmetic, datetime). - ASSERT_OK_AND_ASSIGN( - ParsedExpr ternary_expr, - parser::Parse( - "TestMessage{int64_value: 2}.int64_value + 1 < " - " TestMessage{double_value: 3.5}.double_value - 0.1 ? " - " (google.protobuf.Timestamp{seconds: 300} - timestamp(240) " - " >= duration('1m') ? 'yes' : 'no') :" - " null")); - ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), opts)); - - ASSERT_OK_AND_ASSIGN(auto plan, - builder->CreateExpression(&ternary_expr.expr(), - &ternary_expr.source_info())); - ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); - - ASSERT_TRUE(result.IsString()) << result.DebugString(); - EXPECT_EQ(result.StringOrDie().value(), "yes"); -} - -} // namespace -} // namespace google::api::expr::runtime From feaecdbad313e3a4cd8006ed03b8ce6d5c332bc6 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 28 Oct 2024 14:18:52 -0700 Subject: [PATCH 009/180] Remove `TypeFactory` from `TypeIntrospector` PiperOrigin-RevId: 690748051 --- checker/internal/BUILD | 2 - checker/internal/type_check_env.cc | 26 +++++-------- checker/internal/type_check_env.h | 12 ++---- checker/internal/type_checker_impl.cc | 33 +++------------- common/type_factory.h | 13 ------- common/type_introspector.cc | 20 +++++----- common/type_introspector.h | 22 ++++------- common/type_manager.cc | 2 +- common/type_manager.h | 6 +-- common/types/legacy_type_manager.h | 9 +---- .../thread_compatible_type_introspector.cc | 39 ------------------- .../thread_compatible_type_introspector.h | 12 ------ common/types/thread_compatible_type_manager.h | 7 +--- common/value_factory.h | 5 +++ common/values/legacy_value_manager.h | 6 ++- common/values/struct_value_builder.cc | 7 ++-- .../values/thread_compatible_value_manager.h | 6 ++- eval/public/structs/legacy_type_provider.cc | 5 +-- eval/public/structs/legacy_type_provider.h | 5 +-- extensions/protobuf/type_introspector.cc | 9 ++--- extensions/protobuf/type_introspector.h | 8 ++-- extensions/protobuf/type_introspector_test.cc | 14 +++---- runtime/internal/composed_type_provider.cc | 11 +++--- runtime/internal/composed_type_provider.h | 5 +-- runtime/optional_types.cc | 2 +- 25 files changed, 83 insertions(+), 203 deletions(-) delete mode 100644 common/types/thread_compatible_type_introspector.cc diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 453a2c309..6d43a83b3 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -127,11 +127,9 @@ cc_library( "//common:constant", "//common:decl", "//common:expr", - "//common:memory", "//common:source", "//common:type", "//common:type_kind", - "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index b95aa652a..1ac9bd618 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -26,7 +26,6 @@ #include "common/constant.h" #include "common/decl.h" #include "common/type.h" -#include "common/type_factory.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" @@ -59,7 +58,7 @@ absl::Nullable TypeCheckEnv::LookupFunction( } absl::StatusOr> TypeCheckEnv::LookupTypeName( - TypeFactory& type_factory, absl::string_view name) const { + absl::string_view name) const { { // Check the descriptor pool first, then fallback to custom type providers. absl::Nullable descriptor = @@ -77,7 +76,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( do { for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); ++iter) { - auto type = (*iter)->FindType(type_factory, name); + auto type = (*iter)->FindType(name); if (!type.ok() || type->has_value()) { return type; } @@ -88,8 +87,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( } absl::StatusOr> TypeCheckEnv::LookupEnumConstant( - TypeFactory& type_factory, absl::string_view type, - absl::string_view value) const { + absl::string_view type, absl::string_view value) const { { // Check the descriptor pool first, then fallback to custom type providers. absl::Nullable enum_descriptor = @@ -113,7 +111,7 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( do { for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type_factory, type, value); + auto enum_constant = (*iter)->FindEnumConstant(type, value); if (!enum_constant.ok()) { return enum_constant.status(); } @@ -133,10 +131,8 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( } absl::StatusOr> TypeCheckEnv::LookupTypeConstant( - TypeFactory& type_factory, absl::Nonnull arena, - absl::string_view name) const { - CEL_ASSIGN_OR_RETURN(absl::optional type, - LookupTypeName(type_factory, name)); + absl::Nonnull arena, absl::string_view name) const { + CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type)); } @@ -145,16 +141,14 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( size_t last_dot = name.rfind('.'); absl::string_view enum_name_candidate = name.substr(0, last_dot); absl::string_view value_name_candidate = name.substr(last_dot + 1); - return LookupEnumConstant(type_factory, enum_name_candidate, - value_name_candidate); + return LookupEnumConstant(enum_name_candidate, value_name_candidate); } return absl::nullopt; } absl::StatusOr> TypeCheckEnv::LookupStructField( - TypeFactory& type_factory, absl::string_view type_name, - absl::string_view field_name) const { + absl::string_view type_name, absl::string_view field_name) const { { // Check the descriptor pool first, then fallback to custom type providers. absl::Nullable descriptor = @@ -180,8 +174,8 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( // checking field accesses. for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); ++iter) { - auto field_info = (*iter)->FindStructTypeFieldByName( - type_factory, type_name, field_name); + auto field_info = + (*iter)->FindStructTypeFieldByName(type_name, field_name); if (!field_info.ok() || field_info->has_value()) { return field_info; } diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 9cad4ae72..f42a205a9 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -31,7 +31,6 @@ #include "common/constant.h" #include "common/decl.h" #include "common/type.h" -#include "common/type_factory.h" #include "common/type_introspector.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -161,15 +160,13 @@ class TypeCheckEnv { absl::string_view name) const; absl::StatusOr> LookupTypeName( - TypeFactory& type_factory, absl::string_view name) const; + absl::string_view name) const; absl::StatusOr> LookupStructField( - TypeFactory& type_factory, absl::string_view type_name, - absl::string_view field_name) const; + absl::string_view type_name, absl::string_view field_name) const; absl::StatusOr> LookupTypeConstant( - TypeFactory& type_factory, absl::Nonnull arena, - absl::string_view type_name) const; + absl::Nonnull arena, absl::string_view type_name) const; TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return TypeCheckEnv(this); @@ -189,8 +186,7 @@ class TypeCheckEnv { parent_(parent) {} absl::StatusOr> LookupEnumConstant( - TypeFactory& type_factory, absl::string_view type, - absl::string_view value) const; + absl::string_view type, absl::string_view value) const; absl::Nonnull> descriptor_pool_; std::string container_; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index fd665b5d9..6cda11e4f 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -48,12 +48,9 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" -#include "common/memory.h" #include "common/source.h" #include "common/type.h" -#include "common/type_factory.h" #include "common/type_kind.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" @@ -67,19 +64,6 @@ using Severity = TypeCheckIssue::Severity; constexpr const char kOptionalSelect[] = "_?._"; -class TrivialTypeFactory : public TypeFactory { - public: - explicit TrivialTypeFactory(absl::Nonnull arena) - : arena_(arena) {} - - MemoryManagerRef GetMemoryManager() const override { - return extensions::ProtoMemoryManagerRef(arena_); - } - - private: - absl::Nonnull arena_; -}; - std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } @@ -253,7 +237,7 @@ class ResolveVisitor : public AstVisitorBase { const TypeCheckEnv& env, const AstImpl& ast, TypeInferenceContext& inference_context, std::vector& issues, - absl::Nonnull arena, TypeFactory& type_factory) + absl::Nonnull arena) : container_(container), namespace_generator_(std::move(namespace_generator)), env_(&env), @@ -262,7 +246,6 @@ class ResolveVisitor : public AstVisitorBase { ast_(&ast), root_scope_(env.MakeVariableScope()), arena_(arena), - type_factory_(&type_factory), current_scope_(&root_scope_) {} void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } @@ -408,7 +391,7 @@ class ResolveVisitor : public AstVisitorBase { // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( absl::optional field_info, - env_->LookupStructField(*type_factory_, resolved_name, field.name())); + env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); continue; @@ -455,7 +438,6 @@ class ResolveVisitor : public AstVisitorBase { absl::Nonnull ast_; VariableScope root_scope_; absl::Nonnull arena_; - absl::Nonnull type_factory_; // state tracking for the traversal. const VariableScope* current_scope_; @@ -669,7 +651,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, Type resolved_type; namespace_generator_.GenerateCandidates( create_struct.name(), [&](const absl::string_view name) { - auto type = env_->LookupTypeName(*type_factory_, name); + auto type = env_->LookupTypeName(name); if (!type.ok()) { status.Update(type.status()); return false; @@ -938,7 +920,7 @@ absl::Nullable ResolveVisitor::LookupIdentifier( return decl; } absl::StatusOr> constant = - env_->LookupTypeConstant(*type_factory_, arena_, name); + env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { status_.Update(constant.status()); @@ -1032,8 +1014,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, switch (operand_type.kind()) { case TypeKind::kStruct: { StructType struct_type = operand_type.GetStruct(); - auto field_info = - env_->LookupStructField(*type_factory_, struct_type.name(), field); + auto field_info = env_->LookupStructField(struct_type.name(), field); if (!field_info.ok()) { status_.Update(field_info.status()); return absl::nullopt; @@ -1228,10 +1209,8 @@ absl::StatusOr TypeCheckerImpl::Check( TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); - TrivialTypeFactory type_factory(&type_arena); ResolveVisitor visitor(env_.container(), std::move(generator), env_, ast_impl, - type_inference_context, issues, &type_arena, - type_factory); + type_inference_context, issues, &type_arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; diff --git a/common/type_factory.h b/common/type_factory.h index 5752a232d..33829ea8b 100644 --- a/common/type_factory.h +++ b/common/type_factory.h @@ -15,27 +15,14 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ -#include "common/memory.h" - namespace cel { -namespace common_internal { -class PiecewiseValueManager; -} - // `TypeFactory` is the preferred way for constructing compound types such as // lists, maps, structs, and opaques. It caches types and avoids constructing // them multiple times. class TypeFactory { public: virtual ~TypeFactory() = default; - - // Returns a `MemoryManagerRef` which is used to manage memory for internal - // data structures as well as created types. - virtual MemoryManagerRef GetMemoryManager() const = 0; - - protected: - friend class common_internal::PiecewiseValueManager; }; } // namespace cel diff --git a/common/type_introspector.cc b/common/type_introspector.cc index 23151654a..f0737dda2 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -214,49 +214,47 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { } // namespace absl::StatusOr> TypeIntrospector::FindType( - TypeFactory& type_factory, absl::string_view name) const { + absl::string_view name) const { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(name); it != well_known_types.end()) { return it->second.type; } - return FindTypeImpl(type_factory, name); + return FindTypeImpl(name); } absl::StatusOr> -TypeIntrospector::FindEnumConstant(TypeFactory& type_factory, - absl::string_view type, +TypeIntrospector::FindEnumConstant(absl::string_view type, absl::string_view value) const { if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { return EnumConstant{NullType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } - return FindEnumConstantImpl(type_factory, type, value); + return FindEnumConstantImpl(type, value); } absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByName(TypeFactory& type_factory, - absl::string_view type, +TypeIntrospector::FindStructTypeFieldByName(absl::string_view type, absl::string_view name) const { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(type); it != well_known_types.end()) { return it->second.FieldByName(name); } - return FindStructTypeFieldByNameImpl(type_factory, type, name); + return FindStructTypeFieldByNameImpl(type, name); } absl::StatusOr> TypeIntrospector::FindTypeImpl( - TypeFactory&, absl::string_view) const { + absl::string_view) const { return absl::nullopt; } absl::StatusOr> -TypeIntrospector::FindEnumConstantImpl(TypeFactory&, absl::string_view, +TypeIntrospector::FindEnumConstantImpl(absl::string_view, absl::string_view) const { return absl::nullopt; } absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByNameImpl(TypeFactory&, absl::string_view, +TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, absl::string_view) const { return absl::nullopt; } diff --git a/common/type_introspector.h b/common/type_introspector.h index 2e504465b..00b1da758 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -46,40 +46,34 @@ class TypeIntrospector { virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. - absl::StatusOr> FindType(TypeFactory& type_factory, - absl::string_view name) const; + absl::StatusOr> FindType(absl::string_view name) const; // `FindEnumConstant` find a fully qualified enumerator name `name` in enum // type `type`. absl::StatusOr> FindEnumConstant( - TypeFactory& type_factory, absl::string_view type, - absl::string_view value) const; + absl::string_view type, absl::string_view value) const; // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const; + absl::string_view type, absl::string_view name) const; // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. absl::StatusOr> FindStructTypeFieldByName( - TypeFactory& type_factory, const StructType& type, - absl::string_view name) const { - return FindStructTypeFieldByName(type_factory, type.name(), name); + const StructType& type, absl::string_view name) const { + return FindStructTypeFieldByName(type.name(), name); } protected: virtual absl::StatusOr> FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const; + absl::string_view name) const; virtual absl::StatusOr> FindEnumConstantImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view value) const; + absl::string_view type, absl::string_view value) const; virtual absl::StatusOr> - FindStructTypeFieldByNameImpl(TypeFactory& type_factory, - absl::string_view type, + FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const; }; diff --git a/common/type_manager.cc b/common/type_manager.cc index 42e9180d9..0928f5346 100644 --- a/common/type_manager.cc +++ b/common/type_manager.cc @@ -27,7 +27,7 @@ Shared NewThreadCompatibleTypeManager( Shared type_introspector) { return memory_manager .MakeShared( - memory_manager, std::move(type_introspector)); + std::move(type_introspector)); } } // namespace cel diff --git a/common/type_manager.h b/common/type_manager.h index c1980b57d..576fe22fa 100644 --- a/common/type_manager.h +++ b/common/type_manager.h @@ -33,19 +33,19 @@ class TypeManager : public virtual TypeFactory { // See `TypeIntrospector::FindType`. absl::StatusOr> FindType(absl::string_view name) { - return GetTypeIntrospector().FindType(*this, name); + return GetTypeIntrospector().FindType(name); } // See `TypeIntrospector::FindStructTypeFieldByName`. absl::StatusOr> FindStructTypeFieldByName( absl::string_view type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(*this, type, name); + return GetTypeIntrospector().FindStructTypeFieldByName(type, name); } // See `TypeIntrospector::FindStructTypeFieldByName`. absl::StatusOr> FindStructTypeFieldByName( const StructType& type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(*this, type, name); + return GetTypeIntrospector().FindStructTypeFieldByName(type, name); } protected: diff --git a/common/types/legacy_type_manager.h b/common/types/legacy_type_manager.h index 198e00d22..238335b52 100644 --- a/common/types/legacy_type_manager.h +++ b/common/types/legacy_type_manager.h @@ -28,12 +28,8 @@ namespace cel::common_internal { // and only then. class LegacyTypeManager : public virtual TypeManager { public: - LegacyTypeManager(MemoryManagerRef memory_manager, - const TypeIntrospector& type_introspector) - : memory_manager_(memory_manager), - type_introspector_(type_introspector) {} - - MemoryManagerRef GetMemoryManager() const final { return memory_manager_; } + explicit LegacyTypeManager(const TypeIntrospector& type_introspector) + : type_introspector_(type_introspector) {} protected: const TypeIntrospector& GetTypeIntrospector() const final { @@ -41,7 +37,6 @@ class LegacyTypeManager : public virtual TypeManager { } private: - MemoryManagerRef memory_manager_; const TypeIntrospector& type_introspector_; }; diff --git a/common/types/thread_compatible_type_introspector.cc b/common/types/thread_compatible_type_introspector.cc deleted file mode 100644 index 47ff31cd8..000000000 --- a/common/types/thread_compatible_type_introspector.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#include "common/types/thread_compatible_type_introspector.h" - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" - -namespace cel::common_internal { - -absl::StatusOr> -ThreadCompatibleTypeIntrospector::FindTypeImpl(TypeFactory&, - absl::string_view) const { - return absl::nullopt; -} - -absl::StatusOr> -ThreadCompatibleTypeIntrospector::FindStructTypeFieldByNameImpl( - TypeFactory&, absl::string_view, absl::string_view) const { - return absl::nullopt; -} - -} // namespace cel::common_internal diff --git a/common/types/thread_compatible_type_introspector.h b/common/types/thread_compatible_type_introspector.h index 159d3fa19..870ea9054 100644 --- a/common/types/thread_compatible_type_introspector.h +++ b/common/types/thread_compatible_type_introspector.h @@ -17,10 +17,6 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" #include "common/type_introspector.h" namespace cel::common_internal { @@ -31,14 +27,6 @@ namespace cel::common_internal { class ThreadCompatibleTypeIntrospector : public virtual TypeIntrospector { public: ThreadCompatibleTypeIntrospector() = default; - - protected: - absl::StatusOr> FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const override; - - absl::StatusOr> FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const override; }; } // namespace cel::common_internal diff --git a/common/types/thread_compatible_type_manager.h b/common/types/thread_compatible_type_manager.h index 848186774..937d6b13f 100644 --- a/common/types/thread_compatible_type_manager.h +++ b/common/types/thread_compatible_type_manager.h @@ -28,12 +28,8 @@ namespace cel::common_internal { class ThreadCompatibleTypeManager : public virtual TypeManager { public: explicit ThreadCompatibleTypeManager( - MemoryManagerRef memory_manager, Shared type_introspector) - : memory_manager_(memory_manager), - type_introspector_(std::move(type_introspector)) {} - - MemoryManagerRef GetMemoryManager() const final { return memory_manager_; } + : type_introspector_(std::move(type_introspector)) {} protected: TypeIntrospector& GetTypeIntrospector() const final { @@ -41,7 +37,6 @@ class ThreadCompatibleTypeManager : public virtual TypeManager { } private: - MemoryManagerRef memory_manager_; Shared type_introspector_; }; diff --git a/common/value_factory.h b/common/value_factory.h index 4d11a6ce7..9d0c6635a 100644 --- a/common/value_factory.h +++ b/common/value_factory.h @@ -26,6 +26,7 @@ #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/json.h" +#include "common/memory.h" #include "common/type.h" #include "common/type_factory.h" #include "common/unknown.h" @@ -40,6 +41,10 @@ class PiecewiseValueManager; // `ValueFactory` is the preferred way for constructing values. class ValueFactory : public virtual TypeFactory { public: + // Returns a `MemoryManagerRef` which is used to manage memory for internal + // data structures as well as created types. + virtual MemoryManagerRef GetMemoryManager() const = 0; + // `CreateValueFromJson` constructs a new `Value` that is equivalent to the // JSON value `json`. ABSL_DEPRECATED("Avoid using Json/JsonArray/JsonObject") diff --git a/common/values/legacy_value_manager.h b/common/values/legacy_value_manager.h index d8b4b024d..61c9b9bae 100644 --- a/common/values/legacy_value_manager.h +++ b/common/values/legacy_value_manager.h @@ -30,10 +30,11 @@ class LegacyValueManager : public LegacyTypeManager, public ValueManager { public: LegacyValueManager(MemoryManagerRef memory_manager, const TypeReflector& type_reflector) - : LegacyTypeManager(memory_manager, type_reflector), + : LegacyTypeManager(type_reflector), + memory_manager_(memory_manager), type_reflector_(type_reflector) {} - using LegacyTypeManager::GetMemoryManager; + MemoryManagerRef GetMemoryManager() const override { return memory_manager_; } protected: const TypeReflector& GetTypeReflector() const final { @@ -41,6 +42,7 @@ class LegacyValueManager : public LegacyTypeManager, public ValueManager { } private: + MemoryManagerRef memory_manager_; const TypeReflector& type_reflector_; }; diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 8ddbfb967..26717947b 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -74,7 +74,7 @@ class CompatTypeReflector final : public TypeReflector { protected: absl::StatusOr> FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const final { + absl::string_view name) const final { // We do not have to worry about well known types here. // `TypeIntrospector::FindType` handles those directly. const auto* desc = descriptor_pool()->FindMessageTypeByName(name); @@ -85,7 +85,7 @@ class CompatTypeReflector final : public TypeReflector { } absl::StatusOr> - FindEnumConstantImpl(TypeFactory&, absl::string_view type, + FindEnumConstantImpl(absl::string_view type, absl::string_view value) const final { const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool()->FindEnumTypeByName(type); @@ -109,8 +109,7 @@ class CompatTypeReflector final : public TypeReflector { } absl::StatusOr> FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const final { + absl::string_view type, absl::string_view name) const final { // We do not have to worry about well known types here. // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. const auto* desc = descriptor_pool()->FindMessageTypeByName(type); diff --git a/common/values/thread_compatible_value_manager.h b/common/values/thread_compatible_value_manager.h index d90959fb9..798cfcdf1 100644 --- a/common/values/thread_compatible_value_manager.h +++ b/common/values/thread_compatible_value_manager.h @@ -32,15 +32,17 @@ class ThreadCompatibleValueManager : public ThreadCompatibleTypeManager, public: explicit ThreadCompatibleValueManager(MemoryManagerRef memory_manager, Shared type_reflector) - : ThreadCompatibleTypeManager(memory_manager, type_reflector), + : ThreadCompatibleTypeManager(type_reflector), + memory_manager_(memory_manager), type_reflector_(std::move(type_reflector)) {} - using ThreadCompatibleTypeManager::GetMemoryManager; + MemoryManagerRef GetMemoryManager() const override { return memory_manager_; } protected: TypeReflector& GetTypeReflector() const final { return *type_reflector_; } private: + MemoryManagerRef memory_manager_; Shared type_reflector_; }; diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index 9d58ef048..efab51f00 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -149,7 +149,7 @@ LegacyTypeProvider::DeserializeValueImpl(cel::ValueFactory& value_factory, } absl::StatusOr> LegacyTypeProvider::FindTypeImpl( - cel::TypeFactory& type_factory, absl::string_view name) const { + absl::string_view name) const { if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); if (descriptor != nullptr) { @@ -163,8 +163,7 @@ absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::StatusOr> LegacyTypeProvider::FindStructTypeFieldByNameImpl( - cel::TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const { + absl::string_view type, absl::string_view name) const { if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { if (auto field_desc = (*type_info)->FindFieldByName(name); field_desc.has_value()) { diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h index f9245511a..78a8a09b5 100644 --- a/eval/public/structs/legacy_type_provider.h +++ b/eval/public/structs/legacy_type_provider.h @@ -69,11 +69,10 @@ class LegacyTypeProvider : public cel::TypeReflector { const absl::Cord& value) const final; absl::StatusOr> FindTypeImpl( - cel::TypeFactory& type_factory, absl::string_view name) const final; + absl::string_view name) const final; absl::StatusOr> - FindStructTypeFieldByNameImpl(cel::TypeFactory& type_factory, - absl::string_view type, + FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const final; }; diff --git a/extensions/protobuf/type_introspector.cc b/extensions/protobuf/type_introspector.cc index f681d41fc..8b445c359 100644 --- a/extensions/protobuf/type_introspector.cc +++ b/extensions/protobuf/type_introspector.cc @@ -18,14 +18,13 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" -#include "common/type_factory.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" namespace cel::extensions { absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const { + absl::string_view name) const { // We do not have to worry about well known types here. // `TypeIntrospector::FindType` handles those directly. const auto* desc = descriptor_pool()->FindMessageTypeByName(name); @@ -36,8 +35,7 @@ absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( } absl::StatusOr> -ProtoTypeIntrospector::FindEnumConstantImpl(TypeFactory&, - absl::string_view type, +ProtoTypeIntrospector::FindEnumConstantImpl(absl::string_view type, absl::string_view value) const { const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool()->FindEnumTypeByName(type); @@ -62,8 +60,7 @@ ProtoTypeIntrospector::FindEnumConstantImpl(TypeFactory&, absl::StatusOr> ProtoTypeIntrospector::FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const { + absl::string_view type, absl::string_view name) const { // We do not have to worry about well known types here. // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. const auto* desc = descriptor_pool()->FindMessageTypeByName(type); diff --git a/extensions/protobuf/type_introspector.h b/extensions/protobuf/type_introspector.h index eae18aa06..034b908fb 100644 --- a/extensions/protobuf/type_introspector.h +++ b/extensions/protobuf/type_introspector.h @@ -20,7 +20,6 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" -#include "common/type_factory.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" @@ -41,15 +40,14 @@ class ProtoTypeIntrospector : public virtual TypeIntrospector { protected: absl::StatusOr> FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const final; + absl::string_view name) const final; absl::StatusOr> - FindEnumConstantImpl(TypeFactory&, absl::string_view type, + FindEnumConstantImpl(absl::string_view type, absl::string_view value) const final; absl::StatusOr> FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const final; + absl::string_view type, absl::string_view name) const final; private: absl::Nonnull const descriptor_pool_; diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc index 0ea783fe1..592fa946f 100644 --- a/extensions/protobuf/type_introspector_test.cc +++ b/extensions/protobuf/type_introspector_test.cc @@ -69,8 +69,7 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstant) { ASSERT_OK_AND_ASSIGN( auto enum_constant, introspector.FindEnumConstant( - type_manager(), "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", - "BAZ")); + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "BAZ")); ASSERT_TRUE(enum_constant.has_value()); EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); @@ -82,8 +81,7 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantNull) { ProtoTypeIntrospector introspector; ASSERT_OK_AND_ASSIGN( auto enum_constant, - introspector.FindEnumConstant(type_manager(), "google.protobuf.NullValue", - "NULL_VALUE")); + introspector.FindEnumConstant("google.protobuf.NullValue", "NULL_VALUE")); ASSERT_TRUE(enum_constant.has_value()); EXPECT_EQ(enum_constant->type.kind(), TypeKind::kNull); EXPECT_EQ(enum_constant->type_full_name, "google.protobuf.NullValue"); @@ -94,9 +92,8 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantNull) { TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantUnknownEnum) { ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant(type_manager(), "NotARealEnum", "BAZ")); + ASSERT_OK_AND_ASSIGN(auto enum_constant, + introspector.FindEnumConstant("NotARealEnum", "BAZ")); EXPECT_FALSE(enum_constant.has_value()); } @@ -106,8 +103,7 @@ TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantUnknownValue) { ASSERT_OK_AND_ASSIGN( auto enum_constant, introspector.FindEnumConstant( - type_manager(), "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", - "QUX")); + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "QUX")); ASSERT_FALSE(enum_constant.has_value()); } diff --git a/runtime/internal/composed_type_provider.cc b/runtime/internal/composed_type_provider.cc index 60d15193e..5bb377956 100644 --- a/runtime/internal/composed_type_provider.cc +++ b/runtime/internal/composed_type_provider.cc @@ -89,9 +89,9 @@ ComposedTypeProvider::DeserializeValueImpl(ValueFactory& value_factory, } absl::StatusOr> ComposedTypeProvider::FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const { + absl::string_view name) const { for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto result, provider->FindType(type_factory, name)); + CEL_ASSIGN_OR_RETURN(auto result, provider->FindType(name)); if (result.has_value()) { return result; } @@ -101,11 +101,10 @@ absl::StatusOr> ComposedTypeProvider::FindTypeImpl( absl::StatusOr> ComposedTypeProvider::FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const { + absl::string_view type, absl::string_view name) const { for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto result, provider->FindStructTypeFieldByName( - type_factory, type, name)); + CEL_ASSIGN_OR_RETURN(auto result, + provider->FindStructTypeFieldByName(type, name)); if (result.has_value()) { return result; } diff --git a/runtime/internal/composed_type_provider.h b/runtime/internal/composed_type_provider.h index c74141d5a..b451e27fa 100644 --- a/runtime/internal/composed_type_provider.h +++ b/runtime/internal/composed_type_provider.h @@ -74,11 +74,10 @@ class ComposedTypeProvider : public TypeReflector { const absl::Cord& value) const override; absl::StatusOr> FindTypeImpl( - TypeFactory& type_factory, absl::string_view name) const override; + absl::string_view name) const override; absl::StatusOr> FindStructTypeFieldByNameImpl( - TypeFactory& type_factory, absl::string_view type, - absl::string_view name) const override; + absl::string_view type, absl::string_view name) const override; private: std::vector> providers_; diff --git a/runtime/optional_types.cc b/runtime/optional_types.cc index ccca7cfa4..9aa36e77a 100644 --- a/runtime/optional_types.cc +++ b/runtime/optional_types.cc @@ -282,7 +282,7 @@ absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, class OptionalTypeProvider final : public TypeReflector { protected: absl::StatusOr> FindTypeImpl( - TypeFactory&, absl::string_view name) const override { + absl::string_view name) const override { if (name != "optional_type") { return absl::nullopt; } From 02cbb35f0b7cba2cecaec8d93a7768b7b5b030e1 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 29 Oct 2024 07:26:57 -0700 Subject: [PATCH 010/180] Ensure mandatory messages are linked with the generated descriptor pool PiperOrigin-RevId: 691010654 --- internal/BUILD | 1 + internal/well_known_types.cc | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/internal/BUILD b/internal/BUILD index 13cc004b5..105ce07ea 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -683,6 +683,7 @@ cc_library( "//common:json", "//common:memory", "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index 653ba65ad..0ef9da869 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -23,11 +23,13 @@ #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" +#include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" +#include "absl/base/call_once.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" @@ -1735,7 +1737,36 @@ absl::StatusOr GetFieldMaskReflection( return reflection; } +namespace { + +ABSL_CONST_INIT absl::once_flag link_well_known_message_reflection; + +void LinkWellKnownMessageReflection() { + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); +} + +} // namespace + absl::Status Reflection::Initialize(absl::Nonnull pool) { + if (pool == DescriptorPool::generated_pool()) { + absl::call_once(link_well_known_message_reflection, + &LinkWellKnownMessageReflection); + } CEL_RETURN_IF_ERROR(NullValue().Initialize(pool)); CEL_RETURN_IF_ERROR(BoolValue().Initialize(pool)); CEL_RETURN_IF_ERROR(Int32Value().Initialize(pool)); From a546afa5d02f7c15cf19af25f4390625824d8126 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 29 Oct 2024 09:30:31 -0700 Subject: [PATCH 011/180] Add a `[[maybe_unused]]` annotation PiperOrigin-RevId: 691049754 --- internal/well_known_types.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index 0ef9da869..f18d11b03 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1739,7 +1739,8 @@ absl::StatusOr GetFieldMaskReflection( namespace { -ABSL_CONST_INIT absl::once_flag link_well_known_message_reflection; +[[maybe_unused]] ABSL_CONST_INIT absl::once_flag + link_well_known_message_reflection; void LinkWellKnownMessageReflection() { google::protobuf::LinkMessageReflection(); From 883d7273f84271525de0cedbbe9bdc047aad31d6 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 29 Oct 2024 10:35:05 -0700 Subject: [PATCH 012/180] Make the minimal descriptor pool public PiperOrigin-RevId: 691074913 --- common/BUILD | 21 ++++++++++ common/minimal_descriptor_pool.cc | 27 +++++++++++++ common/minimal_descriptor_pool.h | 31 +++++++++++++++ .../minimal_descriptor_pool_test.cc | 38 +++++++++---------- internal/BUILD | 10 ----- 5 files changed, 98 insertions(+), 29 deletions(-) create mode 100644 common/minimal_descriptor_pool.cc create mode 100644 common/minimal_descriptor_pool.h rename {internal => common}/minimal_descriptor_pool_test.cc (85%) diff --git a/common/BUILD b/common/BUILD index 0544969b2..cc37eee2b 100644 --- a/common/BUILD +++ b/common/BUILD @@ -867,3 +867,24 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "minimal_descriptor_pool", + srcs = ["minimal_descriptor_pool.cc"], + hdrs = ["minimal_descriptor_pool.h"], + deps = [ + "//internal:minimal_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "minimal_descriptor_pool_test", + srcs = ["minimal_descriptor_pool_test.cc"], + deps = [ + ":minimal_descriptor_pool", + "//internal:testing", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/minimal_descriptor_pool.cc b/common/minimal_descriptor_pool.cc new file mode 100644 index 000000000..fc29790b4 --- /dev/null +++ b/common/minimal_descriptor_pool.cc @@ -0,0 +1,27 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/minimal_descriptor_pool.h" + +#include "absl/base/nullability.h" +#include "internal/minimal_descriptor_pool.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::Nonnull GetMinimalDescriptorPool() { + return internal::GetMinimalDescriptorPool(); +} + +} // namespace cel diff --git a/common/minimal_descriptor_pool.h b/common/minimal_descriptor_pool.h new file mode 100644 index 000000000..17772bcb0 --- /dev/null +++ b/common/minimal_descriptor_pool.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ + +#include "absl/base/nullability.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// GetMinimalDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` +// which includes has the minimally necessary descriptors required by the Common +// Expression Language. The returned `google::protobuf::DescriptorPool` is valid for the +// lifetime of the process and should not be deleted. +absl::Nonnull GetMinimalDescriptorPool(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ diff --git a/internal/minimal_descriptor_pool_test.cc b/common/minimal_descriptor_pool_test.cc similarity index 85% rename from internal/minimal_descriptor_pool_test.cc rename to common/minimal_descriptor_pool_test.cc index 642d448e0..a654a1a1a 100644 --- a/internal/minimal_descriptor_pool_test.cc +++ b/common/minimal_descriptor_pool_test.cc @@ -12,23 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "internal/minimal_descriptor_pool.h" +#include "common/minimal_descriptor_pool.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" -namespace cel::internal { +namespace cel { namespace { using ::testing::NotNull; -TEST(MinimalDescriptorPool, NullValue) { +TEST(GetMinimalDescriptorPool, NullValue) { ASSERT_THAT(GetMinimalDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"), NotNull()); } -TEST(MinimalDescriptorPool, BoolValue) { +TEST(GetMinimalDescriptorPool, BoolValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.BoolValue"); ASSERT_THAT(desc, NotNull()); @@ -36,7 +36,7 @@ TEST(MinimalDescriptorPool, BoolValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE); } -TEST(MinimalDescriptorPool, Int32Value) { +TEST(GetMinimalDescriptorPool, Int32Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int32Value"); ASSERT_THAT(desc, NotNull()); @@ -44,7 +44,7 @@ TEST(MinimalDescriptorPool, Int32Value) { google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE); } -TEST(MinimalDescriptorPool, Int64Value) { +TEST(GetMinimalDescriptorPool, Int64Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int64Value"); ASSERT_THAT(desc, NotNull()); @@ -52,7 +52,7 @@ TEST(MinimalDescriptorPool, Int64Value) { google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE); } -TEST(MinimalDescriptorPool, UInt32Value) { +TEST(GetMinimalDescriptorPool, UInt32Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt32Value"); ASSERT_THAT(desc, NotNull()); @@ -60,7 +60,7 @@ TEST(MinimalDescriptorPool, UInt32Value) { google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE); } -TEST(MinimalDescriptorPool, UInt64Value) { +TEST(GetMinimalDescriptorPool, UInt64Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt64Value"); ASSERT_THAT(desc, NotNull()); @@ -68,7 +68,7 @@ TEST(MinimalDescriptorPool, UInt64Value) { google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE); } -TEST(MinimalDescriptorPool, FloatValue) { +TEST(GetMinimalDescriptorPool, FloatValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.FloatValue"); ASSERT_THAT(desc, NotNull()); @@ -76,7 +76,7 @@ TEST(MinimalDescriptorPool, FloatValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE); } -TEST(MinimalDescriptorPool, DoubleValue) { +TEST(GetMinimalDescriptorPool, DoubleValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.DoubleValue"); ASSERT_THAT(desc, NotNull()); @@ -84,7 +84,7 @@ TEST(MinimalDescriptorPool, DoubleValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE); } -TEST(MinimalDescriptorPool, BytesValue) { +TEST(GetMinimalDescriptorPool, BytesValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.BytesValue"); ASSERT_THAT(desc, NotNull()); @@ -92,7 +92,7 @@ TEST(MinimalDescriptorPool, BytesValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE); } -TEST(MinimalDescriptorPool, StringValue) { +TEST(GetMinimalDescriptorPool, StringValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.StringValue"); ASSERT_THAT(desc, NotNull()); @@ -100,14 +100,14 @@ TEST(MinimalDescriptorPool, StringValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE); } -TEST(MinimalDescriptorPool, Any) { +TEST(GetMinimalDescriptorPool, Any) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName("google.protobuf.Any"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_ANY); } -TEST(MinimalDescriptorPool, Duration) { +TEST(GetMinimalDescriptorPool, Duration) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Duration"); ASSERT_THAT(desc, NotNull()); @@ -115,7 +115,7 @@ TEST(MinimalDescriptorPool, Duration) { google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION); } -TEST(MinimalDescriptorPool, Timestamp) { +TEST(GetMinimalDescriptorPool, Timestamp) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Timestamp"); ASSERT_THAT(desc, NotNull()); @@ -123,14 +123,14 @@ TEST(MinimalDescriptorPool, Timestamp) { google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP); } -TEST(MinimalDescriptorPool, Value) { +TEST(GetMinimalDescriptorPool, Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); } -TEST(MinimalDescriptorPool, ListValue) { +TEST(GetMinimalDescriptorPool, ListValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.ListValue"); ASSERT_THAT(desc, NotNull()); @@ -138,7 +138,7 @@ TEST(MinimalDescriptorPool, ListValue) { google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); } -TEST(MinimalDescriptorPool, Struct) { +TEST(GetMinimalDescriptorPool, Struct) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Struct"); ASSERT_THAT(desc, NotNull()); @@ -146,4 +146,4 @@ TEST(MinimalDescriptorPool, Struct) { } } // namespace -} // namespace cel::internal +} // namespace cel diff --git a/internal/BUILD b/internal/BUILD index 105ce07ea..fa833c1ae 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -535,16 +535,6 @@ cc_library( ], ) -cc_test( - name = "minimal_descriptor_pool_test", - srcs = ["minimal_descriptor_pool_test.cc"], - deps = [ - ":minimal_descriptor_pool", - ":testing", - "@com_google_protobuf//:protobuf", - ], -) - cel_proto_transitive_descriptor_set( name = "testing_descriptor_set", testonly = True, From 03536668776c94f40a55780387b87696d5b449f7 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 30 Oct 2024 14:47:53 -0700 Subject: [PATCH 013/180] Overhaul propagation of arenas and message factories PiperOrigin-RevId: 691561398 --- conformance/service.cc | 13 +- eval/compiler/BUILD | 49 ++++-- .../cel_expression_builder_flat_impl.h | 40 ++++- .../cel_expression_builder_flat_impl_test.cc | 24 +-- eval/compiler/constant_folding.cc | 59 +++++-- eval/compiler/constant_folding.h | 9 +- eval/compiler/constant_folding_test.cc | 75 +++++---- eval/compiler/flat_expr_builder.cc | 29 ++-- eval/compiler/flat_expr_builder.h | 45 +++--- .../flat_expr_builder_comprehensions_test.cc | 30 ++-- eval/compiler/flat_expr_builder_extensions.h | 64 +++++++- .../flat_expr_builder_extensions_test.cc | 62 +++++--- ...ilder_short_circuiting_conformance_test.cc | 26 +-- eval/compiler/flat_expr_builder_test.cc | 150 +++++++++--------- eval/compiler/instrumentation_test.cc | 34 ++-- .../regex_precompilation_optimization_test.cc | 18 ++- eval/compiler/resolver.cc | 19 ++- eval/compiler/resolver.h | 15 +- eval/eval/BUILD | 2 + eval/eval/evaluator_core.h | 17 +- eval/eval/evaluator_core_test.cc | 4 +- eval/public/BUILD | 8 +- eval/public/cel_expr_builder_factory.cc | 61 +++---- eval/public/cel_expression.h | 21 +-- eval/tests/modern_benchmark_test.cc | 3 +- extensions/select_optimization.cc | 4 +- runtime/BUILD | 7 +- runtime/constant_folding.cc | 126 ++++++++++++--- runtime/constant_folding.h | 46 ++++-- runtime/constant_folding_test.cc | 7 +- runtime/internal/BUILD | 34 +++- runtime/internal/runtime_env.cc | 74 +++++++++ runtime/internal/runtime_env.h | 133 ++++++++++++++++ runtime/internal/runtime_env_testing.cc | 33 ++++ runtime/internal/runtime_env_testing.h | 29 ++++ runtime/internal/runtime_impl.h | 38 +++-- runtime/regex_precompilation_test.cc | 13 +- runtime/runtime_builder_factory.cc | 9 +- 38 files changed, 1019 insertions(+), 411 deletions(-) create mode 100644 runtime/internal/runtime_env.cc create mode 100644 runtime/internal/runtime_env.h create mode 100644 runtime/internal/runtime_env_testing.cc create mode 100644 runtime/internal/runtime_env_testing.h diff --git a/conformance/service.cc b/conformance/service.cc index a6d90c0b0..a0593971d 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -481,8 +481,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { if (enable_optimizations_) { CEL_RETURN_IF_ERROR(cel::extensions::EnableConstantFolding( - builder, constant_memory_manager_, - google::protobuf::MessageFactory::generated_factory())); + builder, google::protobuf::MessageFactory::generated_factory())); } CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( builder, cel::ReferenceResolverEnabled::kAlways)); @@ -528,7 +527,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Check(const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) override { - auto status = DoCheck(&constant_arena_, request, response); + auto status = DoCheck(&arena_, request, response); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -614,10 +613,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { bool enable_optimizations) : options_(options), use_arena_(use_arena), - enable_optimizations_(enable_optimizations), - constant_memory_manager_( - use_arena_ ? ProtoMemoryManagerRef(&constant_arena_) - : cel::MemoryManagerRef::ReferenceCounting()) {} + enable_optimizations_(enable_optimizations) {} static absl::Status DoCheck( google::protobuf::Arena* arena, const conformance::v1alpha1::CheckRequest& request, @@ -733,8 +729,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool use_arena_; bool enable_optimizations_; - Arena constant_arena_; - cel::MemoryManagerRef constant_memory_manager_; + Arena arena_; }; } // namespace diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 396cca677..d31e90451 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -36,7 +36,9 @@ cc_library( "//internal:casts", "//runtime:runtime_options", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", @@ -46,6 +48,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -56,7 +59,6 @@ cc_test( ":flat_expr_builder_extensions", ":resolver", "//base/ast_internal:expr", - "//common:casting", "//common:memory", "//common:native_type", "//common:value", @@ -71,8 +73,12 @@ cc_test( "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -115,7 +121,6 @@ cc_library( "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/eval:trace_step", - "//eval/public:cel_type_registry", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_issue", @@ -123,13 +128,14 @@ cc_library( "//runtime:type_registry", "//runtime/internal:convert_constant", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -137,6 +143,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -155,6 +162,7 @@ cc_test( ":qualified_reference_resolver", "//base:function", "//base:function_descriptor", + "//common:value", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", @@ -174,13 +182,13 @@ cc_test( "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", - "//extensions/protobuf:memory_manager", "//internal:proto_file_util", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -213,6 +221,7 @@ cc_test( "//internal:testing", "//parser", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -236,14 +245,20 @@ cc_library( "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/public:cel_expression", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", "//extensions/protobuf:ast_converters", "//internal:status_macros", "//runtime:runtime_issue", "//runtime:runtime_options", + "//runtime/internal:runtime_env", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], @@ -270,12 +285,12 @@ cc_test( "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//extensions:bindings_ext", - "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "//parser", "//parser:macro", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -303,7 +318,7 @@ cc_library( "//base:kind", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", - "//common:allocator", + "//common:memory", "//common:value", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", @@ -331,8 +346,6 @@ cc_test( "//base:ast", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", - "//common:memory", - "//common:type", "//common:value", "//eval/eval:const_value_step", "//eval/eval:create_list_step", @@ -342,12 +355,16 @@ cc_test( "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//internal:testing_descriptor_pool", "//parser", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -390,14 +407,12 @@ cc_library( hdrs = ["resolver.h"], deps = [ "//base:kind", - "//common:memory", "//common:type", "//common:value", "//internal:status_macros", "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:type_registry", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -445,17 +460,16 @@ cc_test( ], deps = [ ":cel_expression_builder_flat_impl", - ":flat_expr_builder", + "//base:builtins", "//eval/public:activation", "//eval/public:cel_attribute", - "//eval/public:cel_builtins", "//eval/public:cel_expression", - "//eval/public:cel_options", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", - "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -533,6 +547,9 @@ cc_test( "//parser", "//runtime:runtime_issue", "//runtime/internal:issue_collector", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -581,7 +598,6 @@ cc_test( ":instrumentation", ":regex_precompilation_optimization", "//base/ast_internal:ast_impl", - "//common:type", "//common:value", "//eval/eval:evaluator_core", "//extensions/protobuf:ast_converters", @@ -594,6 +610,9 @@ cc_test( "//runtime:runtime_options", "//runtime:standard_functions", "//runtime:type_registry", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/eval/compiler/cel_expression_builder_flat_impl.h b/eval/compiler/cel_expression_builder_flat_impl.h index 98efc4b74..7b09b7879 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.h +++ b/eval/compiler/cel_expression_builder_flat_impl.h @@ -24,11 +24,17 @@ #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "base/ast.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -37,13 +43,16 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { public: - explicit CelExpressionBuilderFlatImpl(const cel::RuntimeOptions& options) - : flat_expr_builder_(GetRegistry()->InternalGetRegistry(), - *GetTypeRegistry(), options) {} + CelExpressionBuilderFlatImpl( + absl::Nonnull> env, + const cel::RuntimeOptions& options) + : env_(std::move(env)), flat_expr_builder_(env_, options) { + ABSL_DCHECK(env_->IsInitialized()); + } - CelExpressionBuilderFlatImpl() - : flat_expr_builder_(GetRegistry()->InternalGetRegistry(), - *GetTypeRegistry()) {} + explicit CelExpressionBuilderFlatImpl( + absl::Nonnull> env) + : CelExpressionBuilderFlatImpl(std::move(env), cel::RuntimeOptions()) {} absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, @@ -64,15 +73,32 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; } void set_container(std::string container) override { - CelExpressionBuilder::set_container(container); flat_expr_builder_.set_container(std::move(container)); } + // CelFunction registry. Extension function should be registered with it + // prior to expression creation. + CelFunctionRegistry* GetRegistry() const override { + return &env_->legacy_function_registry; + } + + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + CelTypeRegistry* GetTypeRegistry() const override { + return &env_->legacy_type_registry; + } + + absl::string_view container() const override { + return flat_expr_builder_.container(); + } + private: absl::StatusOr> CreateExpressionImpl( std::unique_ptr converted_ast, std::vector* warnings) const; + absl::Nonnull> env_; FlatExprBuilder flat_expr_builder_; }; diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index c70a04396..46212128b 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -44,11 +44,11 @@ #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/testing/matchers.h" #include "extensions/bindings_ext.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -62,6 +62,7 @@ namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::NestedTestAllTypes; using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; @@ -78,7 +79,7 @@ using ::testing::NotNull; TEST(CelExpressionBuilderFlatImplTest, Error) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); @@ -87,7 +88,7 @@ TEST(CelExpressionBuilderFlatImplTest, Error) { TEST(CelExpressionBuilderFlatImplTest, ParsedExpr) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, @@ -167,7 +168,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -195,13 +196,12 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { // Unbounded. options.max_recursion_depth = -1; options.enable_comprehension_list_append = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); builder.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - cel::extensions::ProtoMemoryManagerRef(&arena))); + cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options.regex_max_program_size)); @@ -232,7 +232,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { // Unbounded. options.max_recursion_depth = -1; options.enable_recursive_tracing = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -261,7 +261,7 @@ TEST_P(RecursivePlanTest, Disabled) { google::protobuf::Arena arena; // disabled. options.max_recursion_depth = 0; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); @@ -343,7 +343,7 @@ TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN( @@ -367,7 +367,7 @@ TEST(CelExpressionBuilderFlatImplTest, CheckedExpr) { checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, @@ -387,7 +387,7 @@ TEST(CelExpressionBuilderFlatImplTest, CheckedExprWithWarnings) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index faf0b0387..73eccad0e 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -29,7 +29,7 @@ #include "base/builtins.h" #include "base/kind.h" #include "base/type_provider.h" -#include "common/allocator.h" +#include "common/memory.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" @@ -39,6 +39,7 @@ #include "internal/status_macros.h" #include "runtime/activation.h" #include "runtime/internal/convert_constant.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { @@ -73,13 +74,18 @@ using ::google::api::expr::runtime::Resolver; class ConstantFoldingExtension : public ProgramOptimizer { public: ConstantFoldingExtension( - Allocator<> allocator, - absl::Nullable message_factory, + absl::Nullable> shared_arena, + absl::Nonnull arena, + absl::Nullable> + shared_message_factory, + absl::Nonnull message_factory, const TypeProvider& type_provider) - : memory_manager_(allocator), + : shared_arena_(std::move(shared_arena)), + arena_(arena), + shared_message_factory_(std::move(shared_message_factory)), + message_factory_(message_factory), state_(kDefaultStackLimit, kComprehensionSlotCount, type_provider, - MemoryManager(allocator)), - message_factory_(message_factory) {} + MemoryManager::Pooling(arena)) {} absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; @@ -99,12 +105,15 @@ class ConstantFoldingExtension : public ProgramOptimizer { // if the comprehension variables are only used in a const way. static constexpr size_t kComprehensionSlotCount = 0; - MemoryManager memory_manager_; + absl::Nullable> shared_arena_; + ABSL_ATTRIBUTE_UNUSED + absl::Nonnull arena_; + absl::Nullable> + shared_message_factory_; + ABSL_ATTRIBUTE_UNUSED + absl::Nonnull message_factory_; Activation empty_; FlatExpressionEvaluatorState state_; - // Not yet used, will be in future. - ABSL_ATTRIBUTE_UNUSED - absl::Nullable message_factory_; std::vector is_const_; }; @@ -254,13 +263,29 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, } // namespace ProgramOptimizerFactory CreateConstantFoldingOptimizer( - Allocator<> allocator, - absl::Nullable message_factory) { - return [allocator, message_factory](PlannerContext& ctx, const AstImpl&) - -> absl::StatusOr> { - return std::make_unique( - allocator, message_factory, ctx.value_factory().type_provider()); - }; + absl::Nullable> arena, + absl::Nullable> message_factory) { + return + [shared_arena = std::move(arena), + shared_message_factory = std::move(message_factory)]( + PlannerContext& context, + const AstImpl&) -> absl::StatusOr> { + // If one was explicitly provided during planning or none was explicitly + // provided during configuration, request one from the planning context. + // Otherwise use the one provided during configuration. + absl::Nonnull arena = + context.HasExplicitArena() || shared_arena == nullptr + ? context.MutableArena() + : shared_arena.get(); + absl::Nonnull message_factory = + context.HasExplicitMessageFactory() || + shared_message_factory == nullptr + ? context.MutableMessageFactory() + : shared_message_factory.get(); + return std::make_unique( + shared_arena, arena, shared_message_factory, message_factory, + context.type_reflector()); + }; } } // namespace cel::runtime_internal diff --git a/eval/compiler/constant_folding.h b/eval/compiler/constant_folding.h index a69df01a3..532ba2b4b 100644 --- a/eval/compiler/constant_folding.h +++ b/eval/compiler/constant_folding.h @@ -15,9 +15,11 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ +#include + #include "absl/base/nullability.h" -#include "common/allocator.h" #include "eval/compiler/flat_expr_builder_extensions.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { @@ -31,8 +33,9 @@ namespace cel::runtime_internal { // extension. google::api::expr::runtime::ProgramOptimizerFactory CreateConstantFoldingOptimizer( - Allocator<> allocator, - absl::Nullable message_factory = nullptr); + absl::Nullable> arena = nullptr, + absl::Nullable> message_factory = + nullptr); } // namespace cel::runtime_internal diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index 7aafa7442..fcecf1297 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -18,17 +18,14 @@ #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" -#include "common/memory.h" -#include "common/type_factory.h" -#include "common/type_manager.h" #include "common/value.h" -#include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" @@ -40,9 +37,12 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" @@ -58,6 +58,7 @@ using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CreateConstValueStep; @@ -74,16 +75,20 @@ using ::testing::SizeIs; class UpdatedConstantFoldingTest : public testing::Test { public: UpdatedConstantFoldingTest() - : value_factory_(ProtoMemoryManagerRef(&arena_), + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + value_factory_(ProtoMemoryManagerRef(&arena_), type_registry_.GetComposedTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError), resolver_("", function_registry_, type_registry_, value_factory_, type_registry_.resolveable_enums()) {} protected: + absl::Nonnull> env_; google::protobuf::Arena arena_; - cel::FunctionRegistry function_registry_; - cel::TypeRegistry type_registry_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; cel::common_internal::LegacyValueManager value_factory_; cel::RuntimeOptions options_; IssueCollector issue_collector_; @@ -143,12 +148,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -205,12 +210,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -264,12 +269,12 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -320,12 +325,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesList) { program_builder.ExitSubexpression(&create_list); // Insert the list creation step - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -377,12 +382,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesMap) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -435,12 +440,12 @@ TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. @@ -494,12 +499,12 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); - google::protobuf::Arena arena; ProgramOptimizerFactory constant_folder_factory = - CreateConstantFoldingOptimizer(ProtoMemoryManagerRef(&arena_)); + CreateConstantFoldingOptimizer(); // Act / Assert ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1bd7c205b..fcb4d5c44 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -82,6 +82,7 @@ #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -2128,23 +2129,20 @@ std::vector FlattenExpressionTable( absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const { - // These objects are expected to remain scoped to one build call -- references - // to them shouldn't be persisted in any part of the result expression. - cel::common_internal::LegacyValueManager value_factory( - cel::MemoryManagerRef::ReferenceCounting(), - type_registry_.GetComposedTypeProvider()); - RuntimeIssue::Severity max_severity = options_.fail_on_warnings ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); Resolver resolver(container_, function_registry_, type_registry_, - value_factory, type_registry_.resolveable_enums(), + type_registry_.GetComposedTypeProvider(), + type_registry_.resolveable_enums(), options_.enable_qualified_type_identifiers); + std::shared_ptr arena; ProgramBuilder program_builder; - PlannerContext extension_context(resolver, options_, value_factory, - issue_collector, program_builder); + PlannerContext extension_context(env_, resolver, options_, + type_registry_.GetComposedTypeProvider(), + issue_collector, program_builder, arena); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); @@ -2166,6 +2164,11 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( } } + // These objects are expected to remain scoped to one build call -- references + // to them shouldn't be persisted in any part of the result expression. + cel::common_internal::LegacyValueManager value_factory( + cel::MemoryManagerRef::ReferenceCounting(), + type_registry_.GetComposedTypeProvider()); FlatExprVisitor visitor(resolver, options_, std::move(optimizers), ast_impl.reference_map(), value_factory, issue_collector, program_builder, extension_context, @@ -2187,9 +2190,15 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::vector subexpressions = FlattenExpressionTable(program_builder, execution_path); + if (arena != nullptr && arena->SpaceUsed() == 0) { + // Arena was requested but no memory was used. Destroy it. + arena.reset(); + } + return FlatExpression(std::move(execution_path), std::move(subexpressions), visitor.slot_count(), - type_registry_.GetComposedTypeProvider(), options_); + type_registry_.GetComposedTypeProvider(), options_, + std::move(arena)); } } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index f1081d5c4..eafb58781 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -22,12 +22,14 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "base/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/cel_type_registry.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" @@ -38,29 +40,28 @@ namespace google::api::expr::runtime { // Builds instances of CelExpressionFlatImpl. class FlatExprBuilder { public: - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const CelTypeRegistry& type_registry, - const cel::RuntimeOptions& options) - : options_(options), + FlatExprBuilder( + absl::Nonnull> + env, + const cel::RuntimeOptions& options) + : env_(std::move(env)), + options_(options), container_(options.container), - function_registry_(function_registry), - type_registry_(type_registry.InternalGetModernRegistry()) {} - - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options) - : options_(options), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry) {} + + FlatExprBuilder( + absl::Nonnull> + env, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::RuntimeOptions& options) + : env_(std::move(env)), + options_(options), container_(options.container), function_registry_(function_registry), type_registry_(type_registry) {} - // Create a flat expr builder with defaulted options. - FlatExprBuilder(const cel::FunctionRegistry& function_registry, - const CelTypeRegistry& type_registry) - : options_(cel::RuntimeOptions()), - function_registry_(function_registry), - type_registry_(type_registry.InternalGetModernRegistry()) {} - void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } @@ -73,12 +74,16 @@ class FlatExprBuilder { container_ = std::move(container); } + absl::string_view container() const { return container_; } + // TODO: Add overload for cref AST. At the moment, all the users // can pass ownership of a freshly converted AST. absl::StatusOr CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const; + const cel::runtime_internal::RuntimeEnv& env() const { return *env_; } + const cel::RuntimeOptions& options() const { return options_; } // Called by `cel::extensions::EnableOptionalTypes` to indicate that special @@ -86,6 +91,8 @@ class FlatExprBuilder { void enable_optional_types() { enable_optional_types_ = true; } private: + const absl::Nonnull> + env_; cel::RuntimeOptions options_; std::string container_; bool enable_optional_types_ = false; diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 4b9ff2b8c..9d46d8dd8 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -34,6 +34,7 @@ #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" @@ -43,6 +44,7 @@ namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::ParsedExpr; using ::testing::HasSubstr; @@ -66,7 +68,7 @@ class CelExpressionBuilderFlatImplComprehensionsTest TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].filter(x, [3, 4].all(y, x < y))")); @@ -84,7 +86,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -105,7 +107,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7].exists_one(a, a == 7)")); @@ -122,7 +124,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7, 7].exists_one(a, a == 7)")); @@ -140,7 +142,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { cel::RuntimeOptions options = GetRuntimeOptions(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("items.exists(i, i < 0)")); @@ -203,7 +205,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, })pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -256,7 +258,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -300,7 +302,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -357,7 +359,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -425,7 +427,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -472,7 +474,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -524,7 +526,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -571,7 +573,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -614,7 +616,7 @@ TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, &expr)); cel::RuntimeOptions options = GetRuntimeOptions(); - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT( diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 10f5513ce..f7d46de0e 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -27,6 +27,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" @@ -39,6 +40,7 @@ #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" #include "common/native_type.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/compiler/resolver.h" @@ -47,7 +49,10 @@ #include "eval/eval/trace_step.h" #include "internal/casts.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -321,16 +326,35 @@ const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) { // Class representing FlatExpr internals exposed to extensions. class PlannerContext { public: - explicit PlannerContext( + PlannerContext( + std::shared_ptr environment, const Resolver& resolver, const cel::RuntimeOptions& options, - cel::ValueManager& value_factory, + cel::ValueManager& value_manager, cel::runtime_internal::IssueCollector& issue_collector, - ProgramBuilder& program_builder) - : resolver_(resolver), - value_factory_(value_factory), + ProgramBuilder& program_builder, + std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::shared_ptr message_factory = nullptr) + : PlannerContext(std::move(environment), resolver, options, + value_manager.type_provider(), issue_collector, + program_builder, arena, std::move(message_factory)) {} + + PlannerContext( + std::shared_ptr environment, + const Resolver& resolver, const cel::RuntimeOptions& options, + const cel::TypeReflector& type_reflector, + cel::runtime_internal::IssueCollector& issue_collector, + ProgramBuilder& program_builder, + std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::shared_ptr message_factory = nullptr) + : environment_(std::move(environment)), + resolver_(resolver), + type_reflector_(type_reflector), options_(options), issue_collector_(issue_collector), - program_builder_(program_builder) {} + program_builder_(program_builder), + arena_(arena), + explicit_arena_(arena_ != nullptr), + message_factory_(std::move(message_factory)) {} ProgramBuilder& program_builder() { return program_builder_; } @@ -374,18 +398,42 @@ class PlannerContext { std::unique_ptr step); const Resolver& resolver() const { return resolver_; } - cel::ValueManager& value_factory() const { return value_factory_; } + const cel::TypeReflector& type_reflector() const { return type_reflector_; } const cel::RuntimeOptions& options() const { return options_; } cel::runtime_internal::IssueCollector& issue_collector() { return issue_collector_; } + // Returns `true` if an arena was explicitly provided during planning. + bool HasExplicitArena() const { return explicit_arena_; } + + absl::Nonnull MutableArena() { + if (!explicit_arena_ && arena_ == nullptr) { + arena_ = std::make_shared(); + } + ABSL_DCHECK(arena_ != nullptr); + return arena_.get(); + } + + // Returns `true` if a message factory was explicitly provided during + // planning. + bool HasExplicitMessageFactory() const { return message_factory_ != nullptr; } + + absl::Nonnull MutableMessageFactory() { + return HasExplicitMessageFactory() ? message_factory_.get() + : environment_->MutableMessageFactory(); + } + private: + const std::shared_ptr environment_; const Resolver& resolver_; - cel::ValueManager& value_factory_; + const cel::TypeReflector& type_reflector_; const cel::RuntimeOptions& options_; cel::runtime_internal::IssueCollector& issue_collector_; ProgramBuilder& program_builder_; + std::shared_ptr& arena_; + const bool explicit_arena_; + const std::shared_ptr message_factory_; }; // Interface for Ast Transforms. diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc index 1374cdfbf..c3b22c5ca 100644 --- a/eval/compiler/flat_expr_builder_extensions_test.cc +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "eval/compiler/flat_expr_builder_extensions.h" +#include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" @@ -31,9 +33,12 @@ #include "internal/testing.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { @@ -42,6 +47,8 @@ using ::absl_testing::StatusIs; using ::cel::RuntimeIssue; using ::cel::ast_internal::Expr; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Optional; @@ -51,8 +58,9 @@ using Subexpression = ProgramBuilder::Subexpression; class PlannerContextTest : public testing::Test { public: PlannerContextTest() - : type_registry_(), - function_registry_(), + : env_(NewTestingRuntimeEnv()), + type_registry_(env_->type_registry), + function_registry_(env_->function_registry), value_factory_(cel::MemoryManagerRef::ReferenceCounting(), type_registry_.GetComposedTypeProvider()), resolver_("", function_registry_, type_registry_, value_factory_, @@ -60,8 +68,9 @@ class PlannerContextTest : public testing::Test { issue_collector_(RuntimeIssue::Severity::kError) {} protected: - cel::TypeRegistry type_registry_; - cel::FunctionRegistry function_registry_; + absl::Nonnull> env_; + cel::TypeRegistry& type_registry_; + cel::FunctionRegistry& function_registry_; cel::RuntimeOptions options_; cel::common_internal::LegacyValueManager value_factory_; Resolver resolver_; @@ -117,8 +126,9 @@ TEST_F(PlannerContextTest, GetPlan) { ASSERT_OK_AND_ASSIGN( auto step_ptrs, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(step_ptrs.b))); @@ -142,8 +152,9 @@ TEST_F(PlannerContextTest, ReplacePlan) { ASSERT_OK_AND_ASSIGN( auto step_ptrs, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), UniquePtrHolds(step_ptrs.c), @@ -172,8 +183,9 @@ TEST_F(PlannerContextTest, ExtractPlan) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); EXPECT_TRUE(context.IsSubplanInspectable(b)); @@ -191,8 +203,9 @@ TEST_F(PlannerContextTest, ExtractFailsOnReplacedNode) { ASSERT_OK(InitSimpleTree(a, b, c, value_factory_, program_builder).status()); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK(context.ReplaceSubplan(a, {})); @@ -208,8 +221,9 @@ TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); @@ -229,8 +243,9 @@ TEST_F(PlannerContextTest, ReplacePlanUpdatesSibling) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ExecutionPath new_b; @@ -263,8 +278,9 @@ TEST_F(PlannerContextTest, ReplacePlanFailsOnUpdatedNode) { ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, value_factory_, program_builder)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(plan_steps.c), @@ -289,8 +305,9 @@ TEST_F(PlannerContextTest, AddSubplanStep) { const ExpressionStep* b2_step_ptr = b2_step.get(); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK(context.AddSubplanStep(b, std::move(b2_step))); @@ -315,8 +332,9 @@ TEST_F(PlannerContextTest, AddSubplanStepFailsOnUnknownNode) { ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(value_factory_.GetNullValue(), -1)); - PlannerContext context(resolver_, options_, value_factory_, issue_collector_, - program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, options_, value_factory_, + issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(d), IsEmpty()); diff --git a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc index b7bed3655..afe7c5f9f 100644 --- a/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc +++ b/eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc @@ -2,27 +2,28 @@ // produce expressions with the same outputs. #include -#include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "base/builtins.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" -#include "eval/compiler/flat_expr_builder.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" -#include "eval/public/cel_builtins.h" #include "eval/public/cel_expression.h" -#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" -#include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::testing::Eq; using ::testing::SizeIs; @@ -104,7 +105,8 @@ class ShortCircuitingTest : public testing::TestWithParam { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; } - auto result = std::make_unique(options); + auto result = std::make_unique( + NewTestingRuntimeEnv(), options); return result; } }; @@ -114,7 +116,7 @@ TEST_P(ShortCircuitingTest, BasicAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(true)); @@ -142,7 +144,7 @@ TEST_P(ShortCircuitingTest, BasicOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(false)); @@ -170,7 +172,7 @@ TEST_P(ShortCircuitingTest, ErrorAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -200,7 +202,7 @@ TEST_P(ShortCircuitingTest, ErrorOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); @@ -230,7 +232,7 @@ TEST_P(ShortCircuitingTest, UnknownAnd) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kAnd), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); @@ -262,7 +264,7 @@ TEST_P(ShortCircuitingTest, UnknownOr) { Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - absl::Substitute(kTwoLogicalOp, builtin::kOr), &expr)); + absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 488f81a8d..ad0664777 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -33,6 +33,7 @@ #include "absl/types/span.h" #include "base/function.h" #include "base/function_descriptor.h" +#include "common/value.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/qualified_reference_resolver.h" @@ -55,12 +56,12 @@ #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/proto_file_util.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" @@ -75,9 +76,9 @@ namespace { using ::absl_testing::StatusIs; using ::cel::Value; using ::cel::expr::conformance::proto3::TestAllTypes; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::internal::test::EqualsProto; using ::cel::internal::test::ReadBinaryProtoFromFile; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; @@ -150,7 +151,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK( builder.GetRegistry()->Register(std::make_unique())); @@ -172,7 +173,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { TEST(FlatExprBuilderTest, ExprUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); @@ -181,7 +182,7 @@ TEST(FlatExprBuilderTest, ExprUnset) { TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Create an empty constant expression to ensure that it triggers an error. expr.mutable_const_expr(); @@ -193,7 +194,7 @@ TEST(FlatExprBuilderTest, ConstValueUnset) { TEST(FlatExprBuilderTest, MapKeyValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Don't set either the key or the value for the map creation step. auto* entry = expr.mutable_struct_expr()->add_entries(); @@ -211,7 +212,7 @@ TEST(FlatExprBuilderTest, MapKeyValueUnset) { TEST(FlatExprBuilderTest, MessageFieldValueUnset) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -235,7 +236,7 @@ TEST(FlatExprBuilderTest, MessageFieldValueUnset) { TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { Expr expr; SourceInfo source_info; - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); auto* call = expr.mutable_call_expr(); call->set_function(builtin::kAnd); @@ -261,7 +262,7 @@ TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -272,7 +273,7 @@ TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -294,7 +295,7 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; // Concat function not registered. @@ -338,7 +339,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; @@ -361,7 +362,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; @@ -409,7 +410,7 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { { cel::RuntimeOptions options; options.short_circuiting = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; @@ -427,7 +428,7 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; @@ -446,7 +447,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -462,7 +463,7 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -474,7 +475,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -489,7 +490,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { comprehension_expr{accu_var: "a"} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -506,7 +507,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { iter_var: "b"} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -526,7 +527,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -549,7 +550,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -575,7 +576,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { }} )", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, @@ -625,7 +626,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -657,7 +658,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); builder.set_container(".bad"); @@ -673,7 +674,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); using FunctionAdapterT = FunctionAdapter; @@ -703,7 +704,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("ext"); @@ -733,7 +734,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -760,7 +761,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderParentContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -787,7 +788,7 @@ TEST(FlatExprBuilderTest, TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -813,7 +814,7 @@ TEST(FlatExprBuilderTest, TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); @@ -842,7 +843,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); cel::RuntimeOptions options; options.fail_on_warnings = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector build_warnings; builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; @@ -888,7 +889,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -948,7 +949,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -1017,7 +1018,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); builder.set_container("com.foo"); @@ -1085,7 +1086,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -1150,13 +1151,12 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); google::protobuf::Arena arena; - auto memory_manager = ProtoMemoryManagerRef(&arena); builder.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer(memory_manager)); + cel::runtime_internal::CreateConstantFoldingOptimizer()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); @@ -1239,7 +1239,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1310,7 +1310,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1362,7 +1362,7 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { cel::RuntimeOptions options; options.comprehension_max_iterations = 1; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1392,7 +1392,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1414,7 +1414,7 @@ TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { Expr* cur_expr = &expr; cur_expr->mutable_ident_expr()->set_name(enum_name); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1431,7 +1431,7 @@ TEST(FlatExprBuilderTest, ContainerStringFormat) { SourceInfo source_info; expr.mutable_ident_expr()->set_name("ident"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.set_container(""); ASSERT_OK(builder.CreateExpression(&expr, &source_info)); @@ -1469,7 +1469,7 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); builder.GetTypeRegistry()->Register(TestEnum_descriptor()); builder.set_container(std::string(container)); @@ -1552,7 +1552,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1596,7 +1596,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { })", &expr); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1639,7 +1639,7 @@ absl::Status RunTernaryExpression(CelValue selector, CelValue value1, auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value2"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); CEL_ASSIGN_OR_RETURN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1668,7 +1668,7 @@ TEST(FlatExprBuilderTest, Ternary) { auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value1"); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1768,7 +1768,7 @@ TEST(FlatExprBuilderTest, EmptyCallList) { SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(op); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); auto build = builder.CreateExpression(&expr, &source_info); ASSERT_FALSE(build.ok()); @@ -1782,7 +1782,7 @@ TEST(FlatExprBuilderTest, HeterogeneousListsAllowed) { parser::Parse("[17, 'seventeen']")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), @@ -1812,7 +1812,7 @@ TEST(FlatExprBuilderTest, NullUnboxingEnabled) { parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -1833,7 +1833,7 @@ TEST(FlatExprBuilderTest, TypeResolve) { parser::Parse("type(message) == runtime.TestMessage")); cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1861,7 +1861,7 @@ TEST(FlatExprBuilderTest, AnyPackingList) { parser::Parse("TestAllTypes{single_any: [1, 2, 3]}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1896,7 +1896,7 @@ TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { parser::Parse("TestAllTypes{single_any: [1, 2.3]}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1929,7 +1929,7 @@ TEST(FlatExprBuilderTest, AnyPackingInt) { parser::Parse("TestAllTypes{single_any: 1}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1961,7 +1961,7 @@ TEST(FlatExprBuilderTest, AnyPackingMap) { parser::Parse("TestAllTypes{single_any: {'key': 'value'}}")); cel::RuntimeOptions options; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -1996,7 +1996,7 @@ TEST(FlatExprBuilderTest, NullUnboxingDisabled) { parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2016,7 +2016,7 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = true; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2034,7 +2034,7 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2056,7 +2056,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is unknown. We only have the proto as data, we did // not link the generated message, so it's not included in the generated pool. - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique( google::protobuf::DescriptorPool::generated_pool(), @@ -2079,7 +2079,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { // This time, the message is *known*. We are using a custom descriptor pool // that has been primed with the relevant message. - CelExpressionBuilderFlatImpl builder2; + CelExpressionBuilderFlatImpl builder2(NewTestingRuntimeEnv()); builder2.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&desc_pool, &message_factory)); @@ -2121,7 +2121,7 @@ TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { // The since this is access only, the evaluator will work with message duck // typing. - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2170,7 +2170,7 @@ TEST_P(CustomDescriptorPoolTest, TestType) { ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->RegisterTypeProvider( std::make_unique(&descriptor_pool, &message_factory)); @@ -2408,7 +2408,7 @@ TEST(FlatExprBuilderTest, BlockBadIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("bad @index"))); @@ -2430,7 +2430,7 @@ TEST(FlatExprBuilderTest, OutOfRangeBlockIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2451,7 +2451,7 @@ TEST(FlatExprBuilderTest, EarlyBlockIndex) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2466,7 +2466,7 @@ TEST(FlatExprBuilderTest, OutOfScopeCSE) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2482,7 +2482,7 @@ TEST(FlatExprBuilderTest, BlockMissingBindings) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2503,7 +2503,7 @@ TEST(FlatExprBuilderTest, BlockMissingExpression) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2524,7 +2524,7 @@ TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2546,7 +2546,7 @@ TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( @@ -2574,7 +2574,7 @@ TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -2608,7 +2608,7 @@ TEST(FlatExprBuilderTest, BlockNested) { )pb", &parsed_expr)); - CelExpressionBuilderFlatImpl builder; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, diff --git a/eval/compiler/instrumentation_test.cc b/eval/compiler/instrumentation_test.cc index beb94fe2c..78b2ba59b 100644 --- a/eval/compiler/instrumentation_test.cc +++ b/eval/compiler/instrumentation_test.cc @@ -15,14 +15,15 @@ #include "eval/compiler/instrumentation.h" #include +#include #include #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" -#include "common/type.h" #include "common/value.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" @@ -34,6 +35,8 @@ #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" @@ -45,6 +48,8 @@ namespace { using ::cel::IntValue; using ::cel::Value; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; @@ -54,7 +59,10 @@ using ::testing::UnorderedElementsAre; class InstrumentationTest : public ::testing::Test { public: InstrumentationTest() - : managed_value_factory_( + : env_(NewTestingRuntimeEnv()), + function_registry_(env_->function_registry), + type_registry_(env_->type_registry), + managed_value_factory_( type_registry_.GetComposedTypeProvider(), cel::extensions::ProtoMemoryManagerRef(&arena_)) {} void SetUp() override { @@ -62,9 +70,10 @@ class InstrumentationTest : public ::testing::Test { } protected: + absl::Nonnull> env_; cel::RuntimeOptions options_; - cel::FunctionRegistry function_registry_; - cel::TypeRegistry type_registry_; + cel::FunctionRegistry& function_registry_; + cel::TypeRegistry& type_registry_; google::protobuf::Arena arena_; cel::ManagedValueFactory managed_value_factory_; }; @@ -76,7 +85,7 @@ MATCHER_P(IsIntValue, expected, "") { } TEST_F(InstrumentationTest, Basic) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -114,7 +123,7 @@ TEST_F(InstrumentationTest, Basic) { } TEST_F(InstrumentationTest, BasicWithConstFolding) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); absl::flat_hash_map expr_id_to_value; Instrumentation expr_id_recorder = [&expr_id_to_value]( @@ -124,8 +133,7 @@ TEST_F(InstrumentationTest, BasicWithConstFolding) { return absl::OkStatus(); }; builder.AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - managed_value_factory_.get().GetMemoryManager())); + cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::ast_internal::AstImpl&) -> Instrumentation { return expr_id_recorder; @@ -161,7 +169,7 @@ TEST_F(InstrumentationTest, BasicWithConstFolding) { } TEST_F(InstrumentationTest, AndShortCircuit) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -206,7 +214,7 @@ TEST_F(InstrumentationTest, AndShortCircuit) { } TEST_F(InstrumentationTest, OrShortCircuit) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -251,7 +259,7 @@ TEST_F(InstrumentationTest, OrShortCircuit) { } TEST_F(InstrumentationTest, Ternary) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = @@ -304,7 +312,7 @@ TEST_F(InstrumentationTest, Ternary) { } TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateRegexPrecompilationExtension(0)); @@ -340,7 +348,7 @@ TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { } TEST_F(InstrumentationTest, NoopSkipped) { - FlatExprBuilder builder(function_registry_, type_registry_, options_); + FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::ast_internal::AstImpl&) -> Instrumentation { diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index 65d2d9058..2a6341a44 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -21,6 +21,7 @@ #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" #include "common/memory.h" @@ -38,6 +39,8 @@ #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/issue_collector.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "google/protobuf/arena.h" @@ -46,6 +49,8 @@ namespace { using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; @@ -54,7 +59,9 @@ namespace exprpb = cel::expr; class RegexPrecompilationExtensionTest : public testing::TestWithParam { public: RegexPrecompilationExtensionTest() - : type_registry_(*builder_.GetTypeRegistry()), + : env_(NewTestingRuntimeEnv()), + builder_(env_), + type_registry_(*builder_.GetTypeRegistry()), function_registry_(*builder_.GetRegistry()), value_factory_(cel::MemoryManagerRef::ReferenceCounting(), type_registry_.GetTypeProvider()), @@ -88,6 +95,7 @@ class RegexPrecompilationExtensionTest : public testing::TestWithParam { }; } + absl::Nonnull> env_; CelExpressionBuilderFlatImpl builder_; CelTypeRegistry& type_registry_; CelFunctionRegistry& function_registry_; @@ -106,8 +114,9 @@ TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { ProgramBuilder program_builder; cel::ast_internal::AstImpl ast_impl; ast_impl.set_is_checked(true); - PlannerContext context(resolver_, runtime_options_, value_factory_, - issue_collector_, program_builder); + std::shared_ptr arena; + PlannerContext context(env_, resolver_, runtime_options_, value_factory_, + issue_collector_, program_builder, arena); ASSERT_OK_AND_ASSIGN(std::unique_ptr optimizer, factory(context, ast_impl)); @@ -209,8 +218,7 @@ class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { public: RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { builder_.flat_expr_builder().AddProgramOptimizer( - cel::runtime_internal::CreateConstantFoldingOptimizer( - cel::MemoryManagerRef::ReferenceCounting())); + cel::runtime_internal::CreateConstantFoldingOptimizer()); } protected: diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d2f0ae184..ddbdf1be7 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -15,12 +15,10 @@ #include "eval/compiler/resolver.h" #include -#include #include #include #include -#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -30,10 +28,9 @@ #include "absl/strings/strip.h" #include "absl/types/optional.h" #include "base/kind.h" -#include "common/memory.h" #include "common/type.h" +#include "common/type_reflector.h" #include "common/value.h" -#include "common/value_manager.h" #include "internal/status_macros.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" @@ -41,18 +38,20 @@ namespace google::api::expr::runtime { +using ::cel::IntValue; +using ::cel::TypeValue; using ::cel::Value; Resolver::Resolver( absl::string_view container, const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry&, cel::ValueManager& value_factory, + const cel::TypeRegistry&, const cel::TypeReflector& type_reflector, const absl::flat_hash_map& resolveable_enums, bool resolve_qualified_type_identifiers) : namespace_prefixes_(), enum_value_map_(), function_registry_(function_registry), - value_factory_(value_factory), + type_reflector_(type_reflector), resolveable_enums_(resolveable_enums), resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) { // The constructor for the registry determines the set of possible namespace @@ -85,7 +84,7 @@ Resolver::Resolver( for (const auto& enumerator : enum_type.enumerators) { auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", enumerator.name); - enum_value_map_[key] = value_factory.CreateIntValue(enumerator.number); + enum_value_map_[key] = IntValue(enumerator.number); } } } @@ -127,9 +126,9 @@ absl::optional Resolver::FindConstant(absl::string_view name, // to do so is configured in the expression builder. If the type name is // not qualified, then it too may be returned as a constant value. if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) { - auto type_value = value_factory_.FindType(name); + auto type_value = type_reflector_.FindType(name); if (type_value.ok() && type_value->has_value()) { - return value_factory_.CreateTypeValue(**type_value); + return TypeValue(**type_value); } } } @@ -179,7 +178,7 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto qualified_names = FullyQualifiedNames(name, expr_id); for (auto& qualified_name : qualified_names) { CEL_ASSIGN_OR_RETURN(auto maybe_type, - value_factory_.FindType(qualified_name)); + type_reflector_.FindType(qualified_name)); if (maybe_type.has_value()) { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index 2d164cb14..ee0e55ce1 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -25,6 +25,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/kind.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "runtime/function_overload_reference.h" @@ -47,6 +48,18 @@ class Resolver { absl::string_view container, const cel::FunctionRegistry& function_registry, const cel::TypeRegistry& type_registry, cel::ValueManager& value_factory, + const absl::flat_hash_map& + resolveable_enums, + bool resolve_qualified_type_identifiers = true) + : Resolver(container, function_registry, type_registry, + value_factory.type_provider(), resolveable_enums, + resolve_qualified_type_identifiers) {} + + Resolver( + absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, const absl::flat_hash_map& resolveable_enums, bool resolve_qualified_type_identifiers = true); @@ -89,7 +102,7 @@ class Resolver { std::vector namespace_prefixes_; absl::flat_hash_map enum_value_map_; const cel::FunctionRegistry& function_registry_; - cel::ValueManager& value_factory_; + const cel::TypeReflector& type_reflector_; const absl::flat_hash_map& resolveable_enums_; diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 62c67c0e9..d7769f22f 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -60,6 +60,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/utility", + "@com_google_protobuf//:protobuf", ], ) @@ -541,6 +542,7 @@ cc_test( "//internal:testing", "//runtime:activation", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index b654d92b7..468a06634 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -44,6 +43,7 @@ #include "runtime/managed_value_factory.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -369,23 +369,27 @@ class FlatExpression { // value creation in evaluation FlatExpression(ExecutionPath path, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, + absl::Nullable> arena = nullptr) : path_(std::move(path)), subexpressions_({path_}), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), - options_(options) {} + options_(options), + arena_(std::move(arena)) {} FlatExpression(ExecutionPath path, std::vector subexpressions, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, + absl::Nullable> arena = nullptr) : path_(std::move(path)), subexpressions_(std::move(subexpressions)), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), - options_(options) {} + options_(options), + arena_(std::move(arena)) {} // Move-only FlatExpression(FlatExpression&&) = default; @@ -429,6 +433,9 @@ class FlatExpression { size_t comprehension_slots_size_; const cel::TypeProvider& type_provider_; cel::RuntimeOptions options_; + // Arena used during planning phase, may hold constant values so should be + // kept alive. + absl::Nullable> arena_; }; } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index 7b4404af1..da15f4b4e 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -15,6 +15,7 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -23,6 +24,7 @@ using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::interop_internal::CreateIntValue; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::testing::_; @@ -173,7 +175,7 @@ TEST(EvaluatorCoreTest, TraceTest) { cel::RuntimeOptions options; options.short_circuiting = false; - CelExpressionBuilderFlatImpl builder(options); + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); diff --git a/eval/public/BUILD b/eval/public/BUILD index be7b3a1c8..6142f3fa4 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -573,23 +573,21 @@ cc_library( ":cel_function", ":cel_options", "//base:kind", - "//base/ast_internal:ast_impl", "//common:memory", "//eval/compiler:cel_expression_builder_flat_impl", "//eval/compiler:comprehension_vulnerability_check", "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", - "//eval/compiler:flat_expr_builder_extensions", "//eval/compiler:qualified_reference_resolver", "//eval/compiler:regex_precompilation_optimization", "//eval/public/structs:protobuf_descriptor_type_provider", "//extensions:select_optimization", - "//extensions/protobuf:memory_manager", - "//internal:proto_util", + "//internal:noop_delete", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index cc061a7ea..436a85752 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -17,28 +17,28 @@ #include "eval/public/cel_expr_builder_factory.h" #include +#include +#include "absl/base/nullability.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "base/ast_internal/ast_impl.h" #include "base/kind.h" #include "common/memory.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/comprehension_vulnerability_check.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" -#include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" -#include "extensions/protobuf/memory_manager.h" #include "extensions/select_optimization.h" -#include "internal/proto_util.h" +#include "internal/noop_delete.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -47,25 +47,12 @@ namespace google::api::expr::runtime { namespace { using ::cel::MemoryManagerRef; -using ::cel::ast_internal::AstImpl; using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; using ::cel::extensions::kCelAttribute; using ::cel::extensions::kCelHasField; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::extensions::SelectOptimizationAstUpdater; using ::cel::runtime_internal::CreateConstantFoldingOptimizer; -using ::google::api::expr::internal::ValidateStandardMessageTypes; - -// Adapter for a raw arena* pointer. Manages a MemoryManager object for the -// constant folding extension. -struct ArenaBackedConstfoldingFactory { - MemoryManagerRef memory_manager; - - absl::StatusOr> operator()( - PlannerContext& ctx, const AstImpl& ast) const { - return CreateConstantFoldingOptimizer(memory_manager)(ctx, ast); - } -}; +using ::cel::runtime_internal::RuntimeEnv; } // namespace @@ -78,15 +65,27 @@ std::unique_ptr CreateCelExpressionBuilder( "CreateCelExpressionBuilder"; return nullptr; } - if (auto s = ValidateStandardMessageTypes(*descriptor_pool); !s.ok()) { - ABSL_LOG(WARNING) << "Failed to validate standard message types: " - << s.ToString(); // NOLINT: OSS compatibility - return nullptr; - } cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); - auto builder = - std::make_unique(runtime_options); + absl::Nullable> + shared_message_factory; + if (message_factory != nullptr) { + shared_message_factory = std::shared_ptr( + message_factory, + cel::internal::NoopDeleteFor()); + } + auto env = std::make_shared( + std::shared_ptr( + descriptor_pool, + cel::internal::NoopDeleteFor()), + shared_message_factory); + if (auto status = env->Initialize(); !status.ok()) { + ABSL_LOG(ERROR) << "Failed to validate standard message types: " + << status.ToString(); // NOLINT: OSS compatibility + return nullptr; + } + auto builder = std::make_unique( + std::move(env), runtime_options); builder->GetTypeRegistry() ->InternalGetModernRegistry() @@ -109,9 +108,15 @@ std::unique_ptr CreateCelExpressionBuilder( } if (options.constant_folding) { + std::shared_ptr shared_arena; + if (options.constant_arena != nullptr) { + shared_arena = std::shared_ptr( + options.constant_arena, + cel::internal::NoopDeleteFor()); + } builder->flat_expr_builder().AddProgramOptimizer( - ArenaBackedConstfoldingFactory{ - ProtoMemoryManagerRef(options.constant_arena)}); + CreateConstantFoldingOptimizer(std::move(shared_arena), + std::move(shared_message_factory))); } if (options.enable_regex_precompilation) { diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 98b58aa98..3f52ad60d 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -5,7 +5,6 @@ #include #include #include -#include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" @@ -76,10 +75,7 @@ class CelExpression { // it built. class CelExpressionBuilder { public: - CelExpressionBuilder() - : func_registry_(std::make_unique()), - type_registry_(std::make_unique()), - container_("") {} + CelExpressionBuilder() = default; virtual ~CelExpressionBuilder() = default; @@ -129,23 +125,16 @@ class CelExpressionBuilder { // CelFunction registry. Extension function should be registered with it // prior to expression creation. - CelFunctionRegistry* GetRegistry() const { return func_registry_.get(); } + virtual CelFunctionRegistry* GetRegistry() const = 0; // CEL Type registry. Provides a means to resolve the CEL built-in types to // CelValue instances, and to extend the set of types and enums known to // expressions by registering them ahead of time. - CelTypeRegistry* GetTypeRegistry() const { return type_registry_.get(); } + virtual CelTypeRegistry* GetTypeRegistry() const = 0; - virtual void set_container(std::string container) { - container_ = std::move(container); - } - - absl::string_view container() const { return container_; } + virtual void set_container(std::string container) = 0; - private: - std::unique_ptr func_registry_; - std::unique_ptr type_registry_; - std::string container_; + virtual absl::string_view container() const = 0; }; } // namespace google::api::expr::runtime diff --git a/eval/tests/modern_benchmark_test.cc b/eval/tests/modern_benchmark_test.cc index 81cf91ef0..22233210a 100644 --- a/eval/tests/modern_benchmark_test.cc +++ b/eval/tests/modern_benchmark_test.cc @@ -102,8 +102,7 @@ std::unique_ptr StandardRuntimeOrDie( break; case ConstFoldingEnabled::kYes: ABSL_CHECK(arena != nullptr); - ABSL_CHECK_OK(extensions::EnableConstantFolding( - *builder, ProtoMemoryManagerRef(arena))); + ABSL_CHECK_OK(extensions::EnableConstantFolding(*builder)); break; } diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 2e34096e0..2083d3b82 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -152,7 +152,7 @@ Expr MakeSelectPathExpr( absl::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { - auto field_or = planner_context.value_factory() + auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) .value_or(absl::nullopt); if (field_or.has_value()) { @@ -515,7 +515,7 @@ class RewriterImpl : public AstRewriterBase { } absl::optional GetRuntimeType(absl::string_view type_name) { - return planner_context_.value_factory().FindType(type_name).value_or( + return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } diff --git a/runtime/BUILD b/runtime/BUILD index c453afb89..1d8c3dbfc 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -213,8 +213,10 @@ cc_library( deps = [ ":runtime_builder", ":runtime_options", + ":type_registry", "//internal:noop_delete", "//internal:status_macros", + "//runtime/internal:runtime_env", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", @@ -304,11 +306,10 @@ cc_library( deps = [ ":runtime", ":runtime_builder", - "//common:allocator", - "//common:memory", "//common:native_type", "//eval/compiler:constant_folding", "//internal:casts", + "//internal:noop_delete", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", @@ -339,6 +340,7 @@ cc_test( "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -384,6 +386,7 @@ cc_test( "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/runtime/constant_folding.cc b/runtime/constant_folding.cc index 57ead8096..0174ef267 100644 --- a/runtime/constant_folding.cc +++ b/runtime/constant_folding.cc @@ -14,20 +14,24 @@ #include "runtime/constant_folding.h" -#include "absl/base/macros.h" +#include +#include + +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "common/allocator.h" #include "common/native_type.h" #include "eval/compiler/constant_folding.h" #include "internal/casts.h" +#include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { @@ -37,44 +41,122 @@ using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; -absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { +absl::StatusOr> RuntimeImplFromBuilder( + RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); - if (RuntimeFriendAccess::RuntimeTypeId(runtime) != NativeTypeId::For()) { return absl::UnimplementedError( "constant folding only supported on the default cel::Runtime " "implementation."); } + return down_cast(&runtime); +} - RuntimeImpl& runtime_impl = down_cast(runtime); - - return &runtime_impl; +absl::Status EnableConstantFoldingImpl( + RuntimeBuilder& builder, + absl::Nullable> arena, + absl::Nullable> message_factory) { + CEL_ASSIGN_OR_RETURN(absl::Nonnull runtime_impl, + RuntimeImplFromBuilder(builder)); + if (arena != nullptr) { + runtime_impl->environment().KeepAlive(arena); + } + if (message_factory != nullptr) { + runtime_impl->environment().KeepAlive(message_factory); + } + runtime_impl->expr_builder().AddProgramOptimizer( + runtime_internal::CreateConstantFoldingOptimizer( + std::move(arena), std::move(message_factory))); + return absl::OkStatus(); } } // namespace +absl::Status EnableConstantFolding(RuntimeBuilder& builder) { + return EnableConstantFoldingImpl(builder, nullptr, nullptr); +} + absl::Status EnableConstantFolding(RuntimeBuilder& builder, - Allocator<> allocator) { - CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, - RuntimeImplFromBuilder(builder)); - ABSL_ASSERT(runtime_impl != nullptr); - runtime_impl->expr_builder().AddProgramOptimizer( - runtime_internal::CreateConstantFoldingOptimizer(allocator, nullptr)); - return absl::OkStatus(); + absl::Nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + nullptr); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena) { + ABSL_DCHECK(arena != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), nullptr); } absl::Status EnableConstantFolding( - RuntimeBuilder& builder, Allocator<> allocator, + RuntimeBuilder& builder, absl::Nonnull message_factory) { ABSL_DCHECK(message_factory != nullptr); - CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, - RuntimeImplFromBuilder(builder)); - ABSL_ASSERT(runtime_impl != nullptr); - runtime_impl->expr_builder().AddProgramOptimizer( - runtime_internal::CreateConstantFoldingOptimizer(allocator, - message_factory)); - return absl::OkStatus(); + return EnableConstantFoldingImpl( + builder, nullptr, + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> message_factory) { + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, nullptr, + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull> message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, + std::shared_ptr(arena, + internal::NoopDeleteFor()), + std::move(message_factory)); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl( + builder, std::move(arena), + std::shared_ptr( + message_factory, internal::NoopDeleteFor())); +} + +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull> message_factory) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return EnableConstantFoldingImpl(builder, std::move(arena), + std::move(message_factory)); } } // namespace cel::extensions diff --git a/runtime/constant_folding.h b/runtime/constant_folding.h index be5cf6044..58cd4cfd0 100644 --- a/runtime/constant_folding.h +++ b/runtime/constant_folding.h @@ -15,10 +15,12 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ +#include + #include "absl/base/nullability.h" #include "absl/status/status.h" -#include "common/allocator.h" #include "runtime/runtime_builder.h" +#include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { @@ -26,20 +28,44 @@ namespace cel::extensions { // Enable constant folding in the runtime being built. // // Constant folding eagerly evaluates sub-expressions with all constant inputs -// at plan time to simplify the resulting program. User extensions functions are -// executed if they are eagerly bound. +// at plan time to simplify the resulting program. User functions are executed +// if they are eagerly bound. // -// The underlying implementation of `allocator` must outlive the resulting -// runtime and any programs it creates. +// The provided, the `google::protobuf::Arena` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// during planning for each program, unless one is explicitly provided during +// planning. // -// The provided `google::protobuf::MessageFactory` must outlive the resulting runtime and -// any program it creates. Failure to pass a message factory may result in -// certain optimizations being disabled. +// The provided, the `google::protobuf::MessageFactory` must outlive the resulting runtime +// and any program it creates. Otherwise the runtime will create one as needed +// and use it for all planning and the resulting programs created from the +// runtime, unless one is explicitly provided during planning or evaluation. +absl::Status EnableConstantFolding(RuntimeBuilder& builder); absl::Status EnableConstantFolding(RuntimeBuilder& builder, - Allocator<> allocator); + absl::Nonnull arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> message_factory); absl::Status EnableConstantFolding( - RuntimeBuilder& builder, Allocator<> allocator, + RuntimeBuilder& builder, absl::Nonnull arena, absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, absl::Nonnull arena, + absl::Nonnull> message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull message_factory); +absl::Status EnableConstantFolding( + RuntimeBuilder& builder, + absl::Nonnull> arena, + absl::Nonnull> message_factory); } // namespace cel::extensions diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc index af3010b62..f579cb400 100644 --- a/runtime/constant_folding_test.cc +++ b/runtime/constant_folding_test.cc @@ -20,6 +20,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "base/function_adapter.h" @@ -38,6 +39,7 @@ namespace cel::extensions { namespace { +using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; @@ -87,10 +89,9 @@ TEST_P(ConstantFoldingExtTest, Runner) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK( - EnableConstantFolding(builder, MemoryManagerRef::ReferenceCounting())); + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 69b8a8e3e..9e5078ccd 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -47,11 +47,31 @@ cc_library( ], ) +cc_library( + name = "runtime_env", + srcs = ["runtime_env.cc"], + hdrs = ["runtime_env.h"], + deps = [ + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//internal:noop_delete", + "//internal:well_known_types", + "//runtime:function_registry", + "//runtime:type_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "runtime_impl", srcs = ["runtime_impl.cc"], hdrs = ["runtime_impl.h"], deps = [ + ":runtime_env", "//base:ast", "//base:data", "//common:native_type", @@ -73,7 +93,6 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", - "@com_google_protobuf//:protobuf", ], ) @@ -156,3 +175,16 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "runtime_env_testing", + testonly = True, + srcs = ["runtime_env_testing.cc"], + hdrs = ["runtime_env_testing.h"], + deps = [ + ":runtime_env", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + ], +) diff --git a/runtime/internal/runtime_env.cc b/runtime/internal/runtime_env.cc new file mode 100644 index 000000000..dbe78d538 --- /dev/null +++ b/runtime/internal/runtime_env.cc @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/synchronization/mutex.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +RuntimeEnv::KeepAlives::~KeepAlives() { + while (!deque.empty()) { + deque.pop_back(); + } +} + +absl::Nonnull RuntimeEnv::MutableMessageFactory() + const { + absl::Nullable shared_message_factory = + message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory != nullptr) { + return shared_message_factory; + } + absl::MutexLock lock(&message_factory_mutex); + shared_message_factory = message_factory_ptr.load(std::memory_order_relaxed); + if (shared_message_factory == nullptr) { + if (descriptor_pool.get() == google::protobuf::DescriptorPool::generated_pool()) { + // Using the generated descriptor pool, just use the generated message + // factory. + message_factory = std::shared_ptr( + google::protobuf::MessageFactory::generated_factory(), + internal::NoopDeleteFor()); + } else { + auto dynamic_message_factory = + std::make_shared(); + // Ensure we do not delegate to the generated factory, if the default + // every changes. We prefer being hermetic. + dynamic_message_factory->SetDelegateToGeneratedFactory(false); + message_factory = std::move(dynamic_message_factory); + } + shared_message_factory = message_factory.get(); + message_factory_ptr.store(shared_message_factory, + std::memory_order_seq_cst); + } + return shared_message_factory; +} + +void RuntimeEnv::KeepAlive(std::shared_ptr keep_alive) { + if (keep_alive == nullptr) { + return; + } + keep_alives.deque.push_back(std::move(keep_alive)); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env.h b/runtime/internal/runtime_env.h new file mode 100644 index 000000000..e0ab566b1 --- /dev/null +++ b/runtime/internal/runtime_env.h @@ -0,0 +1,133 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "internal/well_known_types.h" +#include "runtime/function_registry.h" +#include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +// Shared state used by the runtime during creation, configuration, planning, +// and evaluation. Passed around via `std::shared_ptr`. +// +// TODO: Make this a class. +struct RuntimeEnv final { + explicit RuntimeEnv( + absl::Nonnull> + descriptor_pool, + absl::Nullable> message_factory = + nullptr) + : descriptor_pool(std::move(descriptor_pool)), + message_factory(std::move(message_factory)), + type_registry(legacy_type_registry.InternalGetModernRegistry()), + function_registry(legacy_function_registry.InternalGetRegistry()) { + if (this->message_factory != nullptr) { + message_factory_ptr.store(this->message_factory.get(), + std::memory_order_seq_cst); + } + } + + // Not copyable or moveable. + RuntimeEnv(const RuntimeEnv&) = delete; + RuntimeEnv(RuntimeEnv&&) = delete; + RuntimeEnv& operator=(const RuntimeEnv&) = delete; + RuntimeEnv& operator=(RuntimeEnv&&) = delete; + + // Ideally the environment would already be initialized, but things are a bit + // awkward. This should only be called once immediately after construction. + absl::Status Initialize() { + return well_known_types.Initialize(descriptor_pool.get()); + } + + bool IsInitialized() const { return well_known_types.IsInitialized(); } + + ABSL_ATTRIBUTE_UNUSED + const absl::Nonnull> + descriptor_pool; + + private: + // These fields deal with a message factory that is lazily initialized as + // needed. This might be called during the planning phase of an expression or + // during evaluation. We want the ability to get the message factory when it + // is already created to be cheap, so we use an atomic and a mutex for the + // slow path. + // + // Do not access any of these fields directly, use member functions. + mutable absl::Mutex message_factory_mutex; + mutable absl::Nullable> + message_factory ABSL_GUARDED_BY(message_factory_mutex); + // std::atomic> is not really a simple atomic, so we + // avoid it. + mutable std::atomic> + message_factory_ptr = nullptr; + + struct KeepAlives final { + KeepAlives() = default; + + ~KeepAlives(); + + // Not copyable or moveable. + KeepAlives(const KeepAlives&) = delete; + KeepAlives(KeepAlives&&) = delete; + KeepAlives& operator=(const KeepAlives&) = delete; + KeepAlives& operator=(KeepAlives&&) = delete; + + std::deque> deque; + }; + + KeepAlives keep_alives; + + public: + // Because of legacy shenanigans, we use shared_ptr here. For legacy, this is + // an unowned shared_ptr (a noop deleter) pointing to the modern equivalent + // which is a member of the legacy variant. + google::api::expr::runtime::CelTypeRegistry legacy_type_registry; + google::api::expr::runtime::CelFunctionRegistry legacy_function_registry; + TypeRegistry& type_registry; + FunctionRegistry& function_registry; + + well_known_types::Reflection well_known_types; + + absl::Nonnull MutableMessageFactory() const + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + // Not thread safe. Adds `keep_alive` to a list owned by this environment + // and ensures it survives at least as long as this environment. Keep alives + // are released in reverse order of their registration. This mimics normal + // destructor rules of members. + // + // IMPORTANT: This should only be when building the runtime, and not after. + void KeepAlive(std::shared_ptr keep_alive); +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ diff --git a/runtime/internal/runtime_env_testing.cc b/runtime/internal/runtime_env_testing.cc new file mode 100644 index 000000000..ae7dd0ab9 --- /dev/null +++ b/runtime/internal/runtime_env_testing.cc @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_env_testing.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/internal/runtime_env.h" + +namespace cel::runtime_internal { + +absl::Nonnull> NewTestingRuntimeEnv() { + auto env = + std::make_shared(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(env->Initialize()); // Crash OK + return env; +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/runtime_env_testing.h b/runtime/internal/runtime_env_testing.h new file mode 100644 index 000000000..1645ce4dd --- /dev/null +++ b/runtime/internal/runtime_env_testing.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ + +#include + +#include "absl/base/nullability.h" +#include "runtime/internal/runtime_env.h" + +namespace cel::runtime_internal { + +absl::Nonnull> NewTestingRuntimeEnv(); + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ diff --git a/runtime/internal/runtime_impl.h b/runtime/internal/runtime_impl.h index 4dc2fe929..0c4972fcf 100644 --- a/runtime/internal/runtime_impl.h +++ b/runtime/internal/runtime_impl.h @@ -28,48 +28,51 @@ #include "eval/compiler/flat_expr_builder.h" #include "internal/well_known_types.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" -#include "google/protobuf/descriptor.h" namespace cel::runtime_internal { class RuntimeImpl : public Runtime { public: - struct Environment { - ABSL_ATTRIBUTE_UNUSED - absl::Nonnull> - descriptor_pool; - TypeRegistry type_registry; - FunctionRegistry function_registry; - well_known_types::Reflection well_known_types; - }; + using Environment = RuntimeEnv; RuntimeImpl(absl::Nonnull> environment, const RuntimeOptions& options) : environment_(std::move(environment)), - expr_builder_(environment_->function_registry, - environment_->type_registry, options) { + expr_builder_(environment_, options) { ABSL_DCHECK(environment_->well_known_types.IsInitialized()); } - TypeRegistry& type_registry() { return environment_->type_registry; } - const TypeRegistry& type_registry() const { + TypeRegistry& type_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return environment_->type_registry; + } + const TypeRegistry& type_registry() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->type_registry; } - FunctionRegistry& function_registry() { + FunctionRegistry& function_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } - const FunctionRegistry& function_registry() const { + const FunctionRegistry& function_registry() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } - const well_known_types::Reflection& well_known_types() const { + const well_known_types::Reflection& well_known_types() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->well_known_types; } + Environment& environment() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + const Environment& environment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *environment_; + } + // implement Runtime absl::StatusOr> CreateProgram( std::unique_ptr ast, @@ -84,7 +87,8 @@ class RuntimeImpl : public Runtime { } // exposed for extensions access - google::api::expr::runtime::FlatExprBuilder& expr_builder() { + google::api::expr::runtime::FlatExprBuilder& expr_builder() + ABSL_ATTRIBUTE_LIFETIME_BOUND { return expr_builder_; } diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc index b5da4aa4e..5cbdb291c 100644 --- a/runtime/regex_precompilation_test.cc +++ b/runtime/regex_precompilation_test.cc @@ -20,6 +20,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "base/function_adapter.h" @@ -39,6 +40,7 @@ namespace cel::extensions { namespace { +using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; @@ -89,9 +91,9 @@ TEST_P(RegexPrecompilationTest, Basic) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK(EnableRegexPrecompilation(builder)); + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); @@ -136,11 +138,10 @@ TEST_P(RegexPrecompilationTest, WithConstantFolding) { return StringValue::Concat(f, prefix, value); }, builder.function_registry()); - ASSERT_OK(status); + ASSERT_THAT(status, IsOk()); - ASSERT_OK( - EnableConstantFolding(builder, MemoryManagerRef::ReferenceCounting())); - ASSERT_OK(EnableRegexPrecompilation(builder)); + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); diff --git a/runtime/runtime_builder_factory.cc b/runtime/runtime_builder_factory.cc index 34e16b03a..cdfb0058f 100644 --- a/runtime/runtime_builder_factory.cc +++ b/runtime/runtime_builder_factory.cc @@ -22,13 +22,16 @@ #include "absl/status/statusor.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" +#include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" +#include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" namespace cel { +using ::cel::runtime_internal::RuntimeEnv; using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr CreateRuntimeBuilder( @@ -51,10 +54,8 @@ absl::StatusOr CreateRuntimeBuilder( // TODO: add API for attaching an issue listener (replacing the // vector overloads). ABSL_DCHECK(descriptor_pool != nullptr); - auto environment = std::make_shared(); - environment->descriptor_pool = std::move(descriptor_pool); - CEL_RETURN_IF_ERROR(environment->well_known_types.Initialize( - environment->descriptor_pool.get())); + auto environment = std::make_shared(std::move(descriptor_pool)); + CEL_RETURN_IF_ERROR(environment->Initialize()); auto runtime_impl = std::make_unique(std::move(environment), options); runtime_impl->expr_builder().set_container(options.container); From 1cfeb63d7fbfcb0ff588d025487e1237e520e49d Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 30 Oct 2024 16:44:49 -0700 Subject: [PATCH 014/180] Remove the usage of type reflectors for optional types PiperOrigin-RevId: 691596486 --- common/list_type_reflector.cc | 9 ------ common/map_type_reflector.cc | 10 ------- common/type_reflector.cc | 7 +---- common/type_reflector.h | 34 ++++++---------------- runtime/BUILD | 3 +- runtime/function_adapter_test.cc | 6 ++-- runtime/internal/BUILD | 4 ++- runtime/internal/composed_type_provider.cc | 31 ++++++++------------ runtime/internal/composed_type_provider.h | 16 ++++------ runtime/internal/function_adapter_test.cc | 7 ++--- runtime/optional_types.cc | 16 +--------- runtime/type_registry.h | 9 +++++- 12 files changed, 47 insertions(+), 105 deletions(-) diff --git a/common/list_type_reflector.cc b/common/list_type_reflector.cc index 81b8a1cc7..2713f7a13 100644 --- a/common/list_type_reflector.cc +++ b/common/list_type_reflector.cc @@ -28,13 +28,4 @@ TypeReflector::NewListValueBuilder(ValueFactory& value_factory, return common_internal::NewListValueBuilder(value_factory); } -namespace common_internal { - -absl::StatusOr> -LegacyTypeReflector::NewListValueBuilder(ValueFactory& value_factory, - const ListType& type) const { - return TypeReflector::NewListValueBuilder(value_factory, type); -} -} // namespace common_internal - } // namespace cel diff --git a/common/map_type_reflector.cc b/common/map_type_reflector.cc index 8278e2fbd..39852163e 100644 --- a/common/map_type_reflector.cc +++ b/common/map_type_reflector.cc @@ -28,14 +28,4 @@ TypeReflector::NewMapValueBuilder(ValueFactory& value_factory, return common_internal::NewMapValueBuilder(value_factory); } -namespace common_internal { - -absl::StatusOr> -LegacyTypeReflector::NewMapValueBuilder(ValueFactory& value_factory, - const MapType& type) const { - return TypeReflector::NewMapValueBuilder(value_factory, type); -} - -} // namespace common_internal - } // namespace cel diff --git a/common/type_reflector.cc b/common/type_reflector.cc index 472e64a79..4677f038a3 100644 --- a/common/type_reflector.cc +++ b/common/type_reflector.cc @@ -968,12 +968,7 @@ absl::StatusOr TypeReflector::FindValue(ValueFactory&, absl::string_view, return false; } -TypeReflector& TypeReflector::LegacyBuiltin() { - static absl::NoDestructor instance; - return *instance; -} - -TypeReflector& TypeReflector::ModernBuiltin() { +TypeReflector& TypeReflector::Builtin() { static absl::NoDestructor instance; return *instance; } diff --git a/common/type_reflector.h b/common/type_reflector.h index d53da9c67..b0c1c66d3 100644 --- a/common/type_reflector.h +++ b/common/type_reflector.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" @@ -35,25 +36,22 @@ namespace cel { class TypeReflector : public virtual TypeIntrospector { public: // Legacy type reflector, will prefer builders for legacy value. - static TypeReflector& LegacyBuiltin(); + ABSL_DEPRECATED("Is now the same as Builtin()") + static TypeReflector& LegacyBuiltin() { return Builtin(); } // Will prefer builders for modern values. - static TypeReflector& ModernBuiltin(); + ABSL_DEPRECATED("Is now the same as Builtin()") + static TypeReflector& ModernBuiltin() { return Builtin(); } - static TypeReflector& Builtin() { - // TODO: Check if it's safe to default to modern. - // Legacy will prefer legacy container builders for faster interop with - // client extensions. - return LegacyBuiltin(); - } + static TypeReflector& Builtin(); // `NewListValueBuilder` returns a new `ListValueBuilderInterface` for the // corresponding `ListType` `type`. - virtual absl::StatusOr> - NewListValueBuilder(ValueFactory& value_factory, const ListType& type) const; + absl::StatusOr> NewListValueBuilder( + ValueFactory& value_factory, const ListType& type) const; // `NewMapValueBuilder` returns a new `MapValueBuilderInterface` for the // corresponding `MapType` `type`. - virtual absl::StatusOr> NewMapValueBuilder( + absl::StatusOr> NewMapValueBuilder( ValueFactory& value_factory, const MapType& type) const; // `NewStructValueBuilder` returns a new `StructValueBuilder` for the @@ -98,20 +96,6 @@ class TypeReflector : public virtual TypeIntrospector { Shared NewThreadCompatibleTypeReflector( MemoryManagerRef memory_manager); -namespace common_internal { - -// Implementation backing LegacyBuiltin(). -class LegacyTypeReflector : public TypeReflector { - public: - absl::StatusOr> NewListValueBuilder( - ValueFactory& value_factory, const ListType& type) const override; - - absl::StatusOr> NewMapValueBuilder( - ValueFactory& value_factory, const MapType& type) const override; -}; - -} // namespace common_internal - } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ diff --git a/runtime/BUILD b/runtime/BUILD index 1d8c3dbfc..1e10f3d4a 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -158,9 +158,10 @@ cc_library( hdrs = ["type_registry.h"], deps = [ "//base:data", + "//common:type", "//runtime/internal:composed_type_provider", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) diff --git a/runtime/function_adapter_test.cc b/runtime/function_adapter_test.cc index 62bfaf02f..a54daeb21 100644 --- a/runtime/function_adapter_test.cc +++ b/runtime/function_adapter_test.cc @@ -26,6 +26,7 @@ #include "base/function_descriptor.h" #include "common/kind.h" #include "common/memory.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "common/values/legacy_type_reflector.h" @@ -43,8 +44,8 @@ using ::testing::IsEmpty; class FunctionAdapterTest : public ::testing::Test { public: FunctionAdapterTest() - : type_reflector_(), - value_manager_(MemoryManagerRef::ReferenceCounting(), type_reflector_), + : value_manager_(MemoryManagerRef::ReferenceCounting(), + TypeReflector::Builtin()), test_context_(value_manager_) {} ValueManager& value_factory() { return value_manager_; } @@ -52,7 +53,6 @@ class FunctionAdapterTest : public ::testing::Test { const FunctionEvaluationContext& test_context() { return test_context_; } private: - common_internal::LegacyTypeReflector type_reflector_; common_internal::LegacyValueManager value_manager_; FunctionEvaluationContext test_context_; }; diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 9e5078ccd..33b5c22b0 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -25,12 +25,14 @@ cc_library( hdrs = ["composed_type_provider.h"], deps = [ "//base:data", - "//common:memory", "//common:type", "//common:value", "//internal:status_macros", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", diff --git a/runtime/internal/composed_type_provider.cc b/runtime/internal/composed_type_provider.cc index 5bb377956..65542ac04 100644 --- a/runtime/internal/composed_type_provider.cc +++ b/runtime/internal/composed_type_provider.cc @@ -14,12 +14,14 @@ #include "runtime/internal/composed_type_provider.h" #include +#include #include "absl/base/nullability.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "common/memory.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" @@ -28,25 +30,13 @@ namespace cel::runtime_internal { -absl::StatusOr> -ComposedTypeProvider::NewListValueBuilder(ValueFactory& value_factory, - const ListType& type) const { - if (use_legacy_container_builders_) { - return TypeReflector::LegacyBuiltin().NewListValueBuilder(value_factory, - type); +absl::Status ComposedTypeProvider::RegisterType(const OpaqueType& type) { + auto insertion = types_.insert(std::pair{type.name(), Type(type)}); + if (!insertion.second) { + return absl::AlreadyExistsError( + absl::StrCat("type already registered: ", insertion.first->first)); } - return TypeReflector::ModernBuiltin().NewListValueBuilder(value_factory, - type); -} - -absl::StatusOr> -ComposedTypeProvider::NewMapValueBuilder(ValueFactory& value_factory, - const MapType& type) const { - if (use_legacy_container_builders_) { - return TypeReflector::LegacyBuiltin().NewMapValueBuilder(value_factory, - type); - } - return TypeReflector::ModernBuiltin().NewMapValueBuilder(value_factory, type); + return absl::OkStatus(); } absl::StatusOr> @@ -90,6 +80,9 @@ ComposedTypeProvider::DeserializeValueImpl(ValueFactory& value_factory, absl::StatusOr> ComposedTypeProvider::FindTypeImpl( absl::string_view name) const { + if (auto type = types_.find(name); type != types_.end()) { + return type->second; + } for (const std::unique_ptr& provider : providers_) { CEL_ASSIGN_OR_RETURN(auto result, provider->FindType(name)); if (result.has_value()) { diff --git a/runtime/internal/composed_type_provider.h b/runtime/internal/composed_type_provider.h index b451e27fa..8ec9ecda2 100644 --- a/runtime/internal/composed_type_provider.h +++ b/runtime/internal/composed_type_provider.h @@ -19,12 +19,13 @@ #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/type_provider.h" -#include "common/memory.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" @@ -47,20 +48,12 @@ class ComposedTypeProvider : public TypeReflector { providers_.push_back(std::move(provider)); } + absl::Status RegisterType(const OpaqueType& type); + void set_use_legacy_container_builders(bool use_legacy_container_builders) { use_legacy_container_builders_ = use_legacy_container_builders; } - // `NewListValueBuilder` returns a new `ListValueBuilderInterface` for the - // corresponding `ListType` `type`. - absl::StatusOr> NewListValueBuilder( - ValueFactory& value_factory, const ListType& type) const override; - - // `NewMapValueBuilder` returns a new `MapValueBuilderInterface` for the - // corresponding `MapType` `type`. - absl::StatusOr> NewMapValueBuilder( - ValueFactory& value_factory, const MapType& type) const override; - absl::StatusOr> NewStructValueBuilder( ValueFactory& value_factory, const StructType& type) const override; @@ -80,6 +73,7 @@ class ComposedTypeProvider : public TypeReflector { absl::string_view type, absl::string_view name) const override; private: + absl::flat_hash_map types_; std::vector> providers_; bool use_legacy_container_builders_ = true; }; diff --git a/runtime/internal/function_adapter_test.cc b/runtime/internal/function_adapter_test.cc index 4689f6dad..2c8291fb0 100644 --- a/runtime/internal/function_adapter_test.cc +++ b/runtime/internal/function_adapter_test.cc @@ -22,6 +22,7 @@ #include "common/casting.h" #include "common/kind.h" #include "common/memory.h" +#include "common/type_reflector.h" #include "common/value.h" #include "common/values/legacy_type_reflector.h" #include "common/values/legacy_value_manager.h" @@ -73,14 +74,12 @@ static_assert(AdaptedKind() == Kind::kNullType, class ValueFactoryTestBase : public testing::Test { public: ValueFactoryTestBase() - : type_reflector_(), - value_manager_(MemoryManagerRef::ReferenceCounting(), type_reflector_) { - } + : value_manager_(MemoryManagerRef::ReferenceCounting(), + TypeReflector::Builtin()) {} ValueFactory& value_factory() { return value_manager_; } private: - common_internal::LegacyTypeReflector type_reflector_; common_internal::LegacyValueManager value_manager_; }; diff --git a/runtime/optional_types.cc b/runtime/optional_types.cc index 9aa36e77a..a26fac504 100644 --- a/runtime/optional_types.cc +++ b/runtime/optional_types.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -29,7 +28,6 @@ #include "base/function_adapter.h" #include "common/casting.h" #include "common/type.h" -#include "common/type_reflector.h" #include "common/value.h" #include "common/value_manager.h" #include "internal/casts.h" @@ -279,17 +277,6 @@ absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, return absl::OkStatus(); } -class OptionalTypeProvider final : public TypeReflector { - protected: - absl::StatusOr> FindTypeImpl( - absl::string_view name) const override { - if (name != "optional_type") { - return absl::nullopt; - } - return OptionalType{}; - } -}; - } // namespace absl::Status EnableOptionalTypes(RuntimeBuilder& builder) { @@ -297,8 +284,7 @@ absl::Status EnableOptionalTypes(RuntimeBuilder& builder) { runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); CEL_RETURN_IF_ERROR(RegisterOptionalTypeFunctions( builder.function_registry(), runtime.expr_builder().options())); - builder.type_registry().AddTypeProvider( - std::make_unique()); + CEL_RETURN_IF_ERROR(builder.type_registry().RegisterType(OptionalType())); runtime.expr_builder().enable_optional_types(); return absl::OkStatus(); } diff --git a/runtime/type_registry.h b/runtime/type_registry.h index a4f3ac85b..fb47723dd 100644 --- a/runtime/type_registry.h +++ b/runtime/type_registry.h @@ -21,10 +21,11 @@ #include #include -#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/type_provider.h" +#include "common/type.h" #include "runtime/internal/composed_type_provider.h" namespace cel { @@ -57,6 +58,12 @@ class TypeRegistry { impl_.AddTypeProvider(std::move(provider)); } + // Registers a type such that it can be accessed by name, i.e. `type(foo) == + // my_type`. Where `my_type` is the type being registered. + absl::Status RegisterType(const OpaqueType& type) { + return impl_.RegisterType(type); + } + // Register a custom enum type. // // This adds the enum to the set consulted at plan time to identify constant From c7264239a14ce8bea60bf4ddf503ac004c75985a Mon Sep 17 00:00:00 2001 From: Justin King Date: Thu, 31 Oct 2024 11:26:36 -0700 Subject: [PATCH 015/180] Deprecate the `CreateCelExpressionBuilder` overload that defaults to the generated descriptor pool PiperOrigin-RevId: 691875330 --- eval/public/BUILD | 1 + eval/public/cel_expr_builder_factory.h | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/eval/public/BUILD b/eval/public/BUILD index 6142f3fa4..cbb5b1c3a 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -585,6 +585,7 @@ cc_library( "//internal:noop_delete", "//runtime:runtime_options", "//runtime/internal:runtime_env", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index 0fd7f95fc..61450069f 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -3,6 +3,7 @@ #include +#include "absl/base/attributes.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "google/protobuf/descriptor.h" @@ -19,6 +20,14 @@ std::unique_ptr CreateCelExpressionBuilder( google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options = InterpreterOptions()); +ABSL_DEPRECATED( + "This overload uses the generated descriptor pool, which allows " + "expressions to create any messages linked into the binary. This is not " + "hermetic and potentially dangerous, you should select the descriptor pool " + "carefully. Use the other overload and explicitly pass your descriptor " + "pool. It can still be the generated descriptor pool, but the choice " + "should be explicit. If you do not need struct creation, use " + "`cel::GetMinimalDescriptorPool()`.") inline std::unique_ptr CreateCelExpressionBuilder( const InterpreterOptions& options = InterpreterOptions()) { return CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), From eebc4e9440fa676be2a93fccce0e8f28cfa8d804 Mon Sep 17 00:00:00 2001 From: Justin King Date: Thu, 31 Oct 2024 14:29:11 -0700 Subject: [PATCH 016/180] Remove deadcode PiperOrigin-RevId: 691936622 --- eval/public/BUILD | 1 - eval/public/cel_type_registry.h | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/eval/public/BUILD b/eval/public/BUILD index cbb5b1c3a..609956706 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -810,7 +810,6 @@ cc_library( "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:legacy_type_provider", "//runtime:type_registry", - "//runtime/internal:composed_type_provider", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index e7a3f841b..80854f43d 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -24,7 +24,6 @@ #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "eval/public/structs/legacy_type_provider.h" -#include "runtime/internal/composed_type_provider.h" #include "runtime/type_registry.h" namespace google::api::expr::runtime { @@ -83,15 +82,6 @@ class CelTypeRegistry { return modern_type_registry_.GetComposedTypeProvider(); } - // Register an additional type provider with the registry. - // - // A pointer to the registered provider is returned to support testing, - // but users should prefer to use the composed type provider from - // GetTypeProvider() - void RegisterModernTypeProvider(std::unique_ptr provider) { - return modern_type_registry_.AddTypeProvider(std::move(provider)); - } - // Find a type adapter given a fully qualified type name. // Adapter provides a generic interface for the reflection operations the // interpreter needs to provide. From af71fd2b14b74aa77666484308764d0704e4517b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 1 Nov 2024 11:41:28 -0700 Subject: [PATCH 017/180] Refactor parser implementation to build non-proto version of SourceInfo then convert to proto. PiperOrigin-RevId: 692244526 --- extensions/protobuf/ast_converters.cc | 128 ++++++++++++----------- extensions/protobuf/ast_converters.h | 5 + parser/BUILD | 2 + parser/options.h | 3 + parser/parser.cc | 142 +++++++++++++++----------- parser/parser_test.cc | 18 +++- parser/source_factory.h | 6 ++ 7 files changed, 180 insertions(+), 124 deletions(-) diff --git a/extensions/protobuf/ast_converters.cc b/extensions/protobuf/ast_converters.cc index 63b893940..d88c11f82 100644 --- a/extensions/protobuf/ast_converters.cc +++ b/extensions/protobuf/ast_converters.cc @@ -355,6 +355,58 @@ absl::StatusOr ConvertProtoReferenceToNative( return ret_val; } +absl::StatusOr ConvertSourceInfoToProto( + const ast_internal::SourceInfo& source_info) { + cel::expr::SourceInfo result; + result.set_syntax_version(source_info.syntax_version()); + result.set_location(source_info.location()); + + for (int32_t line_offset : source_info.line_offsets()) { + result.add_line_offsets(line_offset); + } + + for (auto pos_iter = source_info.positions().begin(); + pos_iter != source_info.positions().end(); ++pos_iter) { + (*result.mutable_positions())[pos_iter->first] = pos_iter->second; + } + + for (auto macro_iter = source_info.macro_calls().begin(); + macro_iter != source_info.macro_calls().end(); ++macro_iter) { + ExprPb& dest_macro = (*result.mutable_macro_calls())[macro_iter->first]; + CEL_RETURN_IF_ERROR( + protobuf_internal::ExprToProto(macro_iter->second, &dest_macro)); + } + + for (const auto& extension : source_info.extensions()) { + auto* extension_pb = result.add_extensions(); + extension_pb->set_id(extension.id()); + auto* version_pb = extension_pb->mutable_version(); + version_pb->set_major(extension.version().major()); + version_pb->set_minor(extension.version().minor()); + + for (auto component : extension.affected_components()) { + switch (component) { + case Extension::Component::kParser: + extension_pb->add_affected_components(ExtensionPb::COMPONENT_PARSER); + break; + case Extension::Component::kTypeChecker: + extension_pb->add_affected_components( + ExtensionPb::COMPONENT_TYPE_CHECKER); + break; + case Extension::Component::kRuntime: + extension_pb->add_affected_components(ExtensionPb::COMPONENT_RUNTIME); + break; + default: + extension_pb->add_affected_components( + ExtensionPb::COMPONENT_UNSPECIFIED); + break; + } + } + } + + return result; +} + } // namespace internal namespace { @@ -370,7 +422,6 @@ using ::cel::ast_internal::CreateStruct; using ::cel::ast_internal::DynamicType; using ::cel::ast_internal::ErrorType; using ::cel::ast_internal::Expr; -using ::cel::ast_internal::Extension; using ::cel::ast_internal::FunctionType; using ::cel::ast_internal::Ident; using ::cel::ast_internal::ListType; @@ -386,6 +437,7 @@ using ::cel::ast_internal::SourceInfo; using ::cel::ast_internal::Type; using ::cel::ast_internal::UnspecifiedType; using ::cel::ast_internal::WellKnownType; +using ::cel::extensions::protobuf_internal::ExprToProto; using ExprPb = cel::expr::Expr; using ParsedExprPb = cel::expr::ParsedExpr; @@ -446,62 +498,6 @@ absl::Status ConstantToProto(const ast_internal::Constant& source, source.constant_kind()); } -absl::StatusOr ExprToProto(const Expr& expr) { - ExprPb proto_expr; - CEL_RETURN_IF_ERROR(protobuf_internal::ExprToProto(expr, &proto_expr)); - return proto_expr; -} - -absl::StatusOr SourceInfoToProto(const SourceInfo& source_info) { - SourceInfoPb result; - result.set_syntax_version(source_info.syntax_version()); - result.set_location(source_info.location()); - - for (int32_t line_offset : source_info.line_offsets()) { - result.add_line_offsets(line_offset); - } - - for (auto pos_iter = source_info.positions().begin(); - pos_iter != source_info.positions().end(); ++pos_iter) { - (*result.mutable_positions())[pos_iter->first] = pos_iter->second; - } - - for (auto macro_iter = source_info.macro_calls().begin(); - macro_iter != source_info.macro_calls().end(); ++macro_iter) { - ExprPb& dest_macro = (*result.mutable_macro_calls())[macro_iter->first]; - CEL_ASSIGN_OR_RETURN(dest_macro, ExprToProto(macro_iter->second)); - } - - for (const auto& extension : source_info.extensions()) { - auto* extension_pb = result.add_extensions(); - extension_pb->set_id(extension.id()); - auto* version_pb = extension_pb->mutable_version(); - version_pb->set_major(extension.version().major()); - version_pb->set_minor(extension.version().minor()); - - for (auto component : extension.affected_components()) { - switch (component) { - case Extension::Component::kParser: - extension_pb->add_affected_components(ExtensionPb::COMPONENT_PARSER); - break; - case Extension::Component::kTypeChecker: - extension_pb->add_affected_components( - ExtensionPb::COMPONENT_TYPE_CHECKER); - break; - case Extension::Component::kRuntime: - extension_pb->add_affected_components(ExtensionPb::COMPONENT_RUNTIME); - break; - default: - extension_pb->add_affected_components( - ExtensionPb::COMPONENT_UNSPECIFIED); - break; - } - } - } - - return result; -} - absl::StatusOr ReferenceToProto(const Reference& reference) { ReferencePb result; @@ -681,10 +677,11 @@ absl::StatusOr> CreateAstFromParsedExpr( absl::StatusOr CreateParsedExprFromAst(const Ast& ast) { const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(ast); ParsedExprPb parsed_expr; - CEL_ASSIGN_OR_RETURN(*parsed_expr.mutable_expr(), - ExprToProto(ast_impl.root_expr())); - CEL_ASSIGN_OR_RETURN(*parsed_expr.mutable_source_info(), - SourceInfoToProto(ast_impl.source_info())); + CEL_RETURN_IF_ERROR( + ExprToProto(ast_impl.root_expr(), parsed_expr.mutable_expr())); + CEL_ASSIGN_OR_RETURN( + *parsed_expr.mutable_source_info(), + internal::ConvertSourceInfoToProto(ast_impl.source_info())); return parsed_expr; } @@ -728,10 +725,11 @@ absl::StatusOr CreateCheckedExprFromAst( const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(ast); CheckedExprPb checked_expr; checked_expr.set_expr_version(ast_impl.expr_version()); - CEL_ASSIGN_OR_RETURN(*checked_expr.mutable_expr(), - ExprToProto(ast_impl.root_expr())); - CEL_ASSIGN_OR_RETURN(*checked_expr.mutable_source_info(), - SourceInfoToProto(ast_impl.source_info())); + CEL_RETURN_IF_ERROR( + ExprToProto(ast_impl.root_expr(), checked_expr.mutable_expr())); + CEL_ASSIGN_OR_RETURN( + *checked_expr.mutable_source_info(), + internal::ConvertSourceInfoToProto(ast_impl.source_info())); for (auto it = ast_impl.reference_map().begin(); it != ast_impl.reference_map().end(); ++it) { ReferencePb& dest_reference = diff --git a/extensions/protobuf/ast_converters.h b/extensions/protobuf/ast_converters.h index 79bdc8a44..e2944953c 100644 --- a/extensions/protobuf/ast_converters.h +++ b/extensions/protobuf/ast_converters.h @@ -40,6 +40,11 @@ absl::StatusOr ConvertProtoReferenceToNative( absl::StatusOr ConvertConstant( const cel::expr::Constant& constant); +// Conversion utility for the CEL source info representation to the protobuf +// representation. +absl::StatusOr ConvertSourceInfoToProto( + const ast_internal::SourceInfo& source_info); + } // namespace internal // Creates a runtime AST from a parsed-only protobuf AST. diff --git a/parser/BUILD b/parser/BUILD index f5e91a5d1..97ada22d8 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -33,11 +33,13 @@ cc_library( ":macro_registry", ":options", ":source_factory", + "//base/ast_internal:expr", "//common:ast", "//common:constant", "//common:expr_factory", "//common:operators", "//common:source", + "//extensions/protobuf:ast_converters", "//extensions/protobuf/internal:ast", "//internal:lexis", "//internal:status_macros", diff --git a/parser/options.h b/parser/options.h index 230e16e18..c879d8e37 100644 --- a/parser/options.h +++ b/parser/options.h @@ -47,6 +47,9 @@ struct ParserOptions final { // Enable support for optional syntax. bool enable_optional_syntax = false; + + // Disable standard macros (has, all, exists, exists_one, filter, map). + bool disable_standard_macros = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index 317663f80..44ad7c412 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -50,11 +50,13 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "antlr4-runtime.h" +#include "base/ast_internal/expr.h" #include "common/ast.h" #include "common/constant.h" #include "common/expr_factory.h" #include "common/operators.h" #include "common/source.h" +#include "extensions/protobuf/ast_converters.h" #include "extensions/protobuf/internal/ast.h" #include "internal/lexis.h" #include "internal/status_macros.h" @@ -363,6 +365,13 @@ class ParserMacroExprFactory final : public MacroExprFactory { return macro_calls_; } + absl::flat_hash_map release_macro_calls() { + using std::swap; + absl::flat_hash_map result; + swap(result, macro_calls_); + return result; + } + void EraseId(ExprId id) { positions_.erase(id); if (expr_id_ == id + 1) { @@ -637,7 +646,9 @@ class ParserVisitor final : public CelBaseVisitor, std::any visitBoolTrue(CelParser::BoolTrueContext* ctx) override; std::any visitBoolFalse(CelParser::BoolFalseContext* ctx) override; std::any visitNull(CelParser::NullContext* ctx) override; - absl::Status GetSourceInfo(cel::expr::SourceInfo* source_info) const; + // Note: this is destructive and intended to be called after the parse is + // finished. + cel::ast_internal::SourceInfo GetSourceInfo(); EnrichedSourceInfo enriched_source_info() const; void syntaxError(antlr4::Recognizer* recognizer, antlr4::Token* offending_symbol, size_t line, size_t col, @@ -1344,25 +1355,20 @@ std::any ParserVisitor::visitNull(CelParser::NullContext* ctx) { factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } -absl::Status ParserVisitor::GetSourceInfo( - cel::expr::SourceInfo* source_info) const { - source_info->set_location(source_.description()); +cel::ast_internal::SourceInfo ParserVisitor::GetSourceInfo() { + cel::ast_internal::SourceInfo source_info; + source_info.set_location(std::string(source_.description())); for (const auto& positions : factory_.positions()) { - source_info->mutable_positions()->insert( + source_info.mutable_positions().insert( std::pair{positions.first, positions.second.begin}); } - source_info->mutable_line_offsets()->Reserve(source_.line_offsets().size()); + source_info.mutable_line_offsets().reserve(source_.line_offsets().size()); for (const auto& line_offset : source_.line_offsets()) { - source_info->mutable_line_offsets()->Add(line_offset); + source_info.mutable_line_offsets().push_back(line_offset); } - for (const auto& macro_call : factory_.macro_calls()) { - cel::expr::Expr macro_call_proto; - CEL_RETURN_IF_ERROR(cel::extensions::protobuf_internal::ExprToProto( - macro_call.second, ¯o_call_proto)); - source_info->mutable_macro_calls()->insert( - std::pair{macro_call.first, std::move(macro_call_proto)}); - } - return absl::OkStatus(); + + source_info.mutable_macro_calls() = factory_.release_macro_calls(); + return source_info; } EnrichedSourceInfo ParserVisitor::enriched_source_info() const { @@ -1588,41 +1594,15 @@ class RecoveryLimitErrorStrategy final : public DefaultErrorStrategy { int recovery_token_lookahead_limit_; }; -} // namespace - -absl::StatusOr Parse(absl::string_view expression, - absl::string_view description, - const ParserOptions& options) { - std::vector macros = Macro::AllMacros(); - if (options.enable_optional_syntax) { - macros.push_back(cel::OptMapMacro()); - macros.push_back(cel::OptFlatMapMacro()); - } - return ParseWithMacros(expression, macros, description, options); -} - -absl::StatusOr ParseWithMacros(absl::string_view expression, - const std::vector& macros, - absl::string_view description, - const ParserOptions& options) { - CEL_ASSIGN_OR_RETURN(auto verbose_parsed_expr, - EnrichedParse(expression, macros, description, options)); - return verbose_parsed_expr.parsed_expr(); -} - -absl::StatusOr EnrichedParse( - absl::string_view expression, const std::vector& macros, - absl::string_view description, const ParserOptions& options) { - CEL_ASSIGN_OR_RETURN(auto source, - cel::NewSource(expression, std::string(description))); - cel::MacroRegistry macro_registry; - CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros)); - return EnrichedParse(*source, macro_registry, options); -} +struct ParseResult { + cel::Expr expr; + cel::ast_internal::SourceInfo source_info; + EnrichedSourceInfo enriched_source_info; +}; -absl::StatusOr EnrichedParse( - const cel::Source& source, const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl(const cel::Source& source, + const cel::MacroRegistry& registry, + const ParserOptions& options) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1664,15 +1644,10 @@ absl::StatusOr EnrichedParse( return absl::InvalidArgumentError(visitor.ErrorMessage()); } - // root is deleted as part of the parser context - ParsedExpr parsed_expr; - CEL_RETURN_IF_ERROR(cel::extensions::protobuf_internal::ExprToProto( - expr, parsed_expr.mutable_expr())); - CEL_RETURN_IF_ERROR( - visitor.GetSourceInfo(parsed_expr.mutable_source_info())); - auto enriched_source_info = visitor.enriched_source_info(); - return VerboseParsedExpr(std::move(parsed_expr), - std::move(enriched_source_info)); + return { + ParseResult{.expr = std::move(expr), + .source_info = visitor.GetSourceInfo(), + .enriched_source_info = visitor.enriched_source_info()}}; } catch (const std::exception& e) { return absl::AbortedError(e.what()); } catch (const char* what) { @@ -1684,6 +1659,57 @@ absl::StatusOr EnrichedParse( } } +} // namespace + +absl::StatusOr Parse(absl::string_view expression, + absl::string_view description, + const ParserOptions& options) { + std::vector macros; + if (!options.disable_standard_macros) { + macros = Macro::AllMacros(); + } + if (options.enable_optional_syntax) { + macros.push_back(cel::OptMapMacro()); + macros.push_back(cel::OptFlatMapMacro()); + } + return ParseWithMacros(expression, macros, description, options); +} + +absl::StatusOr ParseWithMacros(absl::string_view expression, + const std::vector& macros, + absl::string_view description, + const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(auto verbose_parsed_expr, + EnrichedParse(expression, macros, description, options)); + return verbose_parsed_expr.parsed_expr(); +} + +absl::StatusOr EnrichedParse( + absl::string_view expression, const std::vector& macros, + absl::string_view description, const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(expression, std::string(description))); + cel::MacroRegistry macro_registry; + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros)); + return EnrichedParse(*source, macro_registry, options); +} + +absl::StatusOr EnrichedParse( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options) { + CEL_ASSIGN_OR_RETURN(ParseResult parse_result, + ParseImpl(source, registry, options)); + ParsedExpr parsed_expr; + CEL_RETURN_IF_ERROR(cel::extensions::protobuf_internal::ExprToProto( + parse_result.expr, parsed_expr.mutable_expr())); + + CEL_ASSIGN_OR_RETURN((*parsed_expr.mutable_source_info()), + cel::extensions::internal::ConvertSourceInfoToProto( + parse_result.source_info)); + return VerboseParsedExpr(std::move(parsed_expr), + std::move(parse_result.enriched_source_info)); +} + absl::StatusOr Parse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 9ca16b0d0..ea0484b1a 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -15,7 +15,6 @@ #include "parser/parser.h" #include -#include #include #include #include @@ -1512,6 +1511,23 @@ TEST(ExpressionTest, RecursionDepthExceeded) { HasSubstr("Exceeded max recursion depth of 6 when parsing.")); } +TEST(ExpressionTest, DisableStandardMacros) { + ParserOptions options; + options.disable_standard_macros = true; + + auto result = Parse("has(foo.bar)", "", options); + + ASSERT_THAT(result, IsOk()); + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.PrintProto(result->expr()); + EXPECT_EQ(adorned_string, + "has(\n" + " foo^#2:Expr.Ident#.bar^#3:Expr.Select#\n" + ")^#1:Expr.Call#") + << adorned_string; +} + TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { ParserOptions options; options.max_recursion_depth = 6; diff --git a/parser/source_factory.h b/parser/source_factory.h index 71a184474..501e1017a 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -27,6 +27,12 @@ class EnrichedSourceInfo { std::map> offsets) : offsets_(std::move(offsets)) {} + EnrichedSourceInfo() = default; + EnrichedSourceInfo(const EnrichedSourceInfo& other) = default; + EnrichedSourceInfo& operator=(const EnrichedSourceInfo& other) = default; + EnrichedSourceInfo(EnrichedSourceInfo&& other) = default; + EnrichedSourceInfo& operator=(EnrichedSourceInfo&& other) = default; + const std::map>& offsets() const { return offsets_; } From a43cbbf03fb6d7f1a40885ed2085e747aab41711 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 Nov 2024 12:05:27 -0800 Subject: [PATCH 018/180] Add ParserBuilderInterface and ParserInterface classes. Add factory for stateful parser object. PiperOrigin-RevId: 693047557 --- parser/BUILD | 19 +++++++++++ parser/parser.cc | 71 +++++++++++++++++++++++++++++++++++++++ parser/parser.h | 11 ++++++ parser/parser_interface.h | 60 +++++++++++++++++++++++++++++++++ parser/parser_test.cc | 55 ++++++++++++++++++++++++++++++ 5 files changed, 216 insertions(+) create mode 100644 parser/parser_interface.h diff --git a/parser/BUILD b/parser/BUILD index 97ada22d8..94fc70d65 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -32,7 +32,9 @@ cc_library( ":macro_expr_factory", ":macro_registry", ":options", + ":parser_interface", ":source_factory", + "//base/ast_internal:ast_impl", "//base/ast_internal:expr", "//common:ast", "//common:constant", @@ -51,6 +53,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -173,12 +176,15 @@ cc_test( ":options", ":parser", ":source_factory", + "//base/ast_internal:ast_impl", "//common:constant", "//common:expr", + "//common:source", "//internal:benchmark", "//internal:testing", "//testutil:expr_printer", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -199,3 +205,16 @@ cc_library( "@com_google_absl//absl/status", ], ) + +cc_library( + name = "parser_interface", + hdrs = ["parser_interface.h"], + deps = [ + ":macro", + ":options", + "//common:ast", + "//common:source", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/parser/parser.cc b/parser/parser.cc index 44ad7c412..b051b1ca0 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -36,6 +36,7 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/overload.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -50,6 +51,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "antlr4-runtime.h" +#include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" #include "common/ast.h" #include "common/constant.h" @@ -69,6 +71,7 @@ #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" +#include "parser/parser_interface.h" #include "parser/source_factory.h" namespace google::api::expr::parser { @@ -1659,6 +1662,61 @@ absl::StatusOr ParseImpl(const cel::Source& source, } } +class ParserImpl : public cel::Parser { + public: + explicit ParserImpl(const ParserOptions& options, + cel::MacroRegistry macro_registry) + : options_(options), macro_registry_(std::move(macro_registry)) {} + absl::StatusOr> Parse( + const cel::Source& source) const override { + CEL_ASSIGN_OR_RETURN(auto parse_result, + ParseImpl(source, macro_registry_, options_)); + return std::make_unique( + std::move(parse_result.expr), std::move(parse_result.source_info)); + } + + private: + const ParserOptions options_; + const cel::MacroRegistry macro_registry_; +}; + +class ParserBuilderImpl : public cel::ParserBuilder { + public: + explicit ParserBuilderImpl(const ParserOptions& options) + : options_(options) {} + + ParserOptions& GetOptions() override { return options_; } + + absl::Status AddMacro(const cel::Macro& macro) override { + for (const auto& existing_macro : macros_) { + if (existing_macro.key() == macro.key()) { + return absl::AlreadyExistsError( + absl::StrCat("macro already exists: ", macro.key())); + } + } + macros_.push_back(macro); + return absl::OkStatus(); + } + + absl::StatusOr> Build() && override { + cel::MacroRegistry macro_registry; + + if (!options_.disable_standard_macros) { + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); + } + if (options_.enable_optional_syntax) { + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); + } + CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros_)); + return std::make_unique(options_, std::move(macro_registry)); + } + + private: + ParserOptions options_; + std::vector macros_; +}; + } // namespace absl::StatusOr Parse(absl::string_view expression, @@ -1719,3 +1777,16 @@ absl::StatusOr Parse( } } // namespace google::api::expr::parser + +namespace cel { + +// Creates a new parser builder. +// +// Intended for use with the Compiler class, most users should prefer the free +// functions above for independent parsing of expressions. +std::unique_ptr NewParserBuilder(const ParserOptions& options) { + return std::make_unique( + options); +} + +} // namespace cel diff --git a/parser/parser.h b/parser/parser.h index 24229bc2a..4b32c1c42 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -21,6 +21,7 @@ #ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ #define THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ +#include #include #include @@ -31,6 +32,7 @@ #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" +#include "parser/parser_interface.h" #include "parser/source_factory.h" namespace google::api::expr::parser { @@ -88,4 +90,13 @@ absl::StatusOr Parse( } // namespace google::api::expr::parser +namespace cel { +// Creates a new parser builder. +// +// Intended for use with the Compiler class, most users should prefer the free +// functions above for independent parsing of expressions. +std::unique_ptr NewParserBuilder( + const ParserOptions& options = {}); +} // namespace cel + #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ diff --git a/parser/parser_interface.h b/parser/parser_interface.h new file mode 100644 index 000000000..edbcb1fa3 --- /dev/null +++ b/parser/parser_interface.h @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "parser/macro.h" +#include "parser/options.h" + +namespace cel { + +class Parser; + +// Interface for building a CEL parser, see comments on `Parser` below. +class ParserBuilder { + public: + virtual ~ParserBuilder() = default; + + // Returns the (mutable) current parser options. + virtual ParserOptions& GetOptions() = 0; + + // Adds a macro to the parser. + // Standard macros should be automatically added based on parser options. + virtual absl::Status AddMacro(const cel::Macro& macro) = 0; + + // Builds a new parser instance, may error if incompatible macros are added. + virtual absl::StatusOr> Build() && = 0; +}; + +// Interface for stateful CEL parser objects for use with a `Compiler` +// (bundled parse and type check). This is not needed for most users: +// prefer using the free functions in `parser.h` for more flexibility. +class Parser { + public: + virtual ~Parser() = default; + + // Parses the given source into a CEL AST. + virtual absl::StatusOr> Parse( + const cel::Source& source) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index ea0484b1a..0838fbfff 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -22,6 +22,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" @@ -29,8 +30,10 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "base/ast_internal/ast_impl.h" #include "common/constant.h" #include "common/expr.h" +#include "common/source.h" #include "internal/benchmark.h" #include "internal/testing.h" #include "parser/macro.h" @@ -43,6 +46,7 @@ namespace google::api::expr::parser { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::ConstantKindCase; using ::cel::ExprKindCase; using ::cel::test::ExprPrinter; @@ -1536,6 +1540,57 @@ TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { EXPECT_THAT(result, IsOk()); } +TEST(NewParserBuilderTest, Defaults) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, + cel::NewSource("has(a.b) && [].exists(x, x > 0)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, CustomMacros) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddMacro(cel::HasMacro()), IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + builder.reset(); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].map(x, x)")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + EXPECT_FALSE(ast->IsChecked()); + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + const auto& ast_impl = cel::ast_internal::AstImpl::CastFromPublicAst(*ast); + EXPECT_EQ(w.Print(ast_impl.root_expr()), + "_&&_(\n" + " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" + " []^#5:Expr.CreateList#.map(\n" + " x^#7:Expr.Ident#,\n" + " x^#8:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#9:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ForwardsOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b")); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); + + builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = false; + ASSERT_OK_AND_ASSIGN(parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(source, cel::NewSource("a.?b")); + EXPECT_THAT(parser->Parse(*source), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From e6aa21172bdb83bb88ed63008c3ad7b7e1f10412 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 4 Nov 2024 12:13:34 -0800 Subject: [PATCH 019/180] Update AST conversion to support `iter_var2` PiperOrigin-RevId: 693050279 --- common/expr.h | 23 +++++++++++++++++++ extensions/protobuf/internal/ast.cc | 2 ++ extensions/protobuf/internal/ast_test.cc | 28 ++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/common/expr.h b/common/expr.h index f6a32d4ee..18828471f 100644 --- a/common/expr.h +++ b/common/expr.h @@ -581,6 +581,27 @@ class ComprehensionExpr final { return release(iter_var_); } + ABSL_MUST_USE_RESULT const std::string& iter_var2() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return iter_var2_; + } + + void set_iter_var2(std::string iter_var2) { + iter_var2_ = std::move(iter_var2); + } + + void set_iter_var2(absl::string_view iter_var2) { + iter_var2_.assign(iter_var2.data(), iter_var2.size()); + } + + void set_iter_var2(const char* iter_var2) { + set_iter_var2(absl::NullSafeStringView(iter_var2)); + } + + ABSL_MUST_USE_RESULT std::string release_iter_var2() { + return release(iter_var2_); + } + ABSL_MUST_USE_RESULT bool has_iter_range() const { return iter_range_ != nullptr; } @@ -685,6 +706,7 @@ class ComprehensionExpr final { friend void swap(ComprehensionExpr& lhs, ComprehensionExpr& rhs) noexcept { using std::swap; swap(lhs.iter_var_, rhs.iter_var_); + swap(lhs.iter_var2_, rhs.iter_var2_); swap(lhs.iter_range_, rhs.iter_range_); swap(lhs.accu_var_, rhs.accu_var_); swap(lhs.accu_init_, rhs.accu_init_); @@ -711,6 +733,7 @@ class ComprehensionExpr final { } std::string iter_var_; + std::string iter_var2_; std::unique_ptr iter_range_; std::string accu_var_; std::unique_ptr accu_init_; diff --git a/extensions/protobuf/internal/ast.cc b/extensions/protobuf/internal/ast.cc index e6972317c..4d2d0ed13 100644 --- a/extensions/protobuf/internal/ast.cc +++ b/extensions/protobuf/internal/ast.cc @@ -227,6 +227,7 @@ class ExprToProtoState final { auto* comprehension_proto = proto->mutable_comprehension_expr(); proto->set_id(expr.id()); comprehension_proto->set_iter_var(comprehension_expr.iter_var()); + comprehension_proto->set_iter_var2(comprehension_expr.iter_var2()); if (comprehension_expr.has_iter_range()) { Push(comprehension_expr.iter_range(), comprehension_proto->mutable_iter_range()); @@ -457,6 +458,7 @@ class ExprFromProtoState final { expr.set_id(proto.id()); auto& comprehension_expr = expr.mutable_comprehension_expr(); comprehension_expr.set_iter_var(comprehension_proto.iter_var()); + comprehension_expr.set_iter_var2(comprehension_proto.iter_var2()); comprehension_expr.set_accu_var(comprehension_proto.accu_var()); if (comprehension_proto.has_iter_range()) { Push(comprehension_proto.iter_range(), diff --git a/extensions/protobuf/internal/ast_test.cc b/extensions/protobuf/internal/ast_test.cc index 243d75920..cf8d36748 100644 --- a/extensions/protobuf/internal/ast_test.cc +++ b/extensions/protobuf/internal/ast_test.cc @@ -220,6 +220,34 @@ INSTANTIATE_TEST_SUITE_P( } } )pb"}, + {R"pb( + id: 1 + comprehension_expr { + iter_var: "foo" + iter_var2: "baz" + iter_range { + id: 2 + list_expr {} + } + accu_var: "bar" + accu_init { + id: 3 + list_expr {} + } + loop_condition { + id: 4 + const_expr { bool_value: true } + } + loop_step { + id: 4 + ident_expr { name: "bar" } + } + result { + id: 5 + ident_expr { name: "foo" } + } + } + )pb"}, })); TEST(ExprFromProto, StructFieldInMap) { From 61d0334e263008403a50859afa71bb05bf519c97 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 4 Nov 2024 12:39:17 -0800 Subject: [PATCH 020/180] Update type checker to handle `iter_var2` PiperOrigin-RevId: 693058259 --- checker/internal/type_checker_impl.cc | 21 ++++++++++++++++----- checker/standard_library.cc | 13 ++++++++++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6cda11e4f..9adf4f164 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -807,13 +807,17 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( break; case ComprehensionArg::ITER_RANGE: { Type range_type = GetTypeOrDyn(&comprehension.iter_range()); - Type iter_type = DynType(); + Type iter_type = DynType(); // iter_var for non comprehensions v2. + Type iter_type1 = DynType(); // iter_var for comprehensions v2. + Type iter_type2 = DynType(); // iter_var2 for comprehensions v2. switch (range_type.kind()) { case TypeKind::kList: - iter_type = range_type.GetList().element(); + iter_type1 = IntType(); + iter_type = iter_type2 = range_type.GetList().element(); break; case TypeKind::kMap: - iter_type = range_type.GetMap().key(); + iter_type = iter_type1 = range_type.GetMap().key(); + iter_type2 = range_type.GetMap().value(); break; case TypeKind::kDyn: break; @@ -827,8 +831,15 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( "list, map, or dynamic)"))); break; } - scope.iter_scope->InsertVariableIfAbsent( - MakeVariableDecl(comprehension.iter_var(), iter_type)); + if (comprehension.iter_var2().empty()) { + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var(), iter_type)); + } else { + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var(), iter_type1)); + scope.iter_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.iter_var2(), iter_type2)); + } break; } case ComprehensionArg::RESULT: diff --git a/checker/standard_library.cc b/checker/standard_library.cc index dcdda3fb8..18c43a8ad 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -1058,6 +1058,17 @@ absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { return absl::OkStatus(); } +absl::Status AddComprehensionsV2Functions(TypeCheckerBuilder& builder) { + FunctionDecl map_insert; + map_insert.set_name("@cel.mapInsert"); + CEL_RETURN_IF_ERROR(map_insert.AddOverload( + MakeOverloadDecl("@mapInsert_map_key_value", MapOfAB(), MapOfAB(), + TypeParamA(), TypeParamB()))); + CEL_RETURN_IF_ERROR(map_insert.AddOverload( + MakeOverloadDecl("@mapInsert_map_map", MapOfAB(), MapOfAB(), MapOfAB()))); + return builder.AddFunction(map_insert); +} + absl::Status AddStandardLibraryDecls(TypeCheckerBuilder& builder) { CEL_RETURN_IF_ERROR(AddLogicalOps(builder)); CEL_RETURN_IF_ERROR(AddArithmeticOps(builder)); @@ -1070,7 +1081,7 @@ absl::Status AddStandardLibraryDecls(TypeCheckerBuilder& builder) { CEL_RETURN_IF_ERROR(AddTimeFunctions(builder)); CEL_RETURN_IF_ERROR(AddTypeConstantVariables(builder)); CEL_RETURN_IF_ERROR(AddEnumConstants(builder)); - + CEL_RETURN_IF_ERROR(AddComprehensionsV2Functions(builder)); return absl::OkStatus(); } From 834c7fd3e94be4ed57964cda62986e8302f449f9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 Nov 2024 16:20:17 -0800 Subject: [PATCH 021/180] Rename StandardLibrary() to StandardCheckerLibrary(). PiperOrigin-RevId: 693126582 --- checker/optional_test.cc | 6 +++--- checker/standard_library.cc | 4 +++- checker/standard_library.h | 2 +- checker/standard_library_test.cc | 12 ++++++------ conformance/service.cc | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 841597061..ae4383883 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -79,7 +79,7 @@ TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { ASSERT_OK_AND_ASSIGN( TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, @@ -116,7 +116,7 @@ TEST_P(OptionalTest, Runner) { TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); const TestCase& test_case = GetParam(); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(builder).Build()); @@ -275,7 +275,7 @@ TEST_P(OptionalStrictNullAssignmentTest, Runner) { TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); const TestCase& test_case = GetParam(); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(builder).Build()); diff --git a/checker/standard_library.cc b/checker/standard_library.cc index 18c43a8ad..3cc246482 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -1088,5 +1088,7 @@ absl::Status AddStandardLibraryDecls(TypeCheckerBuilder& builder) { } // namespace // Returns a CheckerLibrary containing all of the standard CEL declarations. -CheckerLibrary StandardLibrary() { return {"stdlib", AddStandardLibraryDecls}; } +CheckerLibrary StandardCheckerLibrary() { + return {"stdlib", AddStandardLibraryDecls}; +} } // namespace cel diff --git a/checker/standard_library.h b/checker/standard_library.h index e42fb0a24..05f6d5bb7 100644 --- a/checker/standard_library.h +++ b/checker/standard_library.h @@ -19,7 +19,7 @@ namespace cel { // Returns a CheckerLibrary containing all of the standard CEL declarations. -CheckerLibrary StandardLibrary(); +CheckerLibrary StandardCheckerLibrary(); } // namespace cel diff --git a/checker/standard_library_test.cc b/checker/standard_library_test.cc index 1968c0294..ca51b8aaa 100644 --- a/checker/standard_library_test.cc +++ b/checker/standard_library_test.cc @@ -52,7 +52,7 @@ TEST(StandardLibraryTest, StandardLibraryAddsDecls) { ASSERT_OK_AND_ASSIGN( TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - EXPECT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); EXPECT_THAT(std::move(builder).Build(), IsOk()); } @@ -60,8 +60,8 @@ TEST(StandardLibraryTest, StandardLibraryErrorsIfAddedTwice) { ASSERT_OK_AND_ASSIGN( TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - EXPECT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); - EXPECT_THAT(builder.AddLibrary(StandardLibrary()), + EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); + EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), StatusIs(absl::StatusCode::kAlreadyExists)); } @@ -70,7 +70,7 @@ TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) { ASSERT_OK_AND_ASSIGN( TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); // Note: this is atypical -- parameterized variables aren't well supported // outside of built-in syntax. @@ -110,7 +110,7 @@ class StandardLibraryDefinitionsTest : public ::testing::Test { ASSERT_OK_AND_ASSIGN( TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(stdlib_type_checker_, std::move(builder).Build()); } @@ -215,7 +215,7 @@ TEST_P(StdLibDefinitionsTest, Runner) { TypeCheckerBuilder builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), GetParam().options)); - ASSERT_THAT(builder.AddLibrary(StandardLibrary()), IsOk()); + ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, std::move(builder).Build()); diff --git a/conformance/service.cc b/conformance/service.cc index a0593971d..d2a8bffed 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -638,7 +638,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { google::protobuf::DescriptorPool::generated_pool())); if (!request.no_std_env()) { - CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::StandardLibrary())); + CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::OptionalCheckerLibrary())); } From e8fdff4f3f3f1ada0637928ed911c0798904d6a9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 5 Nov 2024 10:44:24 -0800 Subject: [PATCH 022/180] Update assignability checks for lists and maps to consider all elements before accepting new inferred types. PiperOrigin-RevId: 693403660 --- checker/internal/BUILD | 1 - checker/internal/type_checker_impl.cc | 21 +++++- checker/internal/type_checker_impl_test.cc | 29 +++++++ checker/internal/type_inference_context.cc | 34 ++++++--- checker/internal/type_inference_context.h | 62 +++++++++++++-- .../internal/type_inference_context_test.cc | 75 +++++++++++++++++++ 6 files changed, 205 insertions(+), 17 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 6d43a83b3..2fbbf47d2 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -162,7 +162,6 @@ cc_test( "//common:expr", "//common:source", "//common:type", - "//extensions/protobuf:value", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 9adf4f164..027c3a87b 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -576,6 +576,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { Type overall_value_type = inference_context_->InstantiateTypeParams(TypeParamType("V")); + auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& entry : map.entries()) { const Expr* key = &entry.key(); Type key_type = GetTypeOrDyn(key); @@ -593,10 +594,17 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { inference_context_->FinalizeType(key_type).DebugString()))); } - if (!inference_context_->IsAssignable(key_type, overall_key_type)) { + if (!assignability_context.IsAssignable(key_type, overall_key_type)) { overall_key_type = DynType(); } + } + + if (!overall_key_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + assignability_context.Reset(); + for (const auto& entry : map.entries()) { const Expr* value = &entry.value(); Type value_type = GetTypeOrDyn(value); if (entry.optional()) { @@ -613,6 +621,10 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } } + if (!overall_value_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + types_[&expr] = inference_context_->FullySubstitute( MapType(arena_, overall_key_type, overall_value_type)); } @@ -622,6 +634,7 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); + auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& element : list.elements()) { const Expr* value = &element.expr(); Type value_type = GetTypeOrDyn(value); @@ -635,11 +648,15 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { } } - if (!inference_context_->IsAssignable(value_type, overall_elem_type)) { + if (!assignability_context.IsAssignable(value_type, overall_elem_type)) { overall_elem_type = DynType(); } } + if (!overall_elem_type.IsDyn()) { + assignability_context.UpdateInferredTypeAssignments(); + } + types_[&expr] = inference_context_->FullySubstitute(ListType(arena_, overall_elem_type)); } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index e0ff26ff8..d4eb2c1a3 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -2017,6 +2017,35 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(ast_internal::DynamicType()), })); +INSTANTIATE_TEST_SUITE_P( + TypeInferences, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "[1, test_msg.single_int64_wrapper]", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64))))}, + CheckedExprTestCase{ + .expr = "[1, 2, test_msg.single_int64_wrapper]", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64))))}, + CheckedExprTestCase{ + .expr = "[test_msg.single_int64_wrapper, 1]", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64))))}, + CheckedExprTestCase{ + .expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::DynamicType())))}, + + CheckedExprTestCase{ + .expr = "test_msg.single_int64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + })); + class StrictNullAssignmentTest : public testing::TestWithParam {}; diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 2a508038a..4c4900058 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -261,15 +261,6 @@ bool TypeInferenceContext::IsAssignableInternal( prospective_substitutions); } - // Type is as concrete as it can be under current substitutions. - if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); - wrapped_type.has_value()) { - return IsAssignableInternal(NullType(), from_subs, - prospective_substitutions) || - IsAssignableInternal(*wrapped_type, from_subs, - prospective_substitutions); - } - // Maybe widen a prospective type binding if it is a member of a union type. // This enables things like `true ? 1 : single_int64_wrapper` to promote // the left hand side of the ternary to an int wrapper. @@ -287,6 +278,15 @@ bool TypeInferenceContext::IsAssignableInternal( return true; } + // Type is as concrete as it can be under current substitutions. + if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); + wrapped_type.has_value()) { + return IsAssignableInternal(NullType(), from_subs, + prospective_substitutions) || + IsAssignableInternal(*wrapped_type, from_subs, + prospective_substitutions); + } + // Wrapper types are assignable to their corresponding primitive type ( // somewhat similar to auto unboxing). This is a bit odd with CEL's null_type, // but there isn't a dedicated syntax for narrowing from the nullable. @@ -538,4 +538,20 @@ Type TypeInferenceContext::FullySubstitute(const Type& type, } } +bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, + const Type& to) { + return inference_context_.IsAssignableInternal(from, to, + prospective_substitutions_); +} + +void TypeInferenceContext::AssignabilityContext:: + UpdateInferredTypeAssignments() { + inference_context_.UpdateTypeParameterBindings( + std::move(prospective_substitutions_)); +} + +void TypeInferenceContext::AssignabilityContext::Reset() { + prospective_substitutions_.clear(); +} + } // namespace cel::checker_internal diff --git a/checker/internal/type_inference_context.h b/checker/internal/type_inference_context.h index c4e15188c..3b1939d2b 100644 --- a/checker/internal/type_inference_context.h +++ b/checker/internal/type_inference_context.h @@ -50,11 +50,68 @@ class TypeInferenceContext { std::vector overloads; }; + private: + // Alias for a map from type var name to the type it is bound to. + // + // Used for prospective substitutions during type inference to make progress + // without affecting final assigned types. + using SubstitutionMap = absl::flat_hash_map; + + public: + // Helper class for managing several dependent type assignability checks. + // + // Note: while allowed, updating multiple AssignabilityContexts concurrently + // can lead to inconsistencies in the final type bindings. + class AssignabilityContext { + public: + // Checks if `from` is assignable to `to` with the current type + // substitutions and any additional prospective substitutions in the parent + // inference context. + bool IsAssignable(const Type& from, const Type& to); + + // Applies any prospective type assignments to the parent inference context. + // + // This should only be called after all assignability checks have completed. + // + // Leaves the AssignabilityContext in the starting state (i.e. no + // prospective substitutions). + void UpdateInferredTypeAssignments(); + + // Return the AssignabilityContext to the starting state (i.e. no + // prospective substitutions). + void Reset(); + + private: + explicit AssignabilityContext(TypeInferenceContext& inference_context) + : inference_context_(inference_context) {} + + AssignabilityContext(const AssignabilityContext&) = delete; + AssignabilityContext& operator=(const AssignabilityContext&) = delete; + AssignabilityContext(AssignabilityContext&&) = delete; + AssignabilityContext& operator=(AssignabilityContext&&) = delete; + + friend class TypeInferenceContext; + + TypeInferenceContext& inference_context_; + SubstitutionMap prospective_substitutions_; + }; + explicit TypeInferenceContext(google::protobuf::Arena* arena, bool enable_legacy_null_assignment = true) : arena_(arena), enable_legacy_null_assignment_(enable_legacy_null_assignment) {} + // Creates a new AssignabilityContext for the current inference context. + // + // This is intended for managing several dependent type assignability checks + // that should only be added to the final type bindings if all checks succeed. + // + // Note: while allowed, updating multiple AssignabilityContexts concurrently + // can lead to inconsistencies in the final type bindings. + AssignabilityContext CreateAssignabilityContext() + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return AssignabilityContext(*this); + } // Resolves any remaining type parameters in the given type to a concrete // type or dyn. Type FinalizeType(const Type& type) const { @@ -98,11 +155,6 @@ class TypeInferenceContext { } private: - // Alias for a map from type var name to the type it is bound to. - // - // Used for prospective substitutions during type inference. - using SubstitutionMap = absl::flat_hash_map; - struct TypeVar { absl::optional type; absl::string_view name; diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index faef3879a..bc9513574 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -711,5 +711,80 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { ElementsAre(IsTypeKind(TypeKind::kInt))); } +TEST(TypeInferenceContextTest, AssignabilityContext) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntWrapperType(), list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kIntWrapper)); +} + +TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + IntWrapperType(), list_of_a_instance.AsList()->GetElement())); + } + + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kDyn)); +} + +TEST(TypeInferenceContextTest, AssignabilityContextReset) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + IntType(), list_of_a_instance.AsList()->GetElement())); + assignability_context.Reset(); + EXPECT_TRUE(assignability_context.IsAssignable( + DoubleType(), list_of_a_instance.AsList()->GetElement())); + assignability_context.UpdateInferredTypeAssignments(); + } + + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + EXPECT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kDouble)); +} + } // namespace } // namespace cel::checker_internal From 9310c4910e598362695930f0e11b7f278f714755 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 5 Nov 2024 11:05:34 -0800 Subject: [PATCH 023/180] Update type assignment widening behavior to more closely follow the 'MoreGeneral' check in the Go and Java implementations. PiperOrigin-RevId: 693412150 --- checker/internal/BUILD | 1 + checker/internal/type_checker_impl_test.cc | 64 ++++++++++- checker/internal/type_inference_context.cc | 103 +++++++++++++++--- checker/internal/type_inference_context.h | 19 ++++ .../internal/type_inference_context_test.cc | 60 ++++++++++ checker/optional_test.cc | 12 ++ 6 files changed, 243 insertions(+), 16 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 2fbbf47d2..68ea74f4f 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -165,6 +165,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//testutil:baseline_tests", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index d4eb2c1a3..d64e22cc3 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -45,6 +45,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "testutil/baseline_tests.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -221,11 +222,18 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull are FunctionDecl ternary_op; ternary_op.set_name("_?_:_"); - CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( + CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl( "conditional", /*return_type=*/ TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A")))); + FunctionDecl index_op; + index_op.set_name("_[_]"); + CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl( + "index", + /*return_type=*/ + TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType()))); + FunctionDecl to_int; to_int.set_name("int"); CEL_RETURN_IF_ERROR(to_int.AddOverload( @@ -268,6 +276,7 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull are env.InsertFunctionIfAbsent(std::move(to_int)); env.InsertFunctionIfAbsent(std::move(eq_op)); env.InsertFunctionIfAbsent(std::move(ternary_op)); + env.InsertFunctionIfAbsent(std::move(index_op)); env.InsertFunctionIfAbsent(std::move(to_dyn)); env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); @@ -1543,7 +1552,8 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); EXPECT_THAT(ast_impl.type_map(), Contains(Pair(ast_impl.root_expr().id(), - Eq(test_case.expected_result_type)))); + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); } INSTANTIATE_TEST_SUITE_P( @@ -2039,11 +2049,59 @@ INSTANTIATE_TEST_SUITE_P( .expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]", .expected_result_type = AstType(ast_internal::ListType( std::make_unique(ast_internal::DynamicType())))}, - + CheckedExprTestCase{ + .expr = "[null, test_msg][0]", + .expected_result_type = AstType(ast_internal::MessageType( + "cel.expr.conformance.proto3.TestAllTypes"))}, + CheckedExprTestCase{ + .expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]", + // Ambiguous type resolution, but we prefer the first option. + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {dyn('k'): 1}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique( + ast_internal::PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{dyn('k'): 1}, {'k': 1}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique( + ast_internal::PrimitiveType::kInt64)))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {'k': dyn(1)}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::DynamicType()), + std::make_unique(ast_internal::DynamicType())))}, + CheckedExprTestCase{ + .expr = + "[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]", + .expected_result_type = AstType(ast_internal::DynamicType())}, CheckedExprTestCase{ .expr = "test_msg.single_int64", .expected_result_type = AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "[[1], {1: 2u}][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), + }, + CheckedExprTestCase{ + .expr = "[{1: 2u}, [1]][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), + }, + CheckedExprTestCase{ + .expr = "[test_msg.single_int64_wrapper," + " test_msg.single_string_wrapper][0]", + .expected_result_type = AstType(ast_internal::DynamicType()), })); class StrictNullAssignmentTest diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 4c4900058..19d59daec 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -261,28 +261,30 @@ bool TypeInferenceContext::IsAssignableInternal( prospective_substitutions); } - // Maybe widen a prospective type binding if it is a member of a union type. - // This enables things like `true ? 1 : single_int64_wrapper` to promote - // the left hand side of the ternary to an int wrapper. - // This is a bit restricted to encourage more specific type -> type var - // assignments. + // Maybe widen a prospective type binding if another potential binding is + // more general and admits the previous binding. if ( // Checking assignability to a specific type var // that has a prospective type assignment. to.kind() == TypeKind::kTypeParam && - prospective_substitutions.contains(to.AsTypeParam()->name()) && - // from is a more general type that to and accepts the current - // prospective binding for to. - IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { - prospective_substitutions[to.AsTypeParam()->name()] = from_subs; - return true; + prospective_substitutions.contains(to.AsTypeParam()->name())) { + auto prospective_subs_cpy(prospective_substitutions); + if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) == + RelativeGenerality::kMoreGeneral) { + if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) && + !OccursWithin(to.name(), from_subs, prospective_subs_cpy)) { + prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs; + prospective_substitutions = prospective_subs_cpy; + return true; + // otherwise, continue with normal assignability check. + } + } } // Type is as concrete as it can be under current substitutions. if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { - return IsAssignableInternal(NullType(), from_subs, - prospective_substitutions) || + return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, prospective_substitutions); } @@ -364,6 +366,81 @@ Type TypeInferenceContext::Substitute( return subs; } +TypeInferenceContext::RelativeGenerality +TypeInferenceContext::CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const { + Type from_subs = Substitute(from, prospective_substitutions); + Type to_subs = Substitute(to, prospective_substitutions); + + if (from_subs == to_subs) { + return RelativeGenerality::kEquivalent; + } + + if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { + return RelativeGenerality::kMoreGeneral; + } + + if (IsUnionType(to_subs)) { + return RelativeGenerality::kLessGeneral; + } + + if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) && + to_subs.IsNull()) { + return RelativeGenerality::kMoreGeneral; + } + + // Not a polytype. Check if it is a parameterized type and all parameters are + // equivalent and at least one is more general. + if (from_subs.IsList() && to_subs.IsList()) { + return CompareGenerality(from_subs.AsList()->GetElement(), + to_subs.AsList()->GetElement(), + prospective_substitutions); + } + + if (from_subs.IsMap() && to_subs.IsMap()) { + RelativeGenerality key_generality = + CompareGenerality(from_subs.AsMap()->GetKey(), + to_subs.AsMap()->GetKey(), prospective_substitutions); + RelativeGenerality value_generality = CompareGenerality( + from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(), + prospective_substitutions); + if (key_generality == RelativeGenerality::kLessGeneral || + value_generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (key_generality == RelativeGenerality::kMoreGeneral || + value_generality == RelativeGenerality::kMoreGeneral) { + return RelativeGenerality::kMoreGeneral; + } + return RelativeGenerality::kEquivalent; + } + + if (from_subs.IsOpaque() && to_subs.IsOpaque() && + from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() && + from_subs.AsOpaque()->GetParameters().size() == + to_subs.AsOpaque()->GetParameters().size()) { + RelativeGenerality max_generality = RelativeGenerality::kEquivalent; + for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) { + RelativeGenerality generality = CompareGenerality( + from_subs.AsOpaque()->GetParameters()[i], + to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions); + if (generality == RelativeGenerality::kLessGeneral) { + return RelativeGenerality::kLessGeneral; + } + if (generality == RelativeGenerality::kMoreGeneral) { + max_generality = RelativeGenerality::kMoreGeneral; + } + } + return max_generality; + } + + // Default not comparable. Since we ruled out polytypes, they should be + // equivalent for the purposes of deciding the most general eligible + // substitution. + return RelativeGenerality::kEquivalent; +} + bool TypeInferenceContext::OccursWithin( absl::string_view var_name, const Type& type, const SubstitutionMap& substitutions) const { diff --git a/checker/internal/type_inference_context.h b/checker/internal/type_inference_context.h index 3b1939d2b..898af657f 100644 --- a/checker/internal/type_inference_context.h +++ b/checker/internal/type_inference_context.h @@ -160,6 +160,15 @@ class TypeInferenceContext { absl::string_view name; }; + // Relative generality between two types. + enum class RelativeGenerality { + kMoreGeneral, + // Note: kLessGeneral does not imply it is definitely more specific, only + // that we cannot determine if equivalent or more general. + kLessGeneral, + kEquivalent, + }; + absl::string_view NewTypeVar(absl::string_view name = "") { next_type_parameter_id_++; auto inserted = type_parameter_bindings_.insert( @@ -190,6 +199,16 @@ class TypeInferenceContext { bool IsAssignableWithConstraints(const Type& from, const Type& to, SubstitutionMap& prospective_substitutions); + // Relative generality of `from` as compared to `to` with the current type + // substitutions and any additional prospective substitutions. + // + // Generality is only defined as a partial ordering. Some types are + // incomparable. However we only need to know if a type is definitely more + // general or not. + RelativeGenerality CompareGenerality( + const Type& from, const Type& to, + const SubstitutionMap& prospective_substitutions) const; + Type Substitute(const Type& type, const SubstitutionMap& substitutions) const; bool OccursWithin(absl::string_view var_name, const Type& type, diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index bc9513574..93543c82d 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -737,6 +737,66 @@ TEST(TypeInferenceContextTest, AssignabilityContext) { IsTypeKind(TypeKind::kIntWrapper)); } +TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, DynType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kDyn))); +} + +TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) { + google::protobuf::Arena arena; + TypeInferenceContext context(&arena); + + Type list_of_a = ListType(&arena, TypeParamType("A")); + + Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); + + { + auto assignability_context = context.CreateAssignabilityContext(); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntType()), + list_of_a_instance.AsList()->GetElement())); + EXPECT_TRUE(assignability_context.IsAssignable( + OptionalType(&arena, IntWrapperType()), + list_of_a_instance.AsList()->GetElement())); + + assignability_context.UpdateInferredTypeAssignments(); + } + Type resolved_type = context.FinalizeType(list_of_a_instance); + + ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); + ASSERT_THAT(resolved_type.AsList()->GetElement(), + IsTypeKind(TypeKind::kOpaque)); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), + "optional_type"); + EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), + ElementsAre(IsTypeKind(TypeKind::kIntWrapper))); +} + TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); diff --git a/checker/optional_test.cc b/checker/optional_test.cc index ae4383883..126225668 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -227,6 +227,18 @@ INSTANTIATE_TEST_SUITE_P( new AstType(ast_internal::PrimitiveType::kString)))))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type' but found 'string'"}, + TestCase{"[optional.of(dyn('1')), optional.of('2')][0]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]", + IsOptionalType(AstType(ast_internal::DynamicType()))}, + TestCase{"[optional.of('1'), optional.of(2)][0]", + Eq(AstType(ast_internal::DynamicType()))}, + TestCase{"['v1', ?'v2']", _, + "expected type 'optional_type' but found 'string'"}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " "optional.of(1)}", Eq(AstType(ast_internal::MessageType( From 997ac2ee800513ec208ced6d42509ebed95755ea Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 5 Nov 2024 16:08:59 -0800 Subject: [PATCH 024/180] Harden builtin macros against usages of `__result__` PiperOrigin-RevId: 693509809 --- parser/BUILD | 15 ++++++ parser/macro.cc | 56 +++++++++++++++++--- parser/standard_macros_test.cc | 95 ++++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 8 deletions(-) create mode 100644 parser/standard_macros_test.cc diff --git a/parser/BUILD b/parser/BUILD index 94fc70d65..11cf95f85 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -218,3 +218,18 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +cc_test( + name = "standard_macros_test", + srcs = ["standard_macros_test.cc"], + deps = [ + ":macro_registry", + ":options", + ":parser", + ":standard_macros", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/parser/macro.cc b/parser/macro.cc index e9312ce8a..b11dca5db 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -87,10 +87,15 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("all() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt(args[1], + absl::StrCat("all() variable name cannot be ", + kAccumulatorVariableName)); + } auto init = factory.NewBoolConst(true); auto condition = factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); @@ -114,10 +119,15 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("exists() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists() variable name cannot be ", + kAccumulatorVariableName)); + } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( CelOperator::NOT_STRICTLY_FALSE, @@ -143,10 +153,15 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("exists_one() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists_one() variable name cannot be ", + kAccumulatorVariableName)); + } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); auto step = @@ -174,10 +189,15 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("map() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt(args[1], + absl::StrCat("map() variable name cannot be ", + kAccumulatorVariableName)); + } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); auto step = factory.NewCall( @@ -200,10 +220,15 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("map() requires 3 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt(args[1], + absl::StrCat("map() variable name cannot be ", + kAccumulatorVariableName)); + } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); auto step = factory.NewCall( @@ -228,10 +253,15 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("filter() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("filter() variable name cannot be ", + kAccumulatorVariableName)); + } auto name = args[0].ident_expr().name(); auto init = factory.NewList(); @@ -259,10 +289,15 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("optMap() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("optMap() variable name cannot be ", + kAccumulatorVariableName)); + } auto var_name = args[0].ident_expr().name(); auto target_copy = factory.Copy(target); @@ -293,10 +328,15 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("optFlatMap() requires 2 arguments"); } - if (!args[0].has_ident_expr()) { + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("optFlatMap() variable name cannot be ", + kAccumulatorVariableName)); + } auto var_name = args[0].ident_expr().name(); auto target_copy = factory.Copy(target); diff --git a/parser/standard_macros_test.cc b/parser/standard_macros_test.cc new file mode 100644 index 000000000..a79390f06 --- /dev/null +++ b/parser/standard_macros_test.cc @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "parser/standard_macros.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/source.h" +#include "internal/testing.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::HasSubstr; + +struct StandardMacrosTestCase { + std::string expression; + std::string error; +}; + +using StandardMacrosTest = ::testing::TestWithParam; + +TEST_P(StandardMacrosTest, Errors) { + const auto& test_param = GetParam(); + ASSERT_OK_AND_ASSIGN(auto source, NewSource(test_param.expression)); + + ParserOptions options; + options.enable_optional_syntax = true; + + MacroRegistry registry; + ASSERT_THAT(RegisterStandardMacros(registry, options), IsOk()); + + EXPECT_THAT(EnrichedParse(*source, registry, options), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_param.error))); +} + +INSTANTIATE_TEST_SUITE_P( + StandardMacrosTest, StandardMacrosTest, + ::testing::ValuesIn({ + { + .expression = "[].all(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists_one(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].map(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].map(__result__, true, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(__result__, __result__ == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "foo.optMap(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + { + .expression = "foo.optFlatMap(__result__, __result__)", + .error = "variable name cannot be __result__", + }, + })); + +} // namespace +} // namespace cel From ada9b5ff127eec05d1ef64322b83e348460f9df3 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 6 Nov 2024 07:56:33 -0800 Subject: [PATCH 025/180] Consolidate type providers into two: one for legacy and one for modern PiperOrigin-RevId: 693727813 --- common/BUILD | 2 - common/json.h | 4 +- common/type_reflector.cc | 917 +------- common/type_reflector.h | 9 +- common/type_reflector_test.cc | 111 +- common/value.cc | 27 + common/value.h | 18 +- common/value_factory.h | 4 + common/value_manager.h | 3 + common/values/legacy_value_manager.h | 6 + common/values/piecewise_value_manager.h | 12 + common/values/struct_value_builder.cc | 205 +- common/values/struct_value_builder.h | 13 +- .../values/thread_compatible_value_manager.h | 6 + common/values/value_builder.h | 37 + conformance/BUILD | 17 - conformance/run.bzl | 18 +- conformance/run.cc | 4 - conformance/service.cc | 42 +- eval/compiler/BUILD | 4 +- .../cel_expression_builder_flat_impl.cc | 4 +- .../cel_expression_builder_flat_impl.h | 3 +- .../cel_expression_builder_flat_impl_test.cc | 4 - eval/compiler/constant_folding.cc | 3 +- eval/compiler/flat_expr_builder.cc | 24 +- eval/compiler/flat_expr_builder.h | 14 +- eval/compiler/flat_expr_builder_test.cc | 119 - eval/compiler/resolver_test.cc | 20 - eval/eval/BUILD | 43 +- eval/eval/cel_expression_flat_impl.cc | 39 +- eval/eval/cel_expression_flat_impl.h | 39 +- eval/eval/comprehension_step_test.cc | 6 +- eval/eval/const_value_step.cc | 6 +- eval/eval/const_value_step.h | 4 +- eval/eval/const_value_step_test.cc | 58 +- eval/eval/container_access_step_test.cc | 27 +- eval/eval/create_list_step_test.cc | 67 +- eval/eval/create_map_step_test.cc | 35 +- eval/eval/create_struct_step.cc | 3 +- eval/eval/create_struct_step_test.cc | 219 +- eval/eval/evaluator_core.cc | 6 - eval/eval/evaluator_core.h | 5 +- eval/eval/evaluator_core_test.cc | 7 +- eval/eval/function_step_test.cc | 70 +- eval/eval/ident_step_test.cc | 41 +- eval/eval/logic_step_test.cc | 11 +- eval/eval/select_step_test.cc | 36 +- eval/eval/shadowable_value_step_test.cc | 25 +- eval/eval/ternary_step_test.cc | 10 +- eval/public/BUILD | 18 +- eval/public/cel_expr_builder_factory.cc | 9 - eval/public/cel_type_registry.cc | 38 +- eval/public/cel_type_registry.h | 27 +- ..._type_registry_protobuf_reflection_test.cc | 4 - eval/public/cel_type_registry_test.cc | 20 +- eval/public/structs/BUILD | 4 + eval/public/structs/legacy_type_provider.cc | 58 + eval/public/structs/legacy_type_provider.h | 6 +- .../protobuf_descriptor_type_provider.h | 20 +- eval/tests/BUILD | 1 - eval/tests/modern_benchmark_test.cc | 127 +- eval/testutil/BUILD | 7 - eval/testutil/args.proto | 47 - eval/testutil/simple_test_message.proto | 9 - extensions/BUILD | 2 +- extensions/protobuf/BUILD | 9 - extensions/protobuf/type_reflector.cc | 44 +- extensions/protobuf/type_reflector.h | 22 +- extensions/protobuf/type_reflector_test.cc | 117 - extensions/protobuf/value_end_to_end_test.cc | 1933 ++++++++--------- extensions/strings_test.cc | 43 +- internal/BUILD | 11 +- runtime/BUILD | 16 +- runtime/constant_folding_test.cc | 8 +- runtime/internal/BUILD | 70 +- runtime/internal/composed_type_provider.cc | 108 - runtime/internal/convert_constant.cc | 32 +- runtime/internal/convert_constant.h | 4 +- .../internal/legacy_runtime_type_provider.h | 37 + runtime/internal/runtime_env.h | 2 + runtime/internal/runtime_env_testing.cc | 10 +- runtime/internal/runtime_impl.cc | 52 +- runtime/internal/runtime_impl.h | 11 + runtime/internal/runtime_type_provider.cc | 161 ++ ...ype_provider.h => runtime_type_provider.h} | 56 +- runtime/internal/runtime_value_manager.h | 75 + runtime/optional_types_test.cc | 51 +- runtime/reference_resolver_test.cc | 44 +- runtime/regex_precompilation_test.cc | 24 +- runtime/runtime.h | 67 +- .../standard_runtime_builder_factory_test.cc | 623 +++--- runtime/type_registry.cc | 12 +- runtime/type_registry.h | 57 +- 93 files changed, 2879 insertions(+), 3624 deletions(-) create mode 100644 common/values/value_builder.h delete mode 100644 eval/testutil/args.proto delete mode 100644 eval/testutil/simple_test_message.proto delete mode 100644 extensions/protobuf/type_reflector_test.cc delete mode 100644 runtime/internal/composed_type_provider.cc create mode 100644 runtime/internal/legacy_runtime_type_provider.h create mode 100644 runtime/internal/runtime_type_provider.cc rename runtime/internal/{composed_type_provider.h => runtime_type_provider.h} (53%) create mode 100644 runtime/internal/runtime_value_manager.h diff --git a/common/BUILD b/common/BUILD index cc37eee2b..bc2fd69d3 100644 --- a/common/BUILD +++ b/common/BUILD @@ -653,11 +653,9 @@ cc_library( "//extensions/protobuf/internal:map_reflection", "//extensions/protobuf/internal:qualify", "//internal:casts", - "//internal:deserialize", "//internal:json", "//internal:message_equality", "//internal:number", - "//internal:overflow", "//internal:protobuf_runtime_version", "//internal:serialize", "//internal:status_macros", diff --git a/common/json.h b/common/json.h index 7233d06dc..b7fa1ddb3 100644 --- a/common/json.h +++ b/common/json.h @@ -495,9 +495,7 @@ class AnyToJsonConverter { return nullptr; } - virtual absl::Nullable message_factory() const { - return nullptr; - } + virtual absl::Nullable message_factory() const = 0; }; inline std::pair, diff --git a/common/type_reflector.cc b/common/type_reflector.cc index 4677f038a3..89c363425 100644 --- a/common/type_reflector.cc +++ b/common/type_reflector.cc @@ -14,942 +14,28 @@ #include "common/type_reflector.h" -#include -#include -#include -#include - #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "absl/time/time.h" #include "absl/types/optional.h" -#include "common/any.h" -#include "common/casting.h" -#include "common/json.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_factory.h" -#include "common/values/piecewise_value_manager.h" #include "common/values/thread_compatible_type_reflector.h" -#include "internal/deserialize.h" -#include "internal/overflow.h" -#include "internal/status_macros.h" namespace cel { -namespace { - -// Exception to `ValueBuilder` which also functions as a deserializer. -class WellKnownValueBuilder : public ValueBuilder { - public: - virtual absl::Status Deserialize(const absl::Cord& serialized_value) = 0; -}; - -class BoolValueBuilder final : public WellKnownValueBuilder { - public: - explicit BoolValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return BoolValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeBoolValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto bool_value = As(value); bool_value.has_value()) { - value_ = bool_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "bool").NativeValue(); - } - - bool value_ = false; -}; - -class Int32ValueBuilder final : public WellKnownValueBuilder { - public: - explicit Int32ValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return IntValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeInt32Value(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - CEL_ASSIGN_OR_RETURN( - value_, internal::CheckedInt64ToInt32(int_value->NativeValue())); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - int64_t value_ = 0; -}; - -class Int64ValueBuilder final : public WellKnownValueBuilder { - public: - explicit Int64ValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return IntValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeInt64Value(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - value_ = int_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - int64_t value_ = 0; -}; - -class UInt32ValueBuilder final : public WellKnownValueBuilder { - public: - explicit UInt32ValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return UintValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeUInt32Value(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto uint_value = As(value); uint_value.has_value()) { - CEL_ASSIGN_OR_RETURN( - value_, internal::CheckedUint64ToUint32(uint_value->NativeValue())); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); - } - - uint64_t value_ = 0; -}; - -class UInt64ValueBuilder final : public WellKnownValueBuilder { - public: - explicit UInt64ValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return UintValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeUInt64Value(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto uint_value = As(value); uint_value.has_value()) { - value_ = uint_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "uint").NativeValue(); - } - - uint64_t value_ = 0; -}; - -class FloatValueBuilder final : public WellKnownValueBuilder { - public: - explicit FloatValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return DoubleValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeFloatValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto double_value = As(value); double_value.has_value()) { - // Ensure we truncate to `float`. - value_ = static_cast(double_value->NativeValue()); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "double").NativeValue(); - } - - double value_ = 0; -}; - -class DoubleValueBuilder final : public WellKnownValueBuilder { - public: - explicit DoubleValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return DoubleValue(value_); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeDoubleValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto double_value = As(value); double_value.has_value()) { - value_ = double_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "double").NativeValue(); - } - - double value_ = 0; -}; - -class StringValueBuilder final : public WellKnownValueBuilder { - public: - explicit StringValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return StringValue(std::move(value_)); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeStringValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto string_value = As(value); string_value.has_value()) { - value_ = string_value->NativeCord(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "string").NativeValue(); - } - - absl::Cord value_; -}; - -class BytesValueBuilder final : public WellKnownValueBuilder { - public: - explicit BytesValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name != "value") { - return NoSuchFieldError(name).NativeValue(); - } - return SetValue(std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number != 1) { - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - return SetValue(std::move(value)); - } - - Value Build() && override { return BytesValue(std::move(value_)); } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(value_, - internal::DeserializeBytesValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValue(Value value) { - if (auto bytes_value = As(value); bytes_value.has_value()) { - value_ = bytes_value->NativeCord(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "bytes").NativeValue(); - } - - absl::Cord value_; -}; - -class DurationValueBuilder final : public WellKnownValueBuilder { - public: - explicit DurationValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "seconds") { - return SetSeconds(std::move(value)); - } - if (name == "nanos") { - return SetNanos(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number == 1) { - return SetSeconds(std::move(value)); - } - if (number == 2) { - return SetNanos(std::move(value)); - } - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - - Value Build() && override { - return DurationValue(absl::Seconds(seconds_) + absl::Nanoseconds(nanos_)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(auto value, - internal::DeserializeDuration(serialized_value)); - seconds_ = absl::IDivDuration(value, absl::Seconds(1), &value); - nanos_ = static_cast( - absl::IDivDuration(value, absl::Nanoseconds(1), &value)); - return absl::OkStatus(); - } - - private: - absl::Status SetSeconds(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - seconds_ = int_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - absl::Status SetNanos(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - CEL_ASSIGN_OR_RETURN( - nanos_, internal::CheckedInt64ToInt32(int_value->NativeValue())); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - int64_t seconds_ = 0; - int32_t nanos_ = 0; -}; - -class TimestampValueBuilder final : public WellKnownValueBuilder { - public: - explicit TimestampValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "seconds") { - return SetSeconds(std::move(value)); - } - if (name == "nanos") { - return SetNanos(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number == 1) { - return SetSeconds(std::move(value)); - } - if (number == 2) { - return SetNanos(std::move(value)); - } - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - - Value Build() && override { - return TimestampValue(absl::UnixEpoch() + absl::Seconds(seconds_) + - absl::Nanoseconds(nanos_)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(auto value, - internal::DeserializeTimestamp(serialized_value)); - auto duration = value - absl::UnixEpoch(); - seconds_ = absl::IDivDuration(duration, absl::Seconds(1), &duration); - nanos_ = static_cast( - absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); - return absl::OkStatus(); - } - - private: - absl::Status SetSeconds(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - seconds_ = int_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - absl::Status SetNanos(Value value) { - if (auto int_value = As(value); int_value.has_value()) { - CEL_ASSIGN_OR_RETURN( - nanos_, internal::CheckedInt64ToInt32(int_value->NativeValue())); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "int").NativeValue(); - } - - int64_t seconds_ = 0; - int32_t nanos_ = 0; -}; - -class JsonValueBuilder final : public WellKnownValueBuilder { - public: - explicit JsonValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) - : type_reflector_(type_reflector), value_factory_(value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "null_value") { - return SetNullValue(); - } - if (name == "number_value") { - return SetNumberValue(std::move(value)); - } - if (name == "string_value") { - return SetStringValue(std::move(value)); - } - if (name == "bool_value") { - return SetBoolValue(std::move(value)); - } - if (name == "struct_value") { - return SetStructValue(std::move(value)); - } - if (name == "list_value") { - return SetListValue(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - switch (number) { - case 1: - return SetNullValue(); - case 2: - return SetNumberValue(std::move(value)); - case 3: - return SetStringValue(std::move(value)); - case 4: - return SetBoolValue(std::move(value)); - case 5: - return SetStructValue(std::move(value)); - case 6: - return SetListValue(std::move(value)); - default: - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - } - - Value Build() && override { - return value_factory_.CreateValueFromJson(std::move(json_)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(json_, internal::DeserializeValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetNullValue() { - json_ = kJsonNull; - return absl::OkStatus(); - } - - absl::Status SetNumberValue(Value value) { - if (auto double_value = As(value); double_value.has_value()) { - json_ = double_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "double").NativeValue(); - } - - absl::Status SetStringValue(Value value) { - if (auto string_value = As(value); string_value.has_value()) { - json_ = string_value->NativeCord(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "string").NativeValue(); - } - - absl::Status SetBoolValue(Value value) { - if (auto bool_value = As(value); bool_value.has_value()) { - json_ = bool_value->NativeValue(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "bool").NativeValue(); - } - - absl::Status SetStructValue(Value value) { - if (auto map_value = As(value); map_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(json_, map_value->ConvertToJson(value_manager)); - return absl::OkStatus(); - } - if (auto struct_value = As(value); struct_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(json_, struct_value->ConvertToJson(value_manager)); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "google.protobuf.Struct") - .NativeValue(); - } - - absl::Status SetListValue(Value value) { - if (auto list_value = As(value); list_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(json_, list_value->ConvertToJson(value_manager)); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "google.protobuf.ListValue") - .NativeValue(); - } - - const TypeReflector& type_reflector_; - ValueFactory& value_factory_; - Json json_; -}; - -class JsonArrayValueBuilder final : public WellKnownValueBuilder { - public: - explicit JsonArrayValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) - : type_reflector_(type_reflector), value_factory_(value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "values") { - return SetValues(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number == 1) { - return SetValues(std::move(value)); - } - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - - Value Build() && override { - return value_factory_.CreateListValueFromJsonArray(std::move(array_)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(array_, - internal::DeserializeListValue(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetValues(Value value) { - if (auto list_value = As(value); list_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(array_, - list_value->ConvertToJsonArray(value_manager)); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "list(dyn)").NativeValue(); - } - - const TypeReflector& type_reflector_; - ValueFactory& value_factory_; - JsonArray array_; -}; - -class JsonObjectValueBuilder final : public WellKnownValueBuilder { - public: - explicit JsonObjectValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) - : type_reflector_(type_reflector), value_factory_(value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "fields") { - return SetFields(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number == 1) { - return SetFields(std::move(value)); - } - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - - Value Build() && override { - return value_factory_.CreateMapValueFromJsonObject(std::move(object_)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(object_, - internal::DeserializeStruct(serialized_value)); - return absl::OkStatus(); - } - - private: - absl::Status SetFields(Value value) { - if (auto map_value = As(value); map_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(object_, - map_value->ConvertToJsonObject(value_manager)); - return absl::OkStatus(); - } - if (auto struct_value = As(value); struct_value.has_value()) { - common_internal::PiecewiseValueManager value_manager(type_reflector_, - value_factory_); - CEL_ASSIGN_OR_RETURN(auto json_value, - struct_value->ConvertToJson(value_manager)); - if (absl::holds_alternative(json_value)) { - object_ = absl::get(std::move(json_value)); - return absl::OkStatus(); - } - } - return TypeConversionError(value.GetTypeName(), "map(string, dyn)") - .NativeValue(); - } - - const TypeReflector& type_reflector_; - ValueFactory& value_factory_; - JsonObject object_; -}; - -class AnyValueBuilder final : public WellKnownValueBuilder { - public: - explicit AnyValueBuilder(const TypeReflector& type_reflector, - ValueFactory& value_factory) - : type_reflector_(type_reflector), value_factory_(value_factory) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - if (name == "type_url") { - return SetTypeUrl(std::move(value)); - } - if (name == "value") { - return SetValue(std::move(value)); - } - return NoSuchFieldError(name).NativeValue(); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - if (number == 1) { - return SetTypeUrl(std::move(value)); - } - if (number == 2) { - return SetValue(std::move(value)); - } - return NoSuchFieldError(absl::StrCat(number)).NativeValue(); - } - - Value Build() && override { - auto status_or_value = - type_reflector_.DeserializeValue(value_factory_, type_url_, value_); - if (!status_or_value.ok()) { - return ErrorValue(std::move(status_or_value).status()); - } - if (!(*status_or_value).has_value()) { - return NoSuchTypeError(type_url_); - } - return std::move(*std::move(*status_or_value)); - } - - absl::Status Deserialize(const absl::Cord& serialized_value) override { - CEL_ASSIGN_OR_RETURN(auto any, internal::DeserializeAny(serialized_value)); - type_url_ = any.type_url(); - value_ = GetAnyValueAsCord(any); - return absl::OkStatus(); - } - - private: - absl::Status SetTypeUrl(Value value) { - if (auto string_value = As(value); string_value.has_value()) { - type_url_ = string_value->NativeString(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "string").NativeValue(); - } - - absl::Status SetValue(Value value) { - if (auto bytes_value = As(value); bytes_value.has_value()) { - value_ = bytes_value->NativeCord(); - return absl::OkStatus(); - } - return TypeConversionError(value.GetTypeName(), "bytes").NativeValue(); - } - - const TypeReflector& type_reflector_; - ValueFactory& value_factory_; - std::string type_url_; - absl::Cord value_; -}; - -using WellKnownValueBuilderProvider = - std::unique_ptr (*)(MemoryManagerRef, - const TypeReflector&, - ValueFactory&); - -template -std::unique_ptr WellKnownValueBuilderProviderFor( - MemoryManagerRef memory_manager, const TypeReflector& type_reflector, - ValueFactory& value_factory) { - return std::make_unique(type_reflector, value_factory); -} - -using WellKnownValueBuilderMap = - absl::flat_hash_map; - -const WellKnownValueBuilderMap& GetWellKnownValueBuilderMap() { - static const WellKnownValueBuilderMap* builders = - []() -> WellKnownValueBuilderMap* { - WellKnownValueBuilderMap* builders = new WellKnownValueBuilderMap(); - builders->insert_or_assign( - "google.protobuf.BoolValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Int32Value", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Int64Value", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.UInt32Value", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.UInt64Value", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.FloatValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.DoubleValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.StringValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.BytesValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Duration", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Timestamp", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Value", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.ListValue", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Struct", - &WellKnownValueBuilderProviderFor); - builders->insert_or_assign( - "google.protobuf.Any", - &WellKnownValueBuilderProviderFor); - return builders; - }(); - return *builders; -} - -class ValueBuilderForStruct final : public ValueBuilder { - public: - explicit ValueBuilderForStruct(StructValueBuilderPtr delegate) - : delegate_(std::move(delegate)) {} - - absl::Status SetFieldByName(absl::string_view name, Value value) override { - return delegate_->SetFieldByName(name, std::move(value)); - } - - absl::Status SetFieldByNumber(int64_t number, Value value) override { - return delegate_->SetFieldByNumber(number, std::move(value)); - } - - Value Build() && override { - auto status_or_value = std::move(*delegate_).Build(); - if (!status_or_value.ok()) { - return ErrorValue(status_or_value.status()); - } - return std::move(status_or_value).value(); - } - - private: - StructValueBuilderPtr delegate_; -}; - -} // namespace - absl::StatusOr> TypeReflector::NewValueBuilder( ValueFactory& value_factory, absl::string_view name) const { - const auto& well_known_value_builders = GetWellKnownValueBuilderMap(); - if (auto well_known_value_builder = well_known_value_builders.find(name); - well_known_value_builder != well_known_value_builders.end()) { - return (*well_known_value_builder->second)(value_factory.GetMemoryManager(), - *this, value_factory); - } - CEL_ASSIGN_OR_RETURN( - auto maybe_builder, - NewStructValueBuilder(value_factory, - common_internal::MakeBasicStructType(name))); - if (maybe_builder != nullptr) { - return std::make_unique(std::move(maybe_builder)); - } return nullptr; } absl::StatusOr> TypeReflector::DeserializeValue( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const { - if (absl::StartsWith(type_url, kTypeGoogleApisComPrefix)) { - const auto& well_known_value_builders = GetWellKnownValueBuilderMap(); - if (auto well_known_value_builder = well_known_value_builders.find( - absl::StripPrefix(type_url, kTypeGoogleApisComPrefix)); - well_known_value_builder != well_known_value_builders.end()) { - auto deserializer = (*well_known_value_builder->second)( - value_factory.GetMemoryManager(), *this, value_factory); - CEL_RETURN_IF_ERROR(deserializer->Deserialize(value)); - return std::move(*deserializer).Build(); - } - } return DeserializeValueImpl(value_factory, type_url, value); } @@ -959,7 +45,8 @@ absl::StatusOr> TypeReflector::DeserializeValueImpl( } absl::StatusOr> -TypeReflector::NewStructValueBuilder(ValueFactory&, const StructType&) const { +TypeReflector::NewStructValueBuilder(ValueFactory& value_factory, + const StructType& type) const { return nullptr; } diff --git a/common/type_reflector.h b/common/type_reflector.h index b0c1c66d3..20b922971 100644 --- a/common/type_reflector.h +++ b/common/type_reflector.h @@ -27,7 +27,6 @@ #include "common/value.h" #include "common/value_factory.h" #include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace cel { @@ -63,7 +62,7 @@ class TypeReflector : public virtual TypeIntrospector { // `NewValueBuilder` returns a new `ValueBuilder` for the corresponding type // `name`. It is primarily used to handle wrapper types which sometimes show // up literally in expressions. - absl::StatusOr> NewValueBuilder( + virtual absl::StatusOr> NewValueBuilder( ValueFactory& value_factory, absl::string_view name) const; // `FindValue` returns a new `Value` for the corresponding name `name`. This @@ -74,7 +73,7 @@ class TypeReflector : public virtual TypeIntrospector { // `DeserializeValue` deserializes the bytes of `value` according to // `type_url`. Returns `NOT_FOUND` if `type_url` is unrecognized. - absl::StatusOr> DeserializeValue( + virtual absl::StatusOr> DeserializeValue( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const; @@ -83,10 +82,6 @@ class TypeReflector : public virtual TypeIntrospector { return nullptr; } - virtual absl::Nullable message_factory() const { - return nullptr; - } - protected: virtual absl::StatusOr> DeserializeValueImpl( ValueFactory& value_factory, absl::string_view type_url, diff --git a/common/type_reflector_test.cc b/common/type_reflector_test.cc index 91d48551f..32e906c14 100644 --- a/common/type_reflector_test.cc +++ b/common/type_reflector_test.cc @@ -25,7 +25,10 @@ #include "common/value.h" #include "common/value_testing.h" #include "common/values/list_value.h" +#include "common/values/value_builder.h" #include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" namespace cel { namespace { @@ -219,8 +222,11 @@ TEST_P(TypeReflectorTest, JsonKeyCoverage) { } TEST_P(TypeReflectorTest, NewValueBuilder_BoolValue) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.BoolValue")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BoolValue")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), @@ -232,14 +238,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_BoolValue) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), true); } TEST_P(TypeReflectorTest, NewValueBuilder_Int32Value) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.Int32Value")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int32Value")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -257,14 +266,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Int32Value) { EXPECT_THAT(builder->SetFieldByNumber( 1, IntValue(std::numeric_limits::max())), StatusIs(absl::StatusCode::kOutOfRange)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_Int64Value) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.Int64Value")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int64Value")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -276,14 +288,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Int64Value) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_UInt32Value) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.UInt32Value")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), @@ -301,14 +316,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_UInt32Value) { EXPECT_THAT(builder->SetFieldByNumber( 1, UintValue(std::numeric_limits::max())), StatusIs(absl::StatusCode::kOutOfRange)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_UInt64Value) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.UInt64Value")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), @@ -320,14 +338,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_UInt64Value) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_FloatValue) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.FloatValue")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.FloatValue")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), @@ -339,14 +360,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_FloatValue) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_DoubleValue) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.DoubleValue")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), @@ -358,14 +382,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_DoubleValue) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_P(TypeReflectorTest, NewValueBuilder_StringValue) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.StringValue")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.StringValue")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), @@ -377,14 +404,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_StringValue) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeString(), "foo"); } TEST_P(TypeReflectorTest, NewValueBuilder_BytesValue) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.BytesValue")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BytesValue")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), @@ -396,14 +426,17 @@ TEST_P(TypeReflectorTest, NewValueBuilder_BytesValue) { StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeString(), "foo"); } TEST_P(TypeReflectorTest, NewValueBuilder_Duration) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.Duration")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Duration")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -427,15 +460,18 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Duration) { StatusIs(absl::StatusCode::kOutOfRange)); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), absl::Seconds(1) + absl::Nanoseconds(1)); } TEST_P(TypeReflectorTest, NewValueBuilder_Timestamp) { - ASSERT_OK_AND_ASSIGN(auto builder, value_manager().NewValueBuilder( - "google.protobuf.Timestamp")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Timestamp")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -459,15 +495,18 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Timestamp) { StatusIs(absl::StatusCode::kOutOfRange)); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)); } TEST_P(TypeReflectorTest, NewValueBuilder_Any) { - ASSERT_OK_AND_ASSIGN(auto builder, - value_manager().NewValueBuilder("google.protobuf.Any")); + ASSERT_OK_AND_ASSIGN( + auto builder, + common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Any")); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName( "type_url", @@ -491,7 +530,7 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Any) { EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue()), IsOk()); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), StatusIs(absl::StatusCode::kInvalidArgument)); - auto value = std::move(*builder).Build(); + ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), false); } diff --git a/common/value.cc b/common/value.cc index 2bd8fbbec..c4a7a8a28 100644 --- a/common/value.cc +++ b/common/value.cc @@ -48,6 +48,9 @@ #include "common/optional_ref.h" #include "common/type.h" #include "common/value_kind.h" +#include "common/values/list_value_builder.h" +#include "common/values/map_value_builder.h" +#include "common/values/struct_value_builder.h" #include "common/values/values.h" #include "internal/number.h" #include "internal/protobuf_runtime_version.h" @@ -2537,6 +2540,30 @@ absl::Nonnull> NewEmptyValueIterator() { return std::make_unique(); } +absl::Nonnull NewListValueBuilder( + absl::Nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return common_internal::NewListValueBuilder(arena); +} + +absl::Nonnull NewMapValueBuilder( + absl::Nonnull arena) { + ABSL_DCHECK(arena != nullptr); + return common_internal::NewMapValueBuilder(arena); +} + +absl::StatusOr> NewStructValueBuilder( + absl::Nonnull arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::string_view name) { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + return common_internal::NewStructValueBuilder(arena, descriptor_pool, + message_factory, name); +} + bool operator==(IntValue lhs, UintValue rhs) { return internal::Number::FromInt64(lhs.NativeValue()) == internal::Number::FromUint64(rhs.NativeValue()); diff --git a/common/value.h b/common/value.h index 0a325c312..e14223e37 100644 --- a/common/value.h +++ b/common/value.h @@ -2640,11 +2640,27 @@ class ValueBuilder { virtual absl::Status SetFieldByNumber(int64_t number, Value value) = 0; - virtual Value Build() && = 0; + virtual absl::StatusOr Build() && = 0; }; using ValueBuilderPtr = std::unique_ptr; +absl::Nonnull NewListValueBuilder( + absl::Nonnull arena); + +absl::Nonnull NewMapValueBuilder( + absl::Nonnull arena); + +// Returns a new `StructValueBuilder`. Returns `nullptr` if there is no such +// message type with the name `name` in `descriptor_pool`. Returns an error if +// `message_factory` is unable to provide a prototype for the descriptor +// returned from `descriptor_pool`. +absl::StatusOr> NewStructValueBuilder( + absl::Nonnull arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::string_view name); + using ListValueBuilderInterface = ListValueBuilder; using MapValueBuilderInterface = MapValueBuilder; using StructValueBuilderInterface = StructValueBuilder; diff --git a/common/value_factory.h b/common/value_factory.h index 9d0c6635a..388f8401e 100644 --- a/common/value_factory.h +++ b/common/value_factory.h @@ -20,6 +20,7 @@ #include #include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" @@ -31,6 +32,7 @@ #include "common/type_factory.h" #include "common/unknown.h" #include "common/value.h" +#include "google/protobuf/message.h" namespace cel { @@ -184,6 +186,8 @@ class ValueFactory : public virtual TypeFactory { Unknown{std::move(attribute_set), std::move(function_result_set)}}; } + virtual absl::Nullable message_factory() const = 0; + protected: friend class common_internal::PiecewiseValueManager; }; diff --git a/common/value_manager.h b/common/value_manager.h index 0abc61594..c6244f049 100644 --- a/common/value_manager.h +++ b/common/value_manager.h @@ -27,6 +27,7 @@ #include "common/type_reflector.h" #include "common/value.h" #include "common/value_factory.h" +#include "google/protobuf/message.h" namespace cel { @@ -76,6 +77,8 @@ class ValueManager : public virtual ValueFactory, absl::StatusOr ConvertToJson(absl::string_view type_url, const absl::Cord& value) final; + absl::Nullable message_factory() const override = 0; + protected: virtual const TypeReflector& GetTypeReflector() const = 0; }; diff --git a/common/values/legacy_value_manager.h b/common/values/legacy_value_manager.h index 61c9b9bae..a1d7b4a62 100644 --- a/common/values/legacy_value_manager.h +++ b/common/values/legacy_value_manager.h @@ -17,12 +17,14 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_VALUE_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_VALUE_MANAGER_H_ +#include "absl/base/nullability.h" #include "common/memory.h" #include "common/type_reflector.h" #include "common/types/legacy_type_manager.h" #include "common/value.h" #include "common/value_manager.h" #include "common/values/legacy_type_reflector.h" +#include "google/protobuf/message.h" namespace cel::common_internal { @@ -36,6 +38,10 @@ class LegacyValueManager : public LegacyTypeManager, public ValueManager { MemoryManagerRef GetMemoryManager() const override { return memory_manager_; } + absl::Nullable message_factory() const override { + return nullptr; + } + protected: const TypeReflector& GetTypeReflector() const final { return type_reflector_; diff --git a/common/values/piecewise_value_manager.h b/common/values/piecewise_value_manager.h index 8078637ce..7dfd0e3f4 100644 --- a/common/values/piecewise_value_manager.h +++ b/common/values/piecewise_value_manager.h @@ -17,12 +17,15 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PIECEWISE_VALUE_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PIECEWISE_VALUE_MANAGER_H_ +#include "absl/base/nullability.h" #include "common/memory.h" #include "common/type_introspector.h" #include "common/type_reflector.h" #include "common/value.h" #include "common/value_factory.h" #include "common/value_manager.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace cel::common_internal { @@ -39,6 +42,15 @@ class PiecewiseValueManager final : public ValueManager { return value_factory_.GetMemoryManager(); } + absl::Nullable descriptor_pool() + const override { + return type_reflector_.descriptor_pool(); + } + + absl::Nullable message_factory() const override { + return value_factory_.message_factory(); + } + protected: const TypeIntrospector& GetTypeIntrospector() const override { return type_reflector_; diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 26717947b..6bf9440a9 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -41,8 +41,10 @@ #include "common/type_introspector.h" #include "common/type_reflector.h" #include "common/value.h" +#include "common/value_factory.h" #include "common/value_kind.h" #include "common/value_manager.h" +#include "common/values/value_builder.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/json.h" #include "internal/status_macros.h" @@ -59,19 +61,15 @@ namespace { class CompatTypeReflector final : public TypeReflector { public: - CompatTypeReflector(absl::Nonnull pool, - absl::Nonnull factory) - : pool_(pool), factory_(factory) {} + explicit CompatTypeReflector( + absl::Nonnull pool) + : pool_(pool) {} - absl::Nullable descriptor_pool() + absl::Nonnull descriptor_pool() const override { return pool_; } - absl::Nullable message_factory() const override { - return factory_; - } - protected: absl::StatusOr> FindTypeImpl( absl::string_view name) const final { @@ -126,19 +124,45 @@ class CompatTypeReflector final : public TypeReflector { return MessageTypeField(field_desc); } - absl::StatusOr> DeserializeValueImpl( + absl::StatusOr> NewStructValueBuilder( + ValueFactory& value_factory, const StructType& type) const override { + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return nullptr; + } + return common_internal::NewStructValueBuilder( + value_factory.GetMemoryManager().arena(), descriptor_pool(), + message_factory, type.name()); + } + + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory, absl::string_view name) const override { + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return nullptr; + } + return common_internal::NewValueBuilder(value_factory.GetMemoryManager(), + descriptor_pool(), message_factory, + name); + } + + absl::StatusOr> DeserializeValue( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const override { + const auto* descriptor_pool = this->descriptor_pool(); + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return absl::nullopt; + } absl::string_view type_name; if (!ParseTypeUrl(type_url, &type_name)) { return absl::InvalidArgumentError("invalid type URL"); } - const auto* descriptor = - descriptor_pool()->FindMessageTypeByName(type_name); + const auto* descriptor = descriptor_pool->FindMessageTypeByName(type_name); if (descriptor == nullptr) { return absl::nullopt; } - const auto* prototype = message_factory()->GetPrototype(descriptor); + const auto* prototype = message_factory->GetPrototype(descriptor); if (prototype == nullptr) { return absl::nullopt; } @@ -149,13 +173,12 @@ class CompatTypeReflector final : public TypeReflector { return absl::InvalidArgumentError( absl::StrCat("failed to parse `", type_url, "`")); } - return Value::Message(WrapShared(prototype->New(arena), arena), pool_, - factory_); + return Value::Message(WrapShared(prototype->New(arena), arena), + descriptor_pool, message_factory); } private: const google::protobuf::DescriptorPool* const pool_; - google::protobuf::MessageFactory* const factory_; }; class CompatValueManager final : public ValueManager { @@ -163,7 +186,7 @@ class CompatValueManager final : public ValueManager { CompatValueManager(absl::Nullable arena, absl::Nonnull pool, absl::Nonnull factory) - : arena_(arena), reflector_(pool, factory) {} + : arena_(arena), reflector_(pool), factory_(factory) {} MemoryManagerRef GetMemoryManager() const override { return arena_ != nullptr ? MemoryManager::Pooling(arena_) @@ -182,12 +205,13 @@ class CompatValueManager final : public ValueManager { } absl::Nullable message_factory() const override { - return reflector_.message_factory(); + return factory_; } private: absl::Nullable const arena_; CompatTypeReflector reflector_; + absl::Nonnull factory_; }; absl::StatusOr> GetDescriptor( @@ -1011,9 +1035,9 @@ GetProtoRepeatedFieldFromValueMutator( } } -class StructValueBuilderImpl final : public StructValueBuilder { +class MessageValueBuilderImpl { public: - StructValueBuilderImpl( + MessageValueBuilderImpl( absl::Nullable arena, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, @@ -1025,13 +1049,13 @@ class StructValueBuilderImpl final : public StructValueBuilder { descriptor_(message_->GetDescriptor()), reflection_(message_->GetReflection()) {} - ~StructValueBuilderImpl() override { + ~MessageValueBuilderImpl() { if (arena_ == nullptr && message_ != nullptr) { delete message_; } } - absl::Status SetFieldByName(absl::string_view name, Value value) override { + absl::Status SetFieldByName(absl::string_view name, Value value) { const auto* field = descriptor_->FindFieldByName(name); if (field == nullptr) { field = descriptor_pool_->FindExtensionByPrintableName(descriptor_, name); @@ -1042,7 +1066,7 @@ class StructValueBuilderImpl final : public StructValueBuilder { return SetField(field, std::move(value)); } - absl::Status SetFieldByNumber(int64_t number, Value value) override { + absl::Status SetFieldByNumber(int64_t number, Value value) { if (number < std::numeric_limits::min() || number > std::numeric_limits::max()) { return NoSuchFieldError(absl::StrCat(number)).NativeValue(); @@ -1055,7 +1079,12 @@ class StructValueBuilderImpl final : public StructValueBuilder { return SetField(field, std::move(value)); } - absl::StatusOr Build() && override { + absl::StatusOr Build() && { + return Value::Message(WrapShared(std::exchange(message_, nullptr)), + descriptor_pool_, message_factory_); + } + + absl::StatusOr BuildStruct() && { return ParsedMessageValue( WrapShared(std::exchange(message_, nullptr), Allocator(arena_))); } @@ -1232,6 +1261,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { switch (field->message_type()->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto bool_value = value.AsBool(); bool_value) { CEL_RETURN_IF_ERROR(well_known_types_.BoolValue().Initialize( field->message_type())); @@ -1246,6 +1279,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || @@ -1266,6 +1303,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto int_value = value.AsInt(); int_value) { CEL_RETURN_IF_ERROR(well_known_types_.Int64Value().Initialize( field->message_type())); @@ -1280,6 +1321,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { @@ -1298,6 +1343,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto uint_value = value.AsUint(); uint_value) { CEL_RETURN_IF_ERROR(well_known_types_.UInt64Value().Initialize( field->message_type())); @@ -1312,6 +1361,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types_.FloatValue().Initialize( field->message_type())); @@ -1326,6 +1379,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types_.DoubleValue().Initialize( field->message_type())); @@ -1340,6 +1397,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto bytes_value = value.AsBytes(); bytes_value) { CEL_RETURN_IF_ERROR(well_known_types_.BytesValue().Initialize( field->message_type())); @@ -1354,6 +1415,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto string_value = value.AsString(); string_value) { CEL_RETURN_IF_ERROR(well_known_types_.StringValue().Initialize( field->message_type())); @@ -1368,6 +1433,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto duration_value = value.AsDuration(); duration_value) { CEL_RETURN_IF_ERROR(well_known_types_.Duration().Initialize( field->message_type())); @@ -1381,6 +1450,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { .NativeValue(); } case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } if (auto timestamp_value = value.AsTimestamp(); timestamp_value) { CEL_RETURN_IF_ERROR(well_known_types_.Timestamp().Initialize( field->message_type())); @@ -1486,6 +1559,10 @@ class StructValueBuilderImpl final : public StructValueBuilder { return absl::OkStatus(); } default: + if (value.IsNull()) { + // Allowing assigning `null` to message fields. + return absl::OkStatus(); + } break; } return ProtoMessageFromValueImpl( @@ -1519,19 +1596,91 @@ class StructValueBuilderImpl final : public StructValueBuilder { well_known_types::Reflection well_known_types_; }; +class ValueBuilderImpl final : public ValueBuilder { + public: + ValueBuilderImpl(absl::Nullable arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::Nonnull message) + : builder_(arena, descriptor_pool, message_factory, message) {} + + absl::Status SetFieldByName(absl::string_view name, Value value) override { + return builder_.SetFieldByName(name, std::move(value)); + } + + absl::Status SetFieldByNumber(int64_t number, Value value) override { + return builder_.SetFieldByNumber(number, std::move(value)); + } + + absl::StatusOr Build() && override { + return std::move(builder_).Build(); + } + + private: + MessageValueBuilderImpl builder_; +}; + +class StructValueBuilderImpl final : public StructValueBuilder { + public: + StructValueBuilderImpl( + absl::Nullable arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::Nonnull message) + : builder_(arena, descriptor_pool, message_factory, message) {} + + absl::Status SetFieldByName(absl::string_view name, Value value) override { + return builder_.SetFieldByName(name, std::move(value)); + } + + absl::Status SetFieldByNumber(int64_t number, Value value) override { + return builder_.SetFieldByNumber(number, std::move(value)); + } + + absl::StatusOr Build() && override { + return std::move(builder_).BuildStruct(); + } + + private: + MessageValueBuilderImpl builder_; +}; + } // namespace -absl::StatusOr> NewStructValueBuilder( +absl::StatusOr> NewValueBuilder( + Allocator<> allocator, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::string_view name) { + absl::Nullable descriptor = + descriptor_pool->FindMessageTypeByName(name); + if (descriptor == nullptr) { + return nullptr; + } + absl::Nullable prototype = + message_factory->GetPrototype(descriptor); + if (prototype == nullptr) { + return absl::NotFoundError(absl::StrCat( + "unable to get prototype for descriptor: ", descriptor->full_name())); + } + return std::make_unique(allocator.arena(), descriptor_pool, + message_factory, + prototype->New(allocator.arena())); +} + +absl::StatusOr> +NewStructValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::string_view name) { - const auto* descriptor = descriptor_pool->FindMessageTypeByName(name); + absl::Nullable descriptor = + descriptor_pool->FindMessageTypeByName(name); if (descriptor == nullptr) { - return absl::NotFoundError( - absl::StrCat("unable to find descriptor for type: ", name)); + return nullptr; } - const auto* prototype = message_factory->GetPrototype(descriptor); + absl::Nullable prototype = + message_factory->GetPrototype(descriptor); if (prototype == nullptr) { return absl::NotFoundError(absl::StrCat( "unable to get prototype for descriptor: ", descriptor->full_name())); diff --git a/common/values/struct_value_builder.h b/common/values/struct_value_builder.h index 76a7217d2..063dc8c84 100644 --- a/common/values/struct_value_builder.h +++ b/common/values/struct_value_builder.h @@ -23,20 +23,15 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" -namespace cel { +namespace cel::common_internal { -class ValueFactory; - -namespace common_internal { - -absl::StatusOr> NewStructValueBuilder( +absl::StatusOr> +NewStructValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::string_view name); -} // namespace common_internal - -} // namespace cel +} // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ diff --git a/common/values/thread_compatible_value_manager.h b/common/values/thread_compatible_value_manager.h index 798cfcdf1..84e6f088b 100644 --- a/common/values/thread_compatible_value_manager.h +++ b/common/values/thread_compatible_value_manager.h @@ -19,11 +19,13 @@ #include +#include "absl/base/nullability.h" #include "common/memory.h" #include "common/type_reflector.h" #include "common/types/thread_compatible_type_manager.h" #include "common/value.h" #include "common/value_manager.h" +#include "google/protobuf/message.h" namespace cel::common_internal { @@ -38,6 +40,10 @@ class ThreadCompatibleValueManager : public ThreadCompatibleTypeManager, MemoryManagerRef GetMemoryManager() const override { return memory_manager_; } + absl::Nullable message_factory() const override { + return nullptr; + } + protected: TypeReflector& GetTypeReflector() const final { return *type_reflector_; } diff --git a/common/values/value_builder.h b/common/values/value_builder.h new file mode 100644 index 000000000..e93704884 --- /dev/null +++ b/common/values/value_builder.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/allocator.h" +#include "common/value.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::common_internal { + +// Like NewStructValueBuilder, but deals with well known types. +absl::StatusOr> NewValueBuilder( + Allocator<> allocator, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + absl::string_view name); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ diff --git a/conformance/BUILD b/conformance/BUILD index aca4c2795..f2de5277a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -212,23 +212,6 @@ _TESTS_TO_SKIP_MODERN = [ "string_ext/value_errors", "string_ext/type_errors", - # TODO: Fix null assignment to a field - "proto2/set_null/single_message", - "proto2/set_null/single_duration", - "proto2/set_null/single_timestamp", - "proto3/set_null/single_message", - "proto3/set_null/single_duration", - "proto3/set_null/single_timestamp", - "wrappers/bool/to_null", - "wrappers/int32/to_null", - "wrappers/int64/to_null", - "wrappers/uint32/to_null", - "wrappers/uint64/to_null", - "wrappers/float/to_null", - "wrappers/double/to_null", - "wrappers/bytes/to_null", - "wrappers/string/to_null", - # TODO: Add missing conversion function "conversions/bool", ] diff --git a/conformance/run.bzl b/conformance/run.bzl index 0a454c632..86fc01ace 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -35,24 +35,19 @@ def _expand_tests_to_skip(tests_to_skip): result.append(test_to_skip[0:slash] + part) return result -def _conformance_test_name(name, modern, arena, optimize, recursive, skip_check): +def _conformance_test_name(name, optimize, recursive): return "_".join( [ name, - "arena" if arena else "refcount", "optimized" if optimize else "unoptimized", "recursive" if recursive else "iterative", ], ) -def _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_tests, dashboard): +def _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard): args = [] if modern: args.append("--modern") - elif not arena: - fail("arena must be true for legacy") - if not modern or arena: - args.append("--arena") if optimize: args.append("--opt") if recursive: @@ -66,10 +61,10 @@ def _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_ args.append("--dashboard") return args -def _conformance_test(name, data, modern, arena, optimize, recursive, skip_check, skip_tests, tags, dashboard): +def _conformance_test(name, data, modern, optimize, recursive, skip_check, skip_tests, tags, dashboard): native.cc_test( - name = _conformance_test_name(name, modern, arena, optimize, recursive, skip_check), - args = _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data], + name = _conformance_test_name(name, optimize, recursive), + args = _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data], data = data, deps = ["//conformance:run"], tags = tags, @@ -89,15 +84,12 @@ def gen_conformance_tests(name, data, modern = False, checked = False, dashboard dashboard: enable dashboard mode """ skip_check = not checked - - # TODO: enable refcount mode for modern. for optimize in (True, False): for recursive in (True, False): _conformance_test( name, data, modern = modern, - arena = True, optimize = optimize, recursive = recursive, skip_check = skip_check, diff --git a/conformance/run.cc b/conformance/run.cc index d810833e3..c76569f8c 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -57,9 +57,6 @@ ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)"); ABSL_FLAG( bool, modern, false, "Use modern cel::Value APIs implementation of the conformance service."); -ABSL_FLAG(bool, arena, false, - "Use arena memory manager (default: global heap ref-counted). Only " - "affects the modern implementation"); ABSL_FLAG(bool, recursive, false, "Enable recursive plans. Depth limited to slightly more than the " "default nesting limit."); @@ -279,7 +276,6 @@ NewConformanceServiceFromFlags() { cel_conformance::ConformanceServiceOptions{ .optimize = absl::GetFlag(FLAGS_opt), .modern = absl::GetFlag(FLAGS_modern), - .arena = absl::GetFlag(FLAGS_arena), .recursive = absl::GetFlag(FLAGS_recursive)}); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( diff --git a/conformance/service.cc b/conformance/service.cc index d2a8bffed..da5ca39a7 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -372,7 +372,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { builder_->CreateExpression(&expr, &source_info); if (!cel_expression_status.ok()) { - return absl::InternalError(cel_expression_status.status().ToString()); + return absl::InternalError(cel_expression_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); } auto cel_expression = std::move(cel_expression_status.value()); @@ -384,7 +385,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { import_value)); auto import_status = ValueToCelValue(*import_value, &arena); if (!import_status.ok()) { - return absl::InternalError(import_status.status().ToString()); + return absl::InternalError(import_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); } activation.InsertValue(pair.first, import_status.value()); } @@ -394,7 +396,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { *response.mutable_result() ->mutable_error() ->add_errors() - ->mutable_message() = eval_status.status().ToString(); + ->mutable_message() = eval_status.status().ToString( + absl::StatusToStringMode::kWithEverything); return absl::OkStatus(); } @@ -403,12 +406,14 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { *response.mutable_result() ->mutable_error() ->add_errors() - ->mutable_message() = std::string(result.ErrorOrDie()->message()); + ->mutable_message() = std::string(result.ErrorOrDie()->ToString( + absl::StatusToStringMode::kWithEverything)); } else { cel::expr::Value export_value; auto export_status = CelValueToValue(result, &export_value); if (!export_status.ok()) { - return absl::InternalError(export_status.ToString()); + return absl::InternalError( + export_status.ToString(absl::StatusToStringMode::kWithEverything)); } auto* result_value = response.mutable_result()->mutable_value(); ABSL_CHECK( // Crash OK @@ -428,7 +433,7 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool use_arena, bool recursive) { + bool optimize, bool recursive) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< @@ -468,7 +473,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { } return absl::WrapUnique( - new ModernConformanceServiceImpl(options, use_arena, optimize)); + new ModernConformanceServiceImpl(options, optimize)); } absl::StatusOr> Setup( @@ -488,8 +493,6 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { auto& type_registry = builder.type_registry(); // Use linked pbs in the generated descriptor pool. - type_registry.AddTypeProvider( - std::make_unique()); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto2::GlobalEnum_descriptor())); @@ -527,7 +530,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Check(const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) override { - auto status = DoCheck(&arena_, request, response); + google::protobuf::Arena arena; + auto status = DoCheck(&arena, request, response); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -539,9 +543,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { conformance::v1alpha1::EvalResponse& response) override { google::protobuf::Arena arena; auto proto_memory_manager = ProtoMemoryManagerRef(&arena); - cel::MemoryManagerRef memory_manager = - (use_arena_ ? proto_memory_manager - : cel::MemoryManagerRef::ReferenceCounting()); + cel::MemoryManagerRef memory_manager = proto_memory_manager; auto runtime_status = Setup(request.container()); if (!runtime_status.ok()) { @@ -569,14 +571,15 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { auto import_status = FromConformanceValue(value_factory.get(), import_value); if (!import_status.ok()) { - return absl::InternalError(import_status.status().ToString()); + return absl::InternalError(import_status.status().ToString( + absl::StatusToStringMode::kWithEverything)); } activation.InsertOrAssignValue(pair.first, std::move(import_status).value()); } - auto eval_status = program->Evaluate(activation, value_factory.get()); + auto eval_status = program->Evaluate(&arena, activation); if (!eval_status.ok()) { *response.mutable_result() ->mutable_error() @@ -609,11 +612,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { private: explicit ModernConformanceServiceImpl(const RuntimeOptions& options, - bool use_arena, bool enable_optimizations) - : options_(options), - use_arena_(use_arena), - enable_optimizations_(enable_optimizations) {} + : options_(options), enable_optimizations_(enable_optimizations) {} static absl::Status DoCheck( google::protobuf::Arena* arena, const conformance::v1alpha1::CheckRequest& request, @@ -727,9 +727,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { } RuntimeOptions options_; - bool use_arena_; bool enable_optimizations_; - Arena arena_; }; } // namespace @@ -742,7 +740,7 @@ absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( - options.optimize, options.arena, options.recursive); + options.optimize, options.recursive); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( options.optimize, options.recursive); diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index d31e90451..680751fcd 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -95,6 +95,7 @@ cc_library( ":resolver", "//base:ast", "//base:builtins", + "//base:data", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", "//common:ast", @@ -152,9 +153,6 @@ cc_test( srcs = [ "flat_expr_builder_test.cc", ], - data = [ - "//eval/testutil:simple_test_message_proto", - ], deps = [ ":cel_expression_builder_flat_impl", ":constant_folding", diff --git a/eval/compiler/cel_expression_builder_flat_impl.cc b/eval/compiler/cel_expression_builder_flat_impl.cc index 63b601cc4..98ecc6aae 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.cc +++ b/eval/compiler/cel_expression_builder_flat_impl.cc @@ -102,10 +102,10 @@ CelExpressionBuilderFlatImpl::CreateExpressionImpl( impl.subexpressions().front().size() == 1 && impl.subexpressions().front().front()->GetNativeTypeId() == cel::NativeTypeId::For()) { - return CelExpressionRecursiveImpl::Create(std::move(impl)); + return CelExpressionRecursiveImpl::Create(env_, std::move(impl)); } - return std::make_unique(std::move(impl)); + return std::make_unique(env_, std::move(impl)); } } // namespace google::api::expr::runtime diff --git a/eval/compiler/cel_expression_builder_flat_impl.h b/eval/compiler/cel_expression_builder_flat_impl.h index 7b09b7879..ac6f46ce1 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.h +++ b/eval/compiler/cel_expression_builder_flat_impl.h @@ -46,7 +46,8 @@ class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { CelExpressionBuilderFlatImpl( absl::Nonnull> env, const cel::RuntimeOptions& options) - : env_(std::move(env)), flat_expr_builder_(env_, options) { + : env_(std::move(env)), + flat_expr_builder_(env_, options, /*use_legacy_type_provider=*/true) { ABSL_DCHECK(env_->IsInitialized()); } diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index 46212128b..73365e4e6 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -110,10 +110,6 @@ struct RecursiveTestCase { class RecursivePlanTest : public ::testing::TestWithParam { protected: absl::Status SetupBuilder(CelExpressionBuilderFlatImpl& builder) { - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.GetTypeRegistry()->RegisterEnum("TestEnum", {{"FOO", 1}, {"BAR", 2}}); diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 73eccad0e..e90adf122 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -218,7 +218,8 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, Value value; if (node.has_const_expr()) { CEL_ASSIGN_OR_RETURN( - value, ConvertConstant(node.const_expr(), state_.value_factory())); + value, + ConvertConstant(node.const_expr(), state_.memory_manager().arena())); } else { ExecutionFrame frame(subplan, empty_, context.options(), state_); state_.Reset(); diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index fcb4d5c44..e787f3411 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -49,6 +49,7 @@ #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" #include "base/builtins.h" +#include "base/type_provider.h" #include "common/ast.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" @@ -82,6 +83,7 @@ #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" +#include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -96,6 +98,8 @@ using ::cel::Value; using ::cel::ValueManager; using ::cel::ast_internal::AstImpl; using ::cel::runtime_internal::ConvertConstant; +using ::cel::runtime_internal::GetLegacyRuntimeTypeProvider; +using ::cel::runtime_internal::GetRuntimeTypeProvider; using ::cel::runtime_internal::IssueCollector; constexpr absl::string_view kOptionalOrFn = "or"; @@ -525,7 +529,7 @@ class FlatExprVisitor : public cel::AstVisitor { } absl::StatusOr converted_value = - ConvertConstant(const_expr, value_factory_); + ConvertConstant(const_expr, value_factory_.GetMemoryManager().arena()); if (!converted_value.ok()) { SetProgressStatusError(converted_value.status()); @@ -2134,14 +2138,12 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); Resolver resolver(container_, function_registry_, type_registry_, - type_registry_.GetComposedTypeProvider(), - type_registry_.resolveable_enums(), + GetTypeProvider(), type_registry_.resolveable_enums(), options_.enable_qualified_type_identifiers); std::shared_ptr arena; ProgramBuilder program_builder; - PlannerContext extension_context(env_, resolver, options_, - type_registry_.GetComposedTypeProvider(), + PlannerContext extension_context(env_, resolver, options_, GetTypeProvider(), issue_collector, program_builder, arena); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); @@ -2167,8 +2169,7 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( // These objects are expected to remain scoped to one build call -- references // to them shouldn't be persisted in any part of the result expression. cel::common_internal::LegacyValueManager value_factory( - cel::MemoryManagerRef::ReferenceCounting(), - type_registry_.GetComposedTypeProvider()); + cel::MemoryManagerRef::ReferenceCounting(), GetTypeProvider()); FlatExprVisitor visitor(resolver, options_, std::move(optimizers), ast_impl.reference_map(), value_factory, issue_collector, program_builder, extension_context, @@ -2196,9 +2197,14 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( } return FlatExpression(std::move(execution_path), std::move(subexpressions), - visitor.slot_count(), - type_registry_.GetComposedTypeProvider(), options_, + visitor.slot_count(), GetTypeProvider(), options_, std::move(arena)); } +const cel::TypeProvider& FlatExprBuilder::GetTypeProvider() const { + return use_legacy_type_provider_ + ? static_cast( + *GetLegacyRuntimeTypeProvider(type_registry_)) + : GetRuntimeTypeProvider(type_registry_); +} } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index eafb58781..f0263f065 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -26,6 +26,7 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" +#include "base/type_provider.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" @@ -43,24 +44,26 @@ class FlatExprBuilder { FlatExprBuilder( absl::Nonnull> env, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) : env_(std::move(env)), options_(options), container_(options.container), function_registry_(env_->function_registry), - type_registry_(env_->type_registry) {} + type_registry_(env_->type_registry), + use_legacy_type_provider_(use_legacy_type_provider) {} FlatExprBuilder( absl::Nonnull> env, const cel::FunctionRegistry& function_registry, const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options) + const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) : env_(std::move(env)), options_(options), container_(options.container), function_registry_(function_registry), - type_registry_(type_registry) {} + type_registry_(type_registry), + use_legacy_type_provider_(use_legacy_type_provider) {} void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); @@ -91,6 +94,8 @@ class FlatExprBuilder { void enable_optional_types() { enable_optional_types_ = true; } private: + const cel::TypeProvider& GetTypeProvider() const; + const absl::Nonnull> env_; cel::RuntimeOptions options_; @@ -100,6 +105,7 @@ class FlatExprBuilder { // allow built expressions to keep the registries alive. const cel::FunctionRegistry& function_registry_; const cel::TypeRegistry& type_registry_; + bool use_legacy_type_provider_; std::vector> ast_transforms_; std::vector program_optimizers_; }; diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index ad0664777..f3435bd31 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -77,7 +77,6 @@ using ::absl_testing::StatusIs; using ::cel::Value; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::test::EqualsProto; -using ::cel::internal::test::ReadBinaryProtoFromFile; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; @@ -89,10 +88,6 @@ using ::testing::HasSubstr; using ::testing::SizeIs; using ::testing::Truly; -inline constexpr absl::string_view kSimpleTestMessageDescriptorSetFile = - "eval/testutil/" - "simple_test_message_proto-descriptor-set.proto.bin"; - class ConcatFunction : public CelFunction { public: explicit ConcatFunction() : CelFunction(CreateDescriptor()) {} @@ -213,10 +208,6 @@ TEST(FlatExprBuilderTest, MessageFieldValueUnset) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); // Don't set either the field or the value for the message creation step. auto* create_message = expr.mutable_struct_expr(); @@ -1834,10 +1825,6 @@ TEST(FlatExprBuilderTest, TypeResolve) { cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.set_container("google.api.expr"); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto expression, @@ -1862,10 +1849,6 @@ TEST(FlatExprBuilderTest, AnyPackingList) { cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, @@ -1897,10 +1880,6 @@ TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, @@ -1930,10 +1909,6 @@ TEST(FlatExprBuilderTest, AnyPackingInt) { cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, @@ -1962,10 +1937,6 @@ TEST(FlatExprBuilderTest, AnyPackingMap) { cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, @@ -2049,93 +2020,6 @@ TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { HasSubstr("Invalid map key type")))); } -TEST(FlatExprBuilderTest, CustomDescriptorPoolForCreateStruct) { - ASSERT_OK_AND_ASSIGN( - ParsedExpr parsed_expr, - parser::Parse("google.api.expr.runtime.SimpleTestMessage{}")); - - // This time, the message is unknown. We only have the proto as data, we did - // not link the generated message, so it's not included in the generated pool. - CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); - - EXPECT_THAT( - builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), - StatusIs(absl::StatusCode::kInvalidArgument)); - - // Now we create a custom DescriptorPool to which we add SimpleTestMessage - google::protobuf::DescriptorPool desc_pool; - google::protobuf::FileDescriptorSet filedesc_set; - - ASSERT_OK(ReadBinaryProtoFromFile(kSimpleTestMessageDescriptorSetFile, - filedesc_set)); - ASSERT_EQ(filedesc_set.file_size(), 1); - desc_pool.BuildFile(filedesc_set.file(0)); - - google::protobuf::DynamicMessageFactory message_factory(&desc_pool); - - // This time, the message is *known*. We are using a custom descriptor pool - // that has been primed with the relevant message. - CelExpressionBuilderFlatImpl builder2(NewTestingRuntimeEnv()); - builder2.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique(&desc_pool, - &message_factory)); - - ASSERT_OK_AND_ASSIGN(auto expression, - builder2.CreateExpression(&parsed_expr.expr(), - &parsed_expr.source_info())); - - Activation activation; - google::protobuf::Arena arena; - ASSERT_OK_AND_ASSIGN(CelValue result, - expression->Evaluate(activation, &arena)); - ASSERT_TRUE(result.IsMessage()); - EXPECT_EQ(result.MessageOrDie()->GetTypeName(), - "google.api.expr.runtime.SimpleTestMessage"); -} - -TEST(FlatExprBuilderTest, CustomDescriptorPoolForSelect) { - ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, - parser::Parse("message.int64_value")); - - google::protobuf::DescriptorPool desc_pool; - google::protobuf::FileDescriptorSet filedesc_set; - - ASSERT_OK(ReadBinaryProtoFromFile(kSimpleTestMessageDescriptorSetFile, - filedesc_set)); - ASSERT_EQ(filedesc_set.file_size(), 1); - desc_pool.BuildFile(filedesc_set.file(0)); - - google::protobuf::DynamicMessageFactory message_factory(&desc_pool); - - const google::protobuf::Descriptor* desc = desc_pool.FindMessageTypeByName( - "google.api.expr.runtime.SimpleTestMessage"); - const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); - google::protobuf::Message* message = message_prototype->New(); - const google::protobuf::Reflection* refl = message->GetReflection(); - const google::protobuf::FieldDescriptor* field = desc->FindFieldByName("int64_value"); - refl->SetInt64(message, field, 123); - - // The since this is access only, the evaluator will work with message duck - // typing. - CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK_AND_ASSIGN(auto expression, - builder.CreateExpression(&parsed_expr.expr(), - &parsed_expr.source_info())); - Activation activation; - google::protobuf::Arena arena; - activation.InsertValue("message", - CelProtoWrapper::CreateMessage(message, &arena)); - ASSERT_OK_AND_ASSIGN(CelValue result, - expression->Evaluate(activation, &arena)); - EXPECT_THAT(result, test::IsCelInt64(123)); - - delete message; -} - std::pair CreateTestMessage( const google::protobuf::DescriptorPool& descriptor_pool, google::protobuf::MessageFactory& message_factory, absl::string_view name) { @@ -2171,9 +2055,6 @@ TEST_P(CustomDescriptorPoolTest, TestType) { google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - builder.GetTypeRegistry()->RegisterTypeProvider( - std::make_unique(&descriptor_pool, - &message_factory)); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); // Create test subject, invoke custom setter for message diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 978596973..0c8b7178e 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -141,10 +141,6 @@ TEST_F(ResolverTest, TestFindConstantUnqualifiedType) { TEST_F(ResolverTest, TestFindConstantFullyQualifiedType) { google::protobuf::LinkMessageReflection(); CelFunctionRegistry func_registry; - type_registry_.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), value_factory_, type_registry_.resolveable_enums()); @@ -159,10 +155,6 @@ TEST_F(ResolverTest, TestFindConstantFullyQualifiedType) { TEST_F(ResolverTest, TestFindConstantQualifiedTypeDisabled) { CelFunctionRegistry func_registry; - type_registry_.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); Resolver resolver("", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), value_factory_, type_registry_.resolveable_enums(), false); @@ -177,10 +169,6 @@ TEST_F(ResolverTest, FindTypeBySimpleName) { func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), value_factory_, type_registry_.resolveable_enums()); - type_registry_.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("TestMessage", -1)); EXPECT_TRUE(type.has_value()); @@ -189,10 +177,6 @@ TEST_F(ResolverTest, FindTypeBySimpleName) { TEST_F(ResolverTest, FindTypeByQualifiedName) { CelFunctionRegistry func_registry; - type_registry_.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); Resolver resolver("google.api.expr.runtime", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), value_factory_, @@ -206,10 +190,6 @@ TEST_F(ResolverTest, FindTypeByQualifiedName) { TEST_F(ResolverTest, TestFindDescriptorNotFound) { CelFunctionRegistry func_registry; - type_registry_.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); Resolver resolver("google.api.expr.runtime", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), value_factory_, diff --git a/eval/eval/BUILD b/eval/eval/BUILD index d7769f22f..9dbf19433 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -87,7 +87,9 @@ cc_library( "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:status_macros", - "//runtime:managed_value_factory", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_value_manager", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -181,6 +183,7 @@ cc_library( ":direct_expression_step", ":evaluator_core", "//base/ast_internal:expr", + "//common:allocator", "//common:value", "//internal:status_macros", "//runtime/internal:convert_constant", @@ -371,7 +374,6 @@ cc_library( ":evaluator_core", ":expression_step_base", "//common:casting", - "//common:memory", "//common:value", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_set", @@ -515,6 +517,7 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -559,7 +562,6 @@ cc_test( ":evaluator_core", "//base:data", "//base/ast_internal:expr", - "//common:type", "//common:value", "//eval/internal:errors", "//eval/public:activation", @@ -569,6 +571,9 @@ cc_test( "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", @@ -603,6 +608,9 @@ cc_test( "//eval/public/testing:matchers", "//internal:testing", "//parser", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", @@ -652,6 +660,7 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", ], ) @@ -691,6 +700,9 @@ cc_test( "//runtime:managed_value_factory", "//runtime:runtime_options", "//runtime:standard_functions", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -726,6 +738,9 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -771,6 +786,9 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -805,6 +823,7 @@ cc_test( "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public/testing:matchers", "//internal:status_macros", @@ -812,9 +831,13 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -832,7 +855,6 @@ cc_test( ":ident_step", "//base:data", "//base/ast_internal:expr", - "//common:value", "//eval/public:activation", "//eval/public:cel_type_registry", "//eval/public:cel_value", @@ -840,13 +862,14 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", - "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/testutil:test_message_cc_proto", - "//extensions/protobuf:memory_manager", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -877,6 +900,9 @@ cc_test( "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", @@ -1004,6 +1030,8 @@ cc_test( "//runtime:activation", "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", @@ -1043,6 +1071,9 @@ cc_test( "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", + "//runtime/internal:runtime_env", + "//runtime/internal:runtime_env_testing", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", ], ) diff --git a/eval/eval/cel_expression_flat_impl.cc b/eval/eval/cel_expression_flat_impl.cc index b23dc7aac..9b168723e 100644 --- a/eval/eval/cel_expression_flat_impl.cc +++ b/eval/eval/cel_expression_flat_impl.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -37,8 +38,11 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" #include "internal/status_macros.h" -#include "runtime/managed_value_factory.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_value_manager.h" #include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { @@ -46,7 +50,8 @@ namespace { using ::cel::Value; using ::cel::ValueManager; using ::cel::extensions::ProtoMemoryManagerArena; -using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::RuntimeEnv; +using ::cel::runtime_internal::RuntimeValueManager; EvaluationListener AdaptListener(const CelEvaluationListener& listener) { if (!listener) return nullptr; @@ -67,9 +72,13 @@ EvaluationListener AdaptListener(const CelEvaluationListener& listener) { } // namespace CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( - google::protobuf::Arena* arena, const FlatExpression& expression) - : arena_(arena), - state_(expression.MakeEvaluatorState(ProtoMemoryManagerRef(arena_))) {} + google::protobuf::Arena* arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + const FlatExpression& expression) + : value_manager_(arena, descriptor_pool, message_factory, + expression.type_provider()), + state_(expression.MakeEvaluatorState(value_manager_)) {} absl::StatusOr CelExpressionFlatImpl::Trace( const BaseActivation& activation, CelEvaluationState* _state, @@ -90,8 +99,9 @@ absl::StatusOr CelExpressionFlatImpl::Trace( std::unique_ptr CelExpressionFlatImpl::InitializeState( google::protobuf::Arena* arena) const { - return std::make_unique(arena, - flat_expression_); + return std::make_unique( + arena, env_->descriptor_pool.get(), env_->MutableMessageFactory(), + flat_expression_); } absl::StatusOr CelExpressionFlatImpl::Evaluate( @@ -100,7 +110,9 @@ absl::StatusOr CelExpressionFlatImpl::Evaluate( } absl::StatusOr> -CelExpressionRecursiveImpl::Create(FlatExpression flat_expr) { +CelExpressionRecursiveImpl::Create( + absl::Nonnull> env, + FlatExpression flat_expr) { if (flat_expr.path().empty() || flat_expr.path().front()->GetNativeTypeId() != cel::NativeTypeId::For()) { @@ -108,7 +120,8 @@ CelExpressionRecursiveImpl::Create(FlatExpression flat_expr) { "Expected a recursive program step", flat_expr.path().size())); } - auto* instance = new CelExpressionRecursiveImpl(std::move(flat_expr)); + auto* instance = + new CelExpressionRecursiveImpl(std::move(env), std::move(flat_expr)); return absl::WrapUnique(instance); } @@ -117,12 +130,12 @@ absl::StatusOr CelExpressionRecursiveImpl::Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const { cel::interop_internal::AdapterActivationImpl modern_activation(activation); - cel::ManagedValueFactory factory = flat_expression_.MakeValueFactory( - cel::extensions::ProtoMemoryManagerRef(arena)); - + RuntimeValueManager value_manager(arena, env_->descriptor_pool.get(), + env_->MutableMessageFactory(), + flat_expression_.type_provider()); ComprehensionSlots slots(flat_expression_.comprehension_slots_size()); ExecutionFrameBase execution_frame(modern_activation, AdaptListener(callback), - flat_expression_.options(), factory.get(), + flat_expression_.options(), value_manager, slots); cel::Value result; diff --git a/eval/eval/cel_expression_flat_impl.h b/eval/eval/cel_expression_flat_impl.h index f14e967f3..f5b01a3c2 100644 --- a/eval/eval/cel_expression_flat_impl.h +++ b/eval/eval/cel_expression_flat_impl.h @@ -18,28 +18,36 @@ #include #include -#include "absl/status/status.h" +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" -#include "extensions/protobuf/memory_manager.h" +#include "eval/public/cel_value.h" #include "internal/casts.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_value_manager.h" #include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { // Wrapper for FlatExpressionEvaluationState used to implement CelExpression. class CelExpressionFlatEvaluationState : public CelEvaluationState { public: - CelExpressionFlatEvaluationState(google::protobuf::Arena* arena, - const FlatExpression& expr); + CelExpressionFlatEvaluationState( + google::protobuf::Arena* arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + const FlatExpression& expr); - google::protobuf::Arena* arena() { return arena_; } + google::protobuf::Arena* arena() { return value_manager_.GetMemoryManager().arena(); } FlatExpressionEvaluatorState& state() { return state_; } private: - google::protobuf::Arena* arena_; + cel::runtime_internal::RuntimeValueManager value_manager_; FlatExpressionEvaluatorState state_; }; @@ -49,8 +57,11 @@ class CelExpressionFlatEvaluationState : public CelEvaluationState { // This class adapts FlatExpression to implement the CelExpression interface. class CelExpressionFlatImpl : public CelExpression { public: - explicit CelExpressionFlatImpl(FlatExpression flat_expression) - : flat_expression_(std::move(flat_expression)) {} + CelExpressionFlatImpl( + absl::Nonnull> + env, + FlatExpression flat_expression) + : env_(std::move(env)), flat_expression_(std::move(flat_expression)) {} // Move-only CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; @@ -83,6 +94,7 @@ class CelExpressionFlatImpl : public CelExpression { const FlatExpression& flat_expression() const { return flat_expression_; } private: + absl::Nonnull> env_; FlatExpression flat_expression_; }; @@ -105,6 +117,8 @@ class CelExpressionRecursiveImpl : public CelExpression { public: static absl::StatusOr> Create( + absl::Nonnull> + env, FlatExpression flat_expression); // Move-only @@ -146,12 +160,17 @@ class CelExpressionRecursiveImpl : public CelExpression { const DirectExpressionStep* root() const { return root_; } private: - explicit CelExpressionRecursiveImpl(FlatExpression flat_expression) - : flat_expression_(std::move(flat_expression)), + explicit CelExpressionRecursiveImpl( + absl::Nonnull> + env, + FlatExpression flat_expression) + : env_(std::move(env)), + flat_expression_(std::move(flat_expression)), root_(cel::internal::down_cast( flat_expression_.path()[0].get()) ->wrapped()) {} + absl::Nonnull> env_; FlatExpression flat_expression_; const DirectExpressionStep* root_; }; diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 2fd513ee7..776b0e238 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -31,6 +31,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -46,6 +47,7 @@ using ::cel::Value; using ::cel::ast_internal::Expr; using ::cel::ast_internal::Ident; using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::test::BoolValueIs; using ::google::protobuf::ListValue; using ::google::protobuf::Struct; @@ -72,9 +74,11 @@ class ListKeysStepTest : public testing::Test { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; } + auto env = NewTestingRuntimeEnv(); return std::make_unique( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); } private: diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 53ed03faa..24a8ae032 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -6,8 +6,8 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "common/allocator.h" #include "common/value.h" -#include "common/value_manager.h" #include "eval/eval/compiler_constant_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" @@ -35,10 +35,10 @@ absl::StatusOr> CreateConstValueStep( } absl::StatusOr> CreateConstValueStep( - const Constant& value, int64_t expr_id, cel::ValueManager& value_factory, + const Constant& value, int64_t expr_id, cel::Allocator<> allocator, bool comes_from_ast) { CEL_ASSIGN_OR_RETURN(cel::Value converted_value, - ConvertConstant(value, value_factory)); + ConvertConstant(value, allocator)); return std::make_unique(std::move(converted_value), expr_id, comes_from_ast); diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index f3a95a6cb..4d96d9403 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -6,8 +6,8 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "common/allocator.h" #include "common/value.h" -#include "common/value_manager.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" @@ -25,7 +25,7 @@ absl::StatusOr> CreateConstValueStep( // expression. absl::StatusOr> CreateConstValueStep( const cel::ast_internal::Constant&, int64_t expr_id, - cel::ValueManager& value_factory, bool comes_from_ast = true); + cel::Allocator<> allocator, bool comes_from_ast = true); } // namespace google::api::expr::runtime diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index a22687e3c..78e7c3ab8 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -1,15 +1,14 @@ #include "eval/eval/const_value_step.h" +#include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "base/ast_internal/expr.h" #include "base/type_provider.h" -#include "common/type_factory.h" -#include "common/type_manager.h" -#include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" @@ -20,6 +19,8 @@ #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -33,21 +34,24 @@ using ::cel::ast_internal::Constant; using ::cel::ast_internal::Expr; using ::cel::ast_internal::NullValue; using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::testing::Eq; using ::testing::HasSubstr; absl::StatusOr RunConstantExpression( - const Expr* expr, const Constant& const_expr, google::protobuf::Arena* arena, - cel::ValueManager& value_factory) { - CEL_ASSIGN_OR_RETURN( - auto step, CreateConstValueStep(const_expr, expr->id(), value_factory)); + const absl::Nonnull>& env, + const Expr* expr, const Constant& const_expr, google::protobuf::Arena* arena) { + CEL_ASSIGN_OR_RETURN(auto step, + CreateConstValueStep(const_expr, expr->id(), arena)); google::api::expr::runtime::ExecutionPath path; path.push_back(std::move(step)); CelExpressionFlatImpl impl( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); google::api::expr::runtime::Activation activation; @@ -57,10 +61,12 @@ absl::StatusOr RunConstantExpression( class ConstValueStepTest : public ::testing::Test { public: ConstValueStepTest() - : value_factory_(ProtoMemoryManagerRef(&arena_), - cel::TypeProvider::Builtin()) {} + : env_(NewTestingRuntimeEnv()), + value_factory_(ProtoMemoryManagerRef(&arena_), + env_->type_registry.GetComposedTypeProvider()) {} protected: + absl::Nonnull> env_; google::protobuf::Arena arena_; cel::common_internal::LegacyValueManager value_factory_; }; @@ -70,8 +76,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstInt64) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_int64_value(1); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -86,8 +91,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstUint64) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_uint64_value(1); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -102,8 +106,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstBool) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_bool_value(true); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -118,8 +121,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstNull) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_null_value(nullptr); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -133,8 +135,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstString) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_string_value("test"); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -149,8 +150,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstDouble) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_double_value(1.0); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -167,8 +167,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstBytes) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_bytes_value("test"); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -183,8 +182,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstDuration) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_duration_value(absl::Seconds(5) + absl::Nanoseconds(2000)); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -199,8 +197,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstDurationOutOfRange) { auto& const_expr = expr.mutable_const_expr(); const_expr.set_duration_value(cel::runtime_internal::kDurationHigh); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); @@ -217,8 +214,7 @@ TEST_F(ConstValueStepTest, TestEvaluationConstTimestamp) { const_expr.set_time_value(absl::FromUnixSeconds(3600) + absl::Nanoseconds(1000)); - auto status = - RunConstantExpression(&expr, const_expr, &arena_, value_factory_); + auto status = RunConstantExpression(env_, &expr, const_expr, &arena_); ASSERT_OK(status); diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 232d0e469..56eb66d3c 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -8,6 +8,7 @@ #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "base/builtins.h" #include "base/type_provider.h" @@ -28,6 +29,8 @@ #include "eval/public/unknown_set.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -38,6 +41,8 @@ using ::absl_testing::StatusIs; using ::cel::TypeProvider; using ::cel::ast_internal::Expr; using ::cel::ast_internal::SourceInfo; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::protobuf::Struct; using ::testing::_; @@ -47,6 +52,7 @@ using ::testing::HasSubstr; using TestParamType = std::tuple; CelValue EvaluateAttributeHelper( + const absl::Nonnull>& env, google::protobuf::Arena* arena, CelValue container, CelValue key, bool use_recursive_impl, bool receiver_style, bool enable_unknown, const std::vector& patterns) { @@ -84,8 +90,9 @@ CelValue EvaluateAttributeHelper( options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; options.enable_heterogeneous_equality = false; CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("container", container); @@ -100,16 +107,17 @@ class ContainerAccessStepTest : public ::testing::Test { protected: ContainerAccessStepTest() = default; - void SetUp() override {} + void SetUp() override { env_ = NewTestingRuntimeEnv(); } CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { - return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, use_recursive_impl, - patterns); + return EvaluateAttributeHelper(env_, &arena_, container, key, + receiver_style, enable_unknown, + use_recursive_impl, patterns); } + absl::Nonnull> env_; google::protobuf::Arena arena_; }; @@ -118,7 +126,7 @@ class ContainerAccessStepUniformityTest protected: ContainerAccessStepUniformityTest() = default; - void SetUp() override {} + void SetUp() override { env_ = NewTestingRuntimeEnv(); } bool receiver_style() { TestParamType params = GetParam(); @@ -140,10 +148,11 @@ class ContainerAccessStepUniformityTest CelValue container, CelValue key, bool receiver_style, bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { - return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, use_recursive_impl, - patterns); + return EvaluateAttributeHelper(env_, &arena_, container, key, + receiver_style, enable_unknown, + use_recursive_impl, patterns); } + absl::Nonnull> env_; google::protobuf::Arena arena_; }; diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 9f6af5e11..53dc990f0 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -5,6 +5,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -25,13 +26,17 @@ #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -52,6 +57,8 @@ using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ast_internal::Expr; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::cel::test::IntValueIs; using ::testing::Eq; using ::testing::HasSubstr; @@ -59,9 +66,10 @@ using ::testing::Not; using ::testing::UnorderedElementsAre; // Helper method. Creates simple pipeline containing Select step and runs it. -absl::StatusOr RunExpression(const std::vector& values, - google::protobuf::Arena* arena, - bool enable_unknowns) { +absl::StatusOr RunExpression( + const absl::Nonnull>& env, + const std::vector& values, google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; @@ -84,10 +92,11 @@ absl::StatusOr RunExpression(const std::vector& values, options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, TypeProvider::Builtin(), - options)); + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -95,6 +104,7 @@ absl::StatusOr RunExpression(const std::vector& values, // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpressionWithCelValues( + const absl::Nonnull>& env, const std::vector& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; @@ -125,13 +135,21 @@ absl::StatusOr RunExpressionWithCelValues( } CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); return cel_expr.Evaluate(activation, arena); } -class CreateListStepTest : public testing::TestWithParam {}; +class CreateListStepTest : public testing::TestWithParam { + public: + CreateListStepTest() : env_(NewTestingRuntimeEnv()) {} + + protected: + absl::Nonnull> env_; + google::protobuf::Arena arena_; +}; // Tests error when not enough list elements are on the stack during list // creation. @@ -147,9 +165,11 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); + auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl cel_expr( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; @@ -159,37 +179,34 @@ TEST(CreateListStepTest, TestCreateListStackUnderflow) { } TEST_P(CreateListStepTest, CreateListEmpty) { - google::protobuf::Arena arena; - ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression({}, &arena, GetParam())); + ASSERT_OK_AND_ASSIGN(CelValue result, + RunExpression(env_, {}, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); EXPECT_THAT(result.ListOrDie()->size(), Eq(0)); } TEST_P(CreateListStepTest, CreateListOne) { - google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, - RunExpression({100}, &arena, GetParam())); + RunExpression(env_, {100}, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); const auto& list = *result.ListOrDie(); ASSERT_THAT(list.size(), Eq(1)); - const CelValue& value = list.Get(&arena, 0); + const CelValue& value = list.Get(&arena_, 0); EXPECT_THAT(value, test::IsCelInt64(100)); } TEST_P(CreateListStepTest, CreateListWithError) { - google::protobuf::Arena arena; std::vector values; CelError error = absl::InvalidArgumentError("bad arg"); values.push_back(CelValue::CreateError(&error)); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunExpressionWithCelValues(values, &arena, GetParam())); + ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( + env_, values, &arena_, GetParam())); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::InvalidArgumentError("bad arg"))); } TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { - google::protobuf::Arena arena; // list composition is: {unknown, error} std::vector values; Expr expr0; @@ -200,8 +217,8 @@ TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { CelError error = absl::InvalidArgumentError("bad arg"); values.push_back(CelValue::CreateError(&error)); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunExpressionWithCelValues(values, &arena, GetParam())); + ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( + env_, values, &arena_, GetParam())); // The bad arg should win. ASSERT_TRUE(result.IsError()); @@ -209,18 +226,17 @@ TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { } TEST_P(CreateListStepTest, CreateListHundred) { - google::protobuf::Arena arena; std::vector values; for (size_t i = 0; i < 100; i++) { values.push_back(i); } ASSERT_OK_AND_ASSIGN(CelValue result, - RunExpression(values, &arena, GetParam())); + RunExpression(env_, values, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); const auto& list = *result.ListOrDie(); EXPECT_THAT(list.size(), Eq(static_cast(values.size()))); for (size_t i = 0; i < values.size(); i++) { - EXPECT_THAT(list.Get(&arena, i), test::IsCelInt64(values[i])); + EXPECT_THAT(list.Get(&arena_, i), test::IsCelInt64(values[i])); } } @@ -245,8 +261,9 @@ TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); values.push_back(CelValue::CreateUnknownSet(&unknown_set1)); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunExpressionWithCelValues(values, &arena, true)); + ASSERT_OK_AND_ASSIGN( + CelValue result, + RunExpressionWithCelValues(NewTestingRuntimeEnv(), values, &arena, true)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); EXPECT_THAT(result_set->unknown_attributes().size(), Eq(2)); diff --git a/eval/eval/create_map_step_test.cc b/eval/eval/create_map_step_test.cc index 44554aee4..3f779660f 100644 --- a/eval/eval/create_map_step_test.cc +++ b/eval/eval/create_map_step_test.cc @@ -21,6 +21,7 @@ #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "base/ast_internal/expr.h" @@ -35,6 +36,8 @@ #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -44,6 +47,8 @@ namespace { using ::cel::TypeProvider; using ::cel::ast_internal::Expr; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; absl::StatusOr CreateStackMachineProgram( @@ -121,6 +126,7 @@ absl::StatusOr CreateRecursiveProgram( // builds Map and runs it. // Equivalent to {key0: value0, ...} absl::StatusOr RunCreateMapExpression( + const absl::Nonnull>& env, const std::vector>& values, google::protobuf::Arena* arena, bool enable_unknowns, bool enable_recursive_program) { Activation activation; @@ -137,29 +143,34 @@ absl::StatusOr RunCreateMapExpression( } CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); return cel_expr.Evaluate(activation, arena); } class CreateMapStepTest : public testing::TestWithParam> { public: + CreateMapStepTest() : env_(NewTestingRuntimeEnv()) {} + bool enable_unknowns() { return std::get<0>(GetParam()); } bool enable_recursive_program() { return std::get<1>(GetParam()); } absl::StatusOr RunMapExpression( - const std::vector>& values, - google::protobuf::Arena* arena) { - return RunCreateMapExpression(values, arena, enable_unknowns(), + const std::vector>& values) { + return RunCreateMapExpression(env_, values, &arena_, enable_unknowns(), enable_recursive_program()); } + + protected: + absl::Nonnull> env_; + google::protobuf::Arena arena_; }; // Test that Empty Map is created successfully. TEST_P(CreateMapStepTest, TestCreateEmptyMap) { - Arena arena; - ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression({}, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression({})); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); @@ -168,6 +179,7 @@ TEST_P(CreateMapStepTest, TestCreateEmptyMap) { // Test message creation if unknown argument is passed TEST(CreateMapStepTest, TestMapCreateWithUnknown) { + absl::Nonnull> env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; std::vector> entries; @@ -179,12 +191,13 @@ TEST(CreateMapStepTest, TestMapCreateWithUnknown) { entries.push_back({CelValue::CreateString(&kKeys[1]), CelValue::CreateUnknownSet(&unknown_set)}); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunCreateMapExpression(entries, &arena, true, false)); + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, false)); ASSERT_TRUE(result.IsUnknownSet()); } TEST(CreateMapStepTest, TestMapCreateWithUnknownRecursiveProgram) { + absl::Nonnull> env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; std::vector> entries; @@ -196,8 +209,8 @@ TEST(CreateMapStepTest, TestMapCreateWithUnknownRecursiveProgram) { entries.push_back({CelValue::CreateString(&kKeys[1]), CelValue::CreateUnknownSet(&unknown_set)}); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunCreateMapExpression(entries, &arena, true, true)); + ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( + env, entries, &arena, true, true)); ASSERT_TRUE(result.IsUnknownSet()); } @@ -214,7 +227,7 @@ TEST_P(CreateMapStepTest, TestCreateStringMap) { entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); - ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression(entries, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression(entries)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index c2f170171..4dbf163dd 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -27,7 +27,6 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/casting.h" -#include "common/memory.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/eval/attribute_trail.h" @@ -223,7 +222,7 @@ absl::Status DirectCreateStructStep::Evaluate(ExecutionFrameBase& frame, return absl::OkStatus(); } - result = std::move(*builder).Build(); + CEL_ASSIGN_OR_RETURN(result, std::move(*builder).Build()); return absl::OkStatus(); } diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index ffcfb5faf..0d766c238 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -22,13 +22,13 @@ #include #include "cel/expr/syntax.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/ast_internal/expr.h" #include "base/type_provider.h" -#include "common/values/legacy_value_manager.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" @@ -39,13 +39,13 @@ #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" -#include "extensions/protobuf/memory_manager.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -57,8 +57,9 @@ namespace { using ::cel::TypeProvider; using ::cel::ast_internal::Expr; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::google::protobuf::Message; using ::testing::Eq; @@ -106,23 +107,14 @@ absl::StatusOr MakeRecursivePath(absl::string_view field) { // Helper method. Creates simple pipeline containing CreateStruct step that // builds message and runs it. -absl::StatusOr RunExpression(absl::string_view field, - const CelValue& value, - google::protobuf::Arena* arena, - bool enable_unknowns, - bool enable_recursive_planning) { - CelTypeRegistry type_registry; - type_registry.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); - auto memory_manager = ProtoMemoryManagerRef(arena); - cel::common_internal::LegacyValueManager type_manager( - memory_manager, type_registry.GetTypeProvider()); - - CEL_ASSIGN_OR_RETURN( - auto maybe_type, - type_manager.FindType("google.api.expr.runtime.TestMessage")); +absl::StatusOr RunExpression( + const absl::Nonnull>& env, + absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, + bool enable_unknowns, bool enable_recursive_planning) { + google::protobuf::LinkMessageReflection(); + CEL_ASSIGN_OR_RETURN(auto maybe_type, + env->type_registry.GetComposedTypeProvider().FindType( + "google.api.expr.runtime.TestMessage")); if (!maybe_type.has_value()) { return absl::Status(absl::StatusCode::kFailedPrecondition, "missing proto message type"); @@ -141,77 +133,78 @@ absl::StatusOr RunExpression(absl::string_view field, } CelExpressionFlatImpl cel_expr( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - type_registry.GetTypeProvider(), options)); + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", value); return cel_expr.Evaluate(activation, arena); } -void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns, - bool enable_recursive_planning) { +void RunExpressionAndGetMessage( + const absl::Nonnull>& env, + absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, + TestMessage* test_msg, bool enable_unknowns, + bool enable_recursive_planning) { ASSERT_OK_AND_ASSIGN(auto result, - RunExpression(field, value, arena, enable_unknowns, + RunExpression(env, field, value, arena, enable_unknowns, enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); - test_msg->MergeFrom(*msg); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); + test_msg->MergePartialFromCord(msg->SerializePartialAsCord()); } -void RunExpressionAndGetMessage(absl::string_view field, - std::vector values, - google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns, - bool enable_recursive_planning) { +void RunExpressionAndGetMessage( + const absl::Nonnull>& env, + absl::string_view field, std::vector values, google::protobuf::Arena* arena, + TestMessage* test_msg, bool enable_unknowns, + bool enable_recursive_planning) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); ASSERT_OK_AND_ASSIGN(auto result, - RunExpression(field, value, arena, enable_unknowns, + RunExpression(env, field, value, arena, enable_unknowns, enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); - test_msg->MergeFrom(*msg); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); + test_msg->MergePartialFromCord(msg->SerializePartialAsCord()); } class CreateCreateStructStepTest : public testing::TestWithParam> { public: + CreateCreateStructStepTest() : env_(NewTestingRuntimeEnv()) {} + bool enable_unknowns() { return std::get<0>(GetParam()); } bool enable_recursive_planning() { return std::get<1>(GetParam()); } + + protected: + absl::Nonnull> env_; + google::protobuf::Arena arena_; }; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; - CelTypeRegistry type_registry; - type_registry.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); - google::protobuf::Arena arena; - auto memory_manager = ProtoMemoryManagerRef(&arena); - cel::common_internal::LegacyValueManager type_manager( - memory_manager, type_registry.GetTypeProvider()); - - auto adapter = - type_registry.FindTypeAdapter("google.api.expr.runtime.TestMessage"); + + auto adapter = env_->legacy_type_registry.FindTypeAdapter( + "google.api.expr.runtime.TestMessage"); ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); - ASSERT_OK_AND_ASSIGN( - auto maybe_type, - type_manager.FindType("google.api.expr.runtime.TestMessage")); + ASSERT_OK_AND_ASSIGN(auto maybe_type, + env_->type_registry.GetComposedTypeProvider().FindType( + "google.api.expr.runtime.TestMessage")); ASSERT_TRUE(maybe_type.has_value()); if (enable_recursive_planning()) { auto step = @@ -235,26 +228,29 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - type_registry.GetTypeProvider(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; - ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena)); + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); - ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); + ASSERT_EQ(msg->GetDescriptor()->full_name(), + "google.api.expr.runtime.TestMessage"); } // Test message creation if unknown argument is passed TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { + absl::Nonnull> env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; UnknownSet unknown_set; auto eval_status = - RunExpression("bool_value", CelValue::CreateUnknownSet(&unknown_set), + RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true, /*enable_recursive_planning=*/false); ASSERT_OK(eval_status); ASSERT_TRUE(eval_status->IsUnknownSet()); @@ -262,12 +258,13 @@ TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { // Test message creation if unknown argument is passed TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknownRecursive) { + absl::Nonnull> env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; UnknownSet unknown_set; auto eval_status = - RunExpression("bool_value", CelValue::CreateUnknownSet(&unknown_set), + RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true, /*enable_recursive_planning=*/true); ASSERT_OK(eval_status); ASSERT_TRUE(eval_status->IsUnknownSet()) << eval_status->DebugString(); @@ -275,22 +272,20 @@ TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknownRecursive) { // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetBoolField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_value", CelValue::CreateBool(true), &arena, &test_msg, + env_, "bool_value", CelValue::CreateBool(true), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.bool_value(), true); } // Test that fields of type int32_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, + env_, "int32_value", CelValue::CreateInt64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int32_value(), 1); @@ -298,11 +293,10 @@ TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { // Test that fields of type uint32_t are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_value", CelValue::CreateUint64(1), &arena, &test_msg, + env_, "uint32_value", CelValue::CreateUint64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint32_value(), 1); @@ -310,11 +304,10 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { // Test that fields of type int64_t are set correctly. TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, + env_, "int64_value", CelValue::CreateInt64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.int64_value(), 1); @@ -322,11 +315,10 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { // Test that fields of type uint64_t are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_value", CelValue::CreateUint64(1), &arena, &test_msg, + env_, "uint64_value", CelValue::CreateUint64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.uint64_value(), 1); @@ -334,11 +326,10 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { // Test that fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetFloatField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_value", CelValue::CreateDouble(2.0), &arena, &test_msg, + env_, "float_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); @@ -346,11 +337,10 @@ TEST_P(CreateCreateStructStepTest, TestSetFloatField) { // Test that fields of type double are set correctly TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_value", CelValue::CreateDouble(2.0), &arena, &test_msg, + env_, "double_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } @@ -359,63 +349,55 @@ TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { TEST_P(CreateCreateStructStepTest, TestSetStringField) { const std::string kTestStr = "test"; - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg, - enable_unknowns(), enable_recursive_planning())); + env_, "string_value", CelValue::CreateString(&kTestStr), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.string_value(), kTestStr); } // Test that fields of type bytes are set correctly. TEST_P(CreateCreateStructStepTest, TestSetBytesField) { - Arena arena; - const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg, + env_, "bytes_value", CelValue::CreateBytes(&kTestStr), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } // Test that fields of type duration are set correctly. TEST_P(CreateCreateStructStepTest, TestSetDurationField) { - Arena arena; - google::protobuf::Duration test_duration; test_duration.set_seconds(2); test_duration.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "duration_value", CelProtoWrapper::CreateDuration(&test_duration), &arena, - &test_msg, enable_unknowns(), enable_recursive_planning())); + env_, "duration_value", CelProtoWrapper::CreateDuration(&test_duration), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } // Test that fields of type timestamp are set correctly. TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { - Arena arena; - google::protobuf::Timestamp test_timestamp; test_timestamp.set_seconds(2); test_timestamp.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "timestamp_value", CelProtoWrapper::CreateTimestamp(&test_timestamp), - &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); + env_, "timestamp_value", + CelProtoWrapper::CreateTimestamp(&test_timestamp), &arena_, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetMessageField) { - Arena arena; - // Create payload message and set some fields. TestMessage orig_msg; orig_msg.set_bool_value(true); @@ -424,15 +406,13 @@ TEST_P(CreateCreateStructStepTest, TestSetMessageField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena), - &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); + env_, "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena_), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } // Test that fields of type Any are set correctly. TEST_P(CreateCreateStructStepTest, TestSetAnyField) { - Arena arena; - // Create payload message and set some fields. TestMessage orig_embedded_msg; orig_embedded_msg.set_bool_value(true); @@ -444,8 +424,9 @@ TEST_P(CreateCreateStructStepTest, TestSetAnyField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "any_value", CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena), - &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); + env_, "any_value", + CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena_), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg, EqualsProto(orig_msg)); TestMessage test_embedded_msg; @@ -455,18 +436,16 @@ TEST_P(CreateCreateStructStepTest, TestSetAnyField) { // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetEnumField) { - Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena, - &test_msg, enable_unknowns(), enable_recursive_planning())); + env_, "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), + &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { - Arena arena; TestMessage test_msg; std::vector kValues = {true, false}; @@ -476,14 +455,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_list", values, &arena, &test_msg, enable_unknowns(), + env_, "bool_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int32_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -493,14 +471,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_list", values, &arena, &test_msg, enable_unknowns(), + env_, "int32_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -510,14 +487,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_list", values, &arena, &test_msg, enable_unknowns(), + env_, "uint32_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int64_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -527,14 +503,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_list", values, &arena, &test_msg, enable_unknowns(), + env_, "int64_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint64_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -544,14 +519,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_list", values, &arena, &test_msg, enable_unknowns(), + env_, "uint64_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -561,14 +535,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_list", values, &arena, &test_msg, enable_unknowns(), + env_, "float_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { - Arena arena; TestMessage test_msg; std::vector kValues = {23, 12}; @@ -578,14 +551,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_list", values, &arena, &test_msg, enable_unknowns(), + env_, "double_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { - Arena arena; TestMessage test_msg; std::vector kValues = {"test1", "test2"}; @@ -595,14 +567,13 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_list", values, &arena, &test_msg, enable_unknowns(), + env_, "string_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { - Arena arena; TestMessage test_msg; std::vector kValues = {"test1", "test2"}; @@ -612,7 +583,7 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_list", values, &arena, &test_msg, enable_unknowns(), + env_, "bytes_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } @@ -620,7 +591,6 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { // Test that repeated fields of type Message are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { - Arena arena; TestMessage test_msg; std::vector kValues(2); @@ -628,11 +598,11 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { kValues[1].set_string_value("test2"); std::vector values; for (const auto& value : kValues) { - values.push_back(CelProtoWrapper::CreateMessage(&value, &arena)); + values.push_back(CelProtoWrapper::CreateMessage(&value, &arena_)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "message_list", values, &arena, &test_msg, enable_unknowns(), + env_, "message_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); @@ -641,7 +611,6 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -658,8 +627,8 @@ TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - enable_unknowns(), enable_recursive_planning())); + env_, "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); @@ -668,7 +637,6 @@ TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -685,8 +653,8 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - enable_unknowns(), enable_recursive_planning())); + env_, "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); @@ -695,7 +663,6 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { - Arena arena; TestMessage test_msg; std::vector> entries; @@ -712,8 +679,8 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - enable_unknowns(), enable_recursive_planning())); + env_, "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, + &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 253edbc71..bef821a44 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -32,7 +32,6 @@ #include "common/value.h" #include "common/value_manager.h" #include "runtime/activation_interface.h" -#include "runtime/managed_value_factory.h" namespace google::api::expr::runtime { @@ -195,9 +194,4 @@ absl::StatusOr FlatExpression::EvaluateWithCallback( return frame.Evaluate(frame.callback()); } -cel::ManagedValueFactory FlatExpression::MakeValueFactory( - cel::MemoryManagerRef memory_manager) const { - return cel::ManagedValueFactory(type_provider_, memory_manager); -} - } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 468a06634..931c76651 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -414,9 +414,6 @@ class FlatExpression { const cel::ActivationInterface& activation, EvaluationListener listener, FlatExpressionEvaluatorState& state) const; - cel::ManagedValueFactory MakeValueFactory( - cel::MemoryManagerRef memory_manager) const; - const ExecutionPath& path() const { return path_; } absl::Span subexpressions() const { @@ -427,6 +424,8 @@ class FlatExpression { size_t comprehension_slots_size() const { return comprehension_slots_size_; } + const cel::TypeProvider& type_provider() const { return type_provider_; } + private: ExecutionPath path_; std::vector subexpressions_; diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index da15f4b4e..a656a5078 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -96,8 +96,11 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { path.push_back(std::move(incr_step1)); path.push_back(std::move(incr_step2)); - CelExpressionFlatImpl impl(FlatExpression( - std::move(path), 0, cel::TypeProvider::Builtin(), cel::RuntimeOptions{})); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, FlatExpression(std::move(path), 0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 1fc9b6e10..74557e6ce 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -7,6 +7,8 @@ #include #include +#include "absl/base/nullability.h" +#include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "base/ast_internal/expr.h" #include "base/builtins.h" @@ -32,6 +34,7 @@ #include "internal/testing.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" @@ -47,6 +50,7 @@ using ::cel::TypeProvider; using ::cel::ast_internal::Call; using ::cel::ast_internal::Expr; using ::cel::ast_internal::Ident; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::testing::Eq; using ::testing::Not; using ::testing::Truly; @@ -212,9 +216,11 @@ std::unique_ptr CreateExpressionImpl( ExecutionPath path; path.push_back(std::make_unique(std::move(expr), -1)); + auto env = NewTestingRuntimeEnv(); return std::make_unique( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); } absl::StatusOr> MakeTestFunctionStep( @@ -239,9 +245,11 @@ class FunctionStepTest cel::RuntimeOptions options; options.unknown_processing = GetParam(); + auto env = NewTestingRuntimeEnv(); return std::make_unique( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); } }; @@ -585,9 +593,11 @@ class FunctionStepTestUnknowns cel::RuntimeOptions options; options.unknown_processing = GetParam(); + auto env = NewTestingRuntimeEnv(); return std::make_unique( + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env->type_registry.GetComposedTypeProvider(), options)); } }; @@ -722,9 +732,12 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; @@ -769,9 +782,12 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; @@ -816,9 +832,12 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; @@ -858,9 +877,12 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; @@ -945,9 +967,12 @@ TEST(FunctionStepStrictnessTest, cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); @@ -974,9 +999,12 @@ TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 725517d7f..042adc9d8 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -17,6 +17,7 @@ #include "eval/public/cel_attribute.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" @@ -36,6 +37,7 @@ using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ast_internal::Expr; +using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; using ::testing::HasSubstr; @@ -51,9 +53,11 @@ TEST(IdentStepTest, TestIdentStep) { ExecutionPath path; path.push_back(std::move(step)); + auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; Arena arena; @@ -79,9 +83,11 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { ExecutionPath path; path.push_back(std::move(step)); + auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; Arena arena; @@ -105,9 +111,12 @@ TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { path.push_back(std::move(step)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; @@ -145,9 +154,12 @@ TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; options.enable_missing_attribute_errors = true; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; @@ -185,9 +197,12 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { // Expression with unknowns enabled. cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; - CelExpressionFlatImpl impl(FlatExpression(std::move(path), - /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + auto env = NewTestingRuntimeEnv(); + CelExpressionFlatImpl impl( + env, + FlatExpression(std::move(path), + /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index d4035e806..4224fdb4f 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -31,6 +32,8 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -53,11 +56,15 @@ using ::cel::Value; using ::cel::ValueManager; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; class LogicStepTest : public testing::TestWithParam { public: + LogicStepTest() : env_(NewTestingRuntimeEnv()) {} + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, bool is_or, CelValue* result, bool enable_unknown) { Expr expr0; @@ -85,8 +92,9 @@ class LogicStepTest : public testing::TestWithParam { cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl impl( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("name0", arg0); @@ -97,6 +105,7 @@ class LogicStepTest : public testing::TestWithParam { } private: + absl::Nonnull> env_; Arena arena_; }; diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 3bb22fca8..8cba56a04 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -1,11 +1,13 @@ #include "eval/eval/select_step.h" +#include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -41,6 +43,8 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -69,6 +73,8 @@ using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::extensions::ProtoMessageToValue; using ::cel::internal::test::EqualsProto; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::cel::test::IntValueIs; using ::testing::_; using ::testing::Eq; @@ -110,8 +116,9 @@ class MockAccessor : public LegacyTypeAccessApis, public LegacyTypeInfoApis { class SelectStepTest : public testing::Test { public: SelectStepTest() - : value_factory_(ProtoMemoryManagerRef(&arena_), - cel::TypeProvider::Builtin()) {} + : env_(NewTestingRuntimeEnv()), + value_factory_(ProtoMemoryManagerRef(&arena_), + env_->type_registry.GetComposedTypeProvider()) {} // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpression(const CelValue target, absl::string_view field, bool test, @@ -142,8 +149,9 @@ class SelectStepTest : public testing::Test { cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), runtime_options)); + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + runtime_options)); Activation activation; activation.InsertValue("target", target); @@ -186,6 +194,7 @@ class SelectStepTest : public testing::Test { } protected: + absl::Nonnull> env_; google::protobuf::Arena arena_; cel::common_internal::LegacyValueManager value_factory_; }; @@ -338,8 +347,9 @@ TEST_F(SelectStepTest, MapPresenseIsErrorTest) { path.push_back(std::move(step1)); path.push_back(std::move(step2)); CelExpressionFlatImpl cel_expr( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena_)); @@ -845,8 +855,9 @@ TEST_P(SelectStepConformanceTest, CelErrorAsArgument) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -879,8 +890,9 @@ TEST_F(SelectStepTest, DisableMissingAttributeOK) { path.push_back(std::move(step1)); CelExpressionFlatImpl cel_expr( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env_->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); @@ -922,8 +934,9 @@ TEST_F(SelectStepTest, UnrecoverableUnknownValueProducesError) { cel::RuntimeOptions options; options.enable_missing_attribute_errors = true; CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); @@ -971,8 +984,9 @@ TEST_F(SelectStepTest, UnknownPatternResolvesToUnknown) { cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; CelExpressionFlatImpl cel_expr( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); { std::vector unknown_patterns; diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc index 935fc0d44..10ad88e70 100644 --- a/eval/eval/shadowable_value_step_test.cc +++ b/eval/eval/shadowable_value_step_test.cc @@ -1,8 +1,10 @@ #include "eval/eval/shadowable_value_step.h" +#include #include #include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "base/type_provider.h" #include "common/value.h" @@ -13,6 +15,8 @@ #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -21,13 +25,15 @@ namespace { using ::cel::TypeProvider; using ::cel::interop_internal::CreateTypeValueFromView; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; -absl::StatusOr RunShadowableExpression(std::string identifier, - cel::Value value, - const Activation& activation, - Arena* arena) { +absl::StatusOr RunShadowableExpression( + const absl::Nonnull>& env, + std::string identifier, cel::Value value, const Activation& activation, + Arena* arena) { CEL_ASSIGN_OR_RETURN( auto step, CreateShadowableValueStep(std::move(identifier), std::move(value), 1)); @@ -35,12 +41,14 @@ absl::StatusOr RunShadowableExpression(std::string identifier, path.push_back(std::move(step)); CelExpressionFlatImpl impl( - FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), cel::RuntimeOptions{})); + env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + env->type_registry.GetComposedTypeProvider(), + cel::RuntimeOptions{})); return impl.Evaluate(activation, arena); } TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { + absl::Nonnull> env = NewTestingRuntimeEnv(); std::string type_name = "google.api.expr.runtime.TestMessage"; Activation activation; @@ -48,7 +56,7 @@ TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { auto type_value = CreateTypeValueFromView(&arena, type_name); auto status = - RunShadowableExpression(type_name, type_value, activation, &arena); + RunShadowableExpression(env, type_name, type_value, activation, &arena); ASSERT_OK(status); auto value = status.value(); @@ -57,6 +65,7 @@ TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { } TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { + absl::Nonnull> env = NewTestingRuntimeEnv(); std::string type_name = "int"; auto shadow_value = CelValue::CreateInt64(1024L); @@ -66,7 +75,7 @@ TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { auto type_value = CreateTypeValueFromView(&arena, type_name); auto status = - RunShadowableExpression(type_name, type_value, activation, &arena); + RunShadowableExpression(env, type_name, type_value, activation, &arena); ASSERT_OK(status); auto value = status.value(); diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index d622ee125..3d6553c49 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -28,6 +28,8 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" +#include "runtime/internal/runtime_env.h" +#include "runtime/internal/runtime_env_testing.h" #include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -48,6 +50,8 @@ using ::cel::UnknownValue; using ::cel::ValueManager; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::runtime_internal::NewTestingRuntimeEnv; +using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::ElementsAre; using ::testing::Eq; @@ -56,6 +60,8 @@ using ::testing::Truly; class LogicStepTest : public testing::TestWithParam { public: + LogicStepTest() : env_(NewTestingRuntimeEnv()) {} + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, CelValue arg2, CelValue* result, bool enable_unknown) { Expr expr0; @@ -93,8 +99,9 @@ class LogicStepTest : public testing::TestWithParam { cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl impl( + env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, - TypeProvider::Builtin(), options)); + env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; std::string value("test"); @@ -110,6 +117,7 @@ class LogicStepTest : public testing::TestWithParam { } private: + absl::Nonnull> env_; Arena arena_; }; diff --git a/eval/public/BUILD b/eval/public/BUILD index 609956706..ba09d6cfe 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -580,7 +580,6 @@ cc_library( "//eval/compiler:flat_expr_builder", "//eval/compiler:qualified_reference_resolver", "//eval/compiler:regex_precompilation_optimization", - "//eval/public/structs:protobuf_descriptor_type_provider", "//extensions:select_optimization", "//internal:noop_delete", "//runtime:runtime_options", @@ -803,16 +802,14 @@ cc_library( hdrs = ["cel_type_registry.h"], deps = [ "//base:data", - "//common:type", - "//common:value", - "//eval/internal:interop", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:legacy_type_provider", + "//eval/public/structs:protobuf_descriptor_type_provider", "//runtime:type_registry", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", @@ -1106,6 +1103,17 @@ cc_library( ], ) +cc_test( + name = "cel_number_test", + srcs = ["cel_number_test.cc"], + deps = [ + ":cel_number", + ":cel_value", + "//internal:testing", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "string_extension_func_registrar", srcs = ["string_extension_func_registrar.cc"], diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 436a85752..1db5e5174 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -33,7 +33,6 @@ #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "extensions/select_optimization.h" #include "internal/noop_delete.h" #include "runtime/internal/runtime_env.h" @@ -87,14 +86,6 @@ std::unique_ptr CreateCelExpressionBuilder( auto builder = std::make_unique( std::move(env), runtime_options); - builder->GetTypeRegistry() - ->InternalGetModernRegistry() - .set_use_legacy_container_builders(options.use_legacy_container_builders); - - builder->GetTypeRegistry()->RegisterTypeProvider( - std::make_unique(descriptor_pool, - message_factory)); - FlatExprBuilder& flat_expr_builder = builder->flat_expr_builder(); flat_expr_builder.AddAstTransform(NewReferenceResolverExtension( diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc index 1a2fc234d..1167ea4db 100644 --- a/eval/public/cel_type_registry.cc +++ b/eval/public/cel_type_registry.cc @@ -19,14 +19,8 @@ #include #include -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "base/type_provider.h" -#include "common/type.h" -#include "common/type_factory.h" -#include "common/value.h" -#include "eval/internal/interop.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/structs/legacy_type_provider.h" @@ -36,9 +30,6 @@ namespace google::api::expr::runtime { namespace { -using cel::Type; -using cel::TypeFactory; - class LegacyToModernTypeProviderAdapter : public LegacyTypeProvider { public: explicit LegacyToModernTypeProviderAdapter(const LegacyTypeProvider& provider) @@ -71,8 +62,6 @@ void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, } // namespace -CelTypeRegistry::CelTypeRegistry() = default; - void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { AddEnumFromDescriptor(enum_descriptor, *this); } @@ -82,33 +71,14 @@ void CelTypeRegistry::RegisterEnum(absl::string_view enum_name, modern_type_registry_.RegisterEnum(enum_name, std::move(enumerators)); } -void CelTypeRegistry::RegisterTypeProvider( - std::unique_ptr provider) { - legacy_type_providers_.push_back( - std::shared_ptr(std::move(provider))); - modern_type_registry_.AddTypeProvider( - std::make_unique( - *legacy_type_providers_.back())); -} - -std::shared_ptr -CelTypeRegistry::GetFirstTypeProvider() const { - if (legacy_type_providers_.empty()) { - return nullptr; - } - return legacy_type_providers_[0]; -} - // Find a type's CelValue instance by its fully qualified name. absl::optional CelTypeRegistry::FindTypeAdapter( absl::string_view fully_qualified_type_name) const { - for (const auto& provider : legacy_type_providers_) { - auto maybe_adapter = provider->ProvideLegacyType(fully_qualified_type_name); - if (maybe_adapter.has_value()) { - return maybe_adapter; - } + auto maybe_adapter = + GetFirstTypeProvider()->ProvideLegacyType(fully_qualified_type_name); + if (maybe_adapter.has_value()) { + return maybe_adapter; } - return absl::nullopt; } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 80854f43d..097d143a9 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -20,11 +20,18 @@ #include #include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/type_provider.h" +#include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -48,7 +55,13 @@ class CelTypeRegistry { // Representation of an enum. using Enumeration = cel::TypeRegistry::Enumeration; - CelTypeRegistry(); + CelTypeRegistry() + : CelTypeRegistry(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()) {} + + CelTypeRegistry(absl::Nonnull descriptor_pool, + absl::Nullable message_factory) + : modern_type_registry_(descriptor_pool, message_factory) {} ~CelTypeRegistry() = default; @@ -63,13 +76,11 @@ class CelTypeRegistry { void RegisterEnum(absl::string_view name, std::vector enumerators); - // Register a new type provider. - // - // Type providers are consulted in the order they are added. - void RegisterTypeProvider(std::unique_ptr provider); - // Get the first registered type provider. - std::shared_ptr GetFirstTypeProvider() const; + std::shared_ptr GetFirstTypeProvider() const { + return cel::runtime_internal::GetLegacyRuntimeTypeProvider( + modern_type_registry_); + } // Returns the effective type provider that has been configured with the // registry. @@ -132,7 +143,7 @@ class CelTypeRegistry { // TODO: This is needed to inspect the registered legacy type // providers for client tests. This can be removed when they are migrated to // use the modern APIs. - std::vector> legacy_type_providers_; + std::shared_ptr legacy_type_provider_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/cel_type_registry_protobuf_reflection_test.cc b/eval/public/cel_type_registry_protobuf_reflection_test.cc index 4f9ba7be1..ee4e79e3a 100644 --- a/eval/public/cel_type_registry_protobuf_reflection_test.cc +++ b/eval/public/cel_type_registry_protobuf_reflection_test.cc @@ -96,10 +96,6 @@ TEST(CelTypeRegistryTypeProviderTest, StructTypes) { google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); - registry.RegisterTypeProvider(std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); - cel::common_internal::LegacyValueManager value_manager( MemoryManagerRef::ReferenceCounting(), registry.GetTypeProvider()); diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index 60809e9b7..fc593b83a 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -89,38 +89,22 @@ TEST(CelTypeRegistryTest, TestRegisterBuiltInEnum) { TEST(CelTypeRegistryTest, TestGetFirstTypeProviderSuccess) { CelTypeRegistry registry; - registry.RegisterTypeProvider(std::make_unique( - std::vector{"google.protobuf.Int64"})); - registry.RegisterTypeProvider(std::make_unique( - std::vector{"google.protobuf.Any"})); auto type_provider = registry.GetFirstTypeProvider(); ASSERT_NE(type_provider, nullptr); - ASSERT_TRUE( - type_provider->ProvideLegacyType("google.protobuf.Int64").has_value()); ASSERT_FALSE( + type_provider->ProvideLegacyType("google.protobuf.Int64").has_value()); + ASSERT_TRUE( type_provider->ProvideLegacyType("google.protobuf.Any").has_value()); } -TEST(CelTypeRegistryTest, TestGetFirstTypeProviderFailureOnEmpty) { - CelTypeRegistry registry; - auto type_provider = registry.GetFirstTypeProvider(); - ASSERT_EQ(type_provider, nullptr); -} - TEST(CelTypeRegistryTest, TestFindTypeAdapterFound) { CelTypeRegistry registry; - registry.RegisterTypeProvider(std::make_unique( - std::vector{"google.protobuf.Any"})); auto desc = registry.FindTypeAdapter("google.protobuf.Any"); ASSERT_TRUE(desc.has_value()); } TEST(CelTypeRegistryTest, TestFindTypeAdapterFoundMultipleProviders) { CelTypeRegistry registry; - registry.RegisterTypeProvider(std::make_unique( - std::vector{"google.protobuf.Int64"})); - registry.RegisterTypeProvider(std::make_unique( - std::vector{"google.protobuf.Any"})); auto desc = registry.FindTypeAdapter("google.protobuf.Any"); ASSERT_TRUE(desc.has_value()); } diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index 10c1441a7..86bf4053e 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -213,6 +213,7 @@ cc_library( "//eval/public:message_wrapper", "//extensions/protobuf:memory_manager", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -315,9 +316,12 @@ cc_library( srcs = ["protobuf_descriptor_type_provider.cc"], hdrs = ["protobuf_descriptor_type_provider.h"], deps = [ + ":legacy_type_adapter", + ":legacy_type_info_apis", ":legacy_type_provider", ":proto_message_type_adapter", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index efab51f00..fe7ed35cc 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -95,8 +95,66 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { MessageWrapper::Builder builder_; }; +class LegacyValueBuilder final : public cel::ValueBuilder { + public: + LegacyValueBuilder(cel::MemoryManagerRef memory_manager, + LegacyTypeAdapter adapter, MessageWrapper::Builder builder) + : memory_manager_(memory_manager), + adapter_(adapter), + builder_(std::move(builder)) {} + + absl::Status SetFieldByName(absl::string_view name, + cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value)); + return adapter_.mutation_apis()->SetField(name, legacy_value, + memory_manager_, builder_); + } + + absl::Status SetFieldByNumber(int64_t number, cel::Value value) override { + CEL_ASSIGN_OR_RETURN( + auto legacy_value, + LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), + value)); + return adapter_.mutation_apis()->SetFieldByNumber( + number, legacy_value, memory_manager_, builder_); + } + + absl::StatusOr Build() && override { + CEL_ASSIGN_OR_RETURN(auto value, + adapter_.mutation_apis()->AdaptFromWellKnownType( + memory_manager_, std::move(builder_))); + return cel::ModernValue( + cel::extensions::ProtoMemoryManagerArena(memory_manager_), value); + } + + private: + cel::MemoryManagerRef memory_manager_; + LegacyTypeAdapter adapter_; + MessageWrapper::Builder builder_; +}; + } // namespace +absl::StatusOr> +LegacyTypeProvider::NewValueBuilder(cel::ValueFactory& value_factory, + absl::string_view name) const { + if (auto type_adapter = ProvideLegacyType(name); type_adapter.has_value()) { + const auto* mutation_apis = type_adapter->mutation_apis(); + if (mutation_apis == nullptr) { + return absl::FailedPreconditionError( + absl::StrCat("LegacyTypeMutationApis missing for type: ", name)); + } + CEL_ASSIGN_OR_RETURN(auto builder, mutation_apis->NewInstance( + value_factory.GetMemoryManager())); + return std::make_unique( + value_factory.GetMemoryManager(), *type_adapter, std::move(builder)); + } + return nullptr; +} + absl::StatusOr> LegacyTypeProvider::NewStructValueBuilder(cel::ValueFactory& value_factory, const cel::StructType& type) const { diff --git a/eval/public/structs/legacy_type_provider.h b/eval/public/structs/legacy_type_provider.h index 78a8a09b5..002f32bd1 100644 --- a/eval/public/structs/legacy_type_provider.h +++ b/eval/public/structs/legacy_type_provider.h @@ -15,17 +15,18 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "common/memory.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "common/value_factory.h" #include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" namespace google::api::expr::runtime { @@ -59,6 +60,9 @@ class LegacyTypeProvider : public cel::TypeReflector { return absl::nullopt; } + absl::StatusOr> NewValueBuilder( + cel::ValueFactory& value_factory, absl::string_view name) const final; + absl::StatusOr> NewStructValueBuilder(cel::ValueFactory& value_factory, const cel::StructType& type) const final; diff --git a/eval/public/structs/protobuf_descriptor_type_provider.h b/eval/public/structs/protobuf_descriptor_type_provider.h index 5856f4f8a..2d82fcfce 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.h +++ b/eval/public/structs/protobuf_descriptor_type_provider.h @@ -17,32 +17,40 @@ #include #include -#include -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" +#include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/optional.h" +#include "eval/public/structs/legacy_type_adapter.h" +#include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/structs/legacy_type_provider.h" #include "eval/public/structs/proto_message_type_adapter.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { // Implementation of a type provider that generates types from protocol buffer // descriptors. -class ProtobufDescriptorProvider final : public LegacyTypeProvider { +class ProtobufDescriptorProvider : public LegacyTypeProvider { public: ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, google::protobuf::MessageFactory* factory) : descriptor_pool_(pool), message_factory_(factory) {} absl::optional ProvideLegacyType( - absl::string_view name) const override; + absl::string_view name) const final; absl::optional ProvideLegacyTypeInfo( - absl::string_view name) const override; + absl::string_view name) const final; + + absl::Nonnull descriptor_pool() + const override { + return descriptor_pool_; + } private: // Create a new type instance if found in the registered descriptor pool. diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 256aa26e9..59b57749b 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -89,7 +89,6 @@ cc_test( "//common:native_type", "//common:type", "//common:value", - "//extensions/protobuf:memory_manager", "//extensions/protobuf:runtime_adapter", "//extensions/protobuf:value", "//internal:benchmark", diff --git a/eval/tests/modern_benchmark_test.cc b/eval/tests/modern_benchmark_test.cc index 22233210a..5063179c9 100644 --- a/eval/tests/modern_benchmark_test.cc +++ b/eval/tests/modern_benchmark_test.cc @@ -44,7 +44,6 @@ #include "common/type.h" #include "common/value.h" #include "eval/tests/request_context.pb.h" -#include "extensions/protobuf/memory_manager.h" #include "extensions/protobuf/runtime_adapter.h" #include "extensions/protobuf/value.h" #include "internal/benchmark.h" @@ -61,8 +60,6 @@ #include "google/protobuf/text_format.h" ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); -ABSL_FLAG(bool, enable_ref_counting, false, - "enable reference counting memory management"); namespace cel { @@ -70,7 +67,6 @@ namespace { using ::absl_testing::IsOkAndHolds; using ::cel::extensions::ProtobufRuntimeAdapter; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; @@ -111,15 +107,6 @@ std::unique_ptr StandardRuntimeOrDie( return std::move(runtime).value(); } -// Set the appropriate memory manager for based on flags. -MemoryManagerRef GetMemoryManagerForBenchmark(google::protobuf::Arena* arena) { - if (absl::GetFlag(FLAGS_enable_ref_counting)) { - return MemoryManagerRef::ReferenceCounting(); - } else { - return ProtoMemoryManagerRef(arena); - } -} - template Value WrapMessageOrDie(ValueManager& value_manager, const T& message) { auto value = extensions::ProtoMessageToValue(value_manager, message); @@ -155,11 +142,9 @@ static void BM_Eval(benchmark::State& state) { for (auto _ : state) { google::protobuf::Arena arena; - ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); Activation activation; ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result) == len + 1); } @@ -202,11 +187,8 @@ static void BM_Eval_Trace(benchmark::State& state) { for (auto _ : state) { google::protobuf::Arena arena; Activation activation; - cel::ManagedValueFactory value_factory( - runtime->GetTypeProvider(), GetMemoryManagerForBenchmark(&arena)); - ASSERT_OK_AND_ASSIGN( - cel::Value result, - cel_expr->Trace(activation, EmptyCallback, value_factory.get())); + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result) == len + 1); } @@ -246,10 +228,8 @@ static void BM_EvalString(benchmark::State& state) { for (auto _ : state) { google::protobuf::Arena arena; Activation activation; - cel::ManagedValueFactory value_factory( - runtime->GetTypeProvider(), GetMemoryManagerForBenchmark(&arena)); ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result).Size() == len + 1); } @@ -290,11 +270,8 @@ static void BM_EvalString_Trace(benchmark::State& state) { for (auto _ : state) { google::protobuf::Arena arena; Activation activation; - cel::ManagedValueFactory value_factory( - runtime->GetTypeProvider(), GetMemoryManagerForBenchmark(&arena)); - ASSERT_OK_AND_ASSIGN( - cel::Value result, - cel_expr->Trace(activation, EmptyCallback, value_factory.get())); + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result).Size() == len + 1); } @@ -370,18 +347,13 @@ void BM_PolicySymbolic(benchmark::State& state) { *runtime, parsed_expr)); Activation activation; - ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); - activation.InsertOrAssignValue( - "ip", value_factory.get().CreateUncheckedStringValue(kIP)); - activation.InsertOrAssignValue( - "path", value_factory.get().CreateUncheckedStringValue(kPath)); - activation.InsertOrAssignValue( - "token", value_factory.get().CreateUncheckedStringValue(kToken)); + activation.InsertOrAssignValue("ip", StringValue(&arena, kIP)); + activation.InsertOrAssignValue("path", StringValue(&arena, kPath)); + activation.InsertOrAssignValue("token", StringValue(&arena, kToken)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); auto result_bool = As(result); ASSERT_TRUE(result_bool && result_bool->NativeValue()); } @@ -469,16 +441,14 @@ void BM_PolicySymbolicMap(benchmark::State& state) { *runtime, parsed_expr)); Activation activation; - ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); ParsedMapValue map_value( - value_factory.get().GetMemoryManager().MakeShared()); + MemoryManager::Pooling(&arena).MakeShared()); activation.InsertOrAssignValue("request", std::move(map_value)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -506,7 +476,7 @@ void BM_PolicySymbolicProto(benchmark::State& state) { *runtime, parsed_expr)); ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); Activation activation; RequestContext request; request.set_ip(kIP); @@ -516,7 +486,7 @@ void BM_PolicySymbolicProto(benchmark::State& state) { "request", WrapMessageOrDie(value_factory.get(), request)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -585,7 +555,7 @@ void BM_Comprehension(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -603,7 +573,7 @@ void BM_Comprehension(benchmark::State& state) { ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len); } @@ -626,7 +596,7 @@ void BM_Comprehension_Trace(benchmark::State& state) { ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -640,9 +610,8 @@ void BM_Comprehension_Trace(benchmark::State& state) { activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); for (auto _ : state) { - ASSERT_OK_AND_ASSIGN( - cel::Value result, - cel_expr->Trace(activation, EmptyCallback, value_factory.get())); + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len); } @@ -663,7 +632,7 @@ void BM_HasMap(benchmark::State& state) { *runtime, parsed_expr)); ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN(auto map_builder, value_factory.get().NewMapValueBuilder( cel::JsonMapType())); @@ -676,7 +645,7 @@ void BM_HasMap(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -696,7 +665,7 @@ void BM_HasProto(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); RequestContext request; request.set_path(kPath); @@ -706,7 +675,7 @@ void BM_HasProto(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -727,7 +696,7 @@ void BM_HasProtoMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); @@ -736,7 +705,7 @@ void BM_HasProtoMap(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -758,7 +727,7 @@ void BM_ReadProtoMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); @@ -767,7 +736,7 @@ void BM_ReadProtoMap(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -789,7 +758,7 @@ void BM_NestedProtoFieldRead(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); RequestContext request; request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); @@ -798,7 +767,7 @@ void BM_NestedProtoFieldRead(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -820,7 +789,7 @@ void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); RequestContext request; activation.InsertOrAssignValue( @@ -828,7 +797,7 @@ void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -850,7 +819,7 @@ void BM_ProtoStructAccess(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); AttributeContext::Request request; auto* auth = request.mutable_auth(); @@ -861,7 +830,7 @@ void BM_ProtoStructAccess(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -883,7 +852,7 @@ void BM_ProtoListAccess(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); AttributeContext::Request request; auto* auth = request.mutable_auth(); @@ -897,7 +866,7 @@ void BM_ProtoListAccess(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } @@ -1010,7 +979,7 @@ void BM_NestedComprehension(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; cel::ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -1029,7 +998,7 @@ void BM_NestedComprehension(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len * len); } @@ -1051,7 +1020,7 @@ void BM_NestedComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -1069,9 +1038,8 @@ void BM_NestedComprehension_Trace(benchmark::State& state) { ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); for (auto _ : state) { - ASSERT_OK_AND_ASSIGN( - cel::Value result, - cel_expr->Trace(activation, &EmptyCallback, value_factory.get())); + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, &EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len * len); } @@ -1093,7 +1061,7 @@ void BM_ListComprehension(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -1109,7 +1077,7 @@ void BM_ListComprehension(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } @@ -1132,7 +1100,7 @@ void BM_ListComprehension_Trace(benchmark::State& state) { Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -1147,9 +1115,8 @@ void BM_ListComprehension_Trace(benchmark::State& state) { activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); for (auto _ : state) { - ASSERT_OK_AND_ASSIGN( - cel::Value result, - cel_expr->Trace(activation, EmptyCallback, value_factory.get())); + ASSERT_OK_AND_ASSIGN(cel::Value result, + cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } @@ -1170,7 +1137,7 @@ void BM_ListComprehension_Opt(benchmark::State& state) { Activation activation; ManagedValueFactory value_factory(runtime->GetTypeProvider(), - GetMemoryManagerForBenchmark(&arena)); + MemoryManager::Pooling(&arena)); ASSERT_OK_AND_ASSIGN( auto list_builder, @@ -1189,7 +1156,7 @@ void BM_ListComprehension_Opt(benchmark::State& state) { for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, - cel_expr->Evaluate(activation, value_factory.get())); + cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } diff --git a/eval/testutil/BUILD b/eval/testutil/BUILD index 034291962..5d80af860 100644 --- a/eval/testutil/BUILD +++ b/eval/testutil/BUILD @@ -22,13 +22,6 @@ cc_proto_library( deps = [":test_message_proto"], ) -proto_library( - name = "simple_test_message_proto", - srcs = [ - "simple_test_message.proto", - ], -) - proto_library( name = "test_extensions_proto", srcs = [ diff --git a/eval/testutil/args.proto b/eval/testutil/args.proto deleted file mode 100644 index f4ec6991e..000000000 --- a/eval/testutil/args.proto +++ /dev/null @@ -1,47 +0,0 @@ -syntax = "proto3"; - -package google.api.expr.runtime; -option cc_enable_arenas = true; - -// Message representing errors -// during CEL evaluation. -message Argument { - oneof arg_kind { - bool bool_value = 1; - int64 int64_value = 2; - uint64 uint64_value = 3; - - float float_value = 4; - double double_value = 5; - - string string_value = 6; - bytes bytes_value = 7; - - google.protobuf.Duration duration = 8; - google.protobuf.Timestamp timestamp = 9; - } - - TestMessage message_value = 12; - - repeated int32 int32_list = 101; - repeated int64 int64_list = 102; - - repeated uint32 uint32_list = 103; - repeated uint64 uint64_list = 104; - - repeated float float_list = 105; - repeated double double_list = 106; - - repeated string string_list = 107; - repeated string cord_list = 108 [ctype = CORD]; - repeated bytes bytes_list = 109; - - repeated bool bool_list = 110; - - repeated TestEnum enum_list = 111; - repeated TestMessage message_list = 112; - - map int64_int32_map = 201; - map uint64_int32_map = 202; - map string_int32_map = 203; -} diff --git a/eval/testutil/simple_test_message.proto b/eval/testutil/simple_test_message.proto deleted file mode 100644 index 27a822fbb..000000000 --- a/eval/testutil/simple_test_message.proto +++ /dev/null @@ -1,9 +0,0 @@ -syntax = "proto3"; - -package google.api.expr.runtime; - -// This has no dependencies on any other messages to keep the file descriptor -// set needed to parse this message simple. -message SimpleTestMessage { - int64 int64_value = 1; -} diff --git a/extensions/BUILD b/extensions/BUILD index ae34b1194..3b428eefd 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -353,7 +353,6 @@ cc_test( srcs = ["strings_test.cc"], deps = [ ":strings", - "//common:memory", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", @@ -368,5 +367,6 @@ cc_test( "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:cord", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 4e23471a0..4e78b51fd 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -130,12 +130,8 @@ cc_library( ], deps = [ "//common:type", - "//internal:status_macros", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", @@ -172,7 +168,6 @@ cc_library( ":type", "//base/internal:message_wrapper", "//common:allocator", - "//common:any", "//common:memory", "//common:type", "//common:value", @@ -190,7 +185,6 @@ cc_library( cc_test( name = "value_test", srcs = [ - "type_reflector_test.cc", "value_test.cc", ], deps = [ @@ -199,7 +193,6 @@ cc_test( "//base:attributes", "//common:casting", "//common:memory", - "//common:type", "//common:value", "//common:value_kind", "//common:value_testing", @@ -219,8 +212,6 @@ cc_test( srcs = ["value_end_to_end_test.cc"], deps = [ ":runtime_adapter", - ":value", - "//common:memory", "//common:value", "//common:value_testing", "//internal:testing", diff --git a/extensions/protobuf/type_reflector.cc b/extensions/protobuf/type_reflector.cc index b9994f1e5..d8ce2cd30 100644 --- a/extensions/protobuf/type_reflector.cc +++ b/extensions/protobuf/type_reflector.cc @@ -14,59 +14,21 @@ #include "extensions/protobuf/type_reflector.h" -#include "absl/base/nullability.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "common/any.h" -#include "common/memory.h" -#include "common/type.h" #include "common/value.h" #include "common/value_factory.h" -#include "common/values/struct_value_builder.h" -#include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace cel::extensions { -absl::StatusOr> -ProtoTypeReflector::NewStructValueBuilder(ValueFactory& value_factory, - const StructType& type) const { - auto status_or_builder = common_internal::NewStructValueBuilder( - value_factory.GetMemoryManager().arena(), descriptor_pool(), - message_factory(), type.name()); - if (!status_or_builder.ok() && absl::IsNotFound(status_or_builder.status())) { - return nullptr; - } - return status_or_builder; -} absl::StatusOr> ProtoTypeReflector::DeserializeValueImpl( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const { - absl::string_view type_name; - if (!ParseTypeUrl(type_url, &type_name)) { - return absl::InvalidArgumentError("invalid type URL"); - } - const auto* descriptor = descriptor_pool()->FindMessageTypeByName(type_name); - if (descriptor == nullptr) { - return absl::nullopt; - } - const auto* prototype = message_factory()->GetPrototype(descriptor); - if (prototype == nullptr) { - return absl::nullopt; - } - absl::Nullable arena = - value_factory.GetMemoryManager().arena(); - auto message = WrapShared(prototype->New(arena), arena); - if (!message->ParsePartialFromCord(value)) { - return absl::UnknownError( - absl::StrCat("failed to parse message: ", descriptor->full_name())); - } - return Value::Message(message, descriptor_pool(), message_factory()); + // This should not be reachable, as we provide both the pool and the factory + // which should trigger DeserializeValue to handle the call and not call us. + return absl::nullopt; } } // namespace cel::extensions diff --git a/extensions/protobuf/type_reflector.h b/extensions/protobuf/type_reflector.h index 0b49738e2..668d15e47 100644 --- a/extensions/protobuf/type_reflector.h +++ b/extensions/protobuf/type_reflector.h @@ -20,46 +20,32 @@ #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "common/value_factory.h" #include "extensions/protobuf/type_introspector.h" #include "google/protobuf/descriptor.h" -#include "google/protobuf/message.h" namespace cel::extensions { class ProtoTypeReflector : public TypeReflector, public ProtoTypeIntrospector { public: ProtoTypeReflector() - : ProtoTypeReflector(google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory()) {} + : ProtoTypeReflector(google::protobuf::DescriptorPool::generated_pool()) {} - ProtoTypeReflector( - absl::Nonnull descriptor_pool, - absl::Nonnull message_factory) - : ProtoTypeIntrospector(descriptor_pool), - message_factory_(message_factory) {} - - absl::StatusOr> NewStructValueBuilder( - ValueFactory& value_factory, const StructType& type) const final; + explicit ProtoTypeReflector( + absl::Nonnull descriptor_pool) + : ProtoTypeIntrospector(descriptor_pool) {} absl::Nonnull descriptor_pool() const override { return ProtoTypeIntrospector::descriptor_pool(); } - absl::Nonnull message_factory() const override { - return message_factory_; - } - private: absl::StatusOr> DeserializeValueImpl( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const final; - - absl::Nonnull const message_factory_; }; } // namespace cel::extensions diff --git a/extensions/protobuf/type_reflector_test.cc b/extensions/protobuf/type_reflector_test.cc deleted file mode 100644 index b56047b90..000000000 --- a/extensions/protobuf/type_reflector_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_reflector.h" - -#include "google/protobuf/wrappers.pb.h" -#include "absl/status/status.h" -#include "common/memory.h" -#include "common/type.h" -#include "common/type_reflector.h" -#include "common/value.h" -#include "common/value_testing.h" -#include "internal/testing.h" -#include "cel/expr/conformance/proto2/test_all_types.pb.h" - -namespace cel::extensions { -namespace { - -using ::absl_testing::StatusIs; -using ::cel::expr::conformance::proto2::TestAllTypes; -using ::testing::IsNull; -using ::testing::NotNull; - -class ProtoTypeReflectorTest - : public common_internal::ThreadCompatibleValueTest<> { - private: - Shared NewTypeReflector( - MemoryManagerRef memory_manager) override { - return memory_manager.MakeShared(); - } -}; - -TEST_P(ProtoTypeReflectorTest, NewStructValueBuilder_NoSuchType) { - ASSERT_OK_AND_ASSIGN( - auto builder, - value_manager().NewStructValueBuilder( - common_internal::MakeBasicStructType("message.that.does.not.Exist"))); - EXPECT_THAT(builder, IsNull()); -} - -TEST_P(ProtoTypeReflectorTest, NewStructValueBuilder_SetFieldByNumber) { - ASSERT_OK_AND_ASSIGN(auto builder, - value_manager().NewStructValueBuilder( - MessageType(TestAllTypes::descriptor()))); - ASSERT_THAT(builder, NotNull()); - EXPECT_THAT(builder->SetFieldByNumber(13, UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); -} - -TEST_P(ProtoTypeReflectorTest, NewStructValueBuilder_TypeConversionError) { - ASSERT_OK_AND_ASSIGN(auto builder, - value_manager().NewStructValueBuilder( - MessageType(TestAllTypes::descriptor()))); - ASSERT_THAT(builder, NotNull()); - EXPECT_THAT(builder->SetFieldByName("single_bool", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_int32", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_int64", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_uint32", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_uint64", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_float", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_double", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_string", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_bytes", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_bool_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_int32_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_int64_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_uint32_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_uint64_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_float_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_double_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_string_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("single_bytes_wrapper", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("null_value", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("repeated_bool", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(builder->SetFieldByName("map_bool_bool", UnknownValue{}), - StatusIs(absl::StatusCode::kInvalidArgument)); -} - -INSTANTIATE_TEST_SUITE_P( - ProtoTypeReflectorTest, ProtoTypeReflectorTest, - ::testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - ProtoTypeReflectorTest::ToString); - -} // namespace -} // namespace cel::extensions diff --git a/extensions/protobuf/value_end_to_end_test.cc b/extensions/protobuf/value_end_to_end_test.cc index 7e90347d1..1ff700fa7 100644 --- a/extensions/protobuf/value_end_to_end_test.cc +++ b/extensions/protobuf/value_end_to_end_test.cc @@ -15,7 +15,6 @@ // Functional tests for protobuf backed CEL structs in the default runtime. #include -#include #include #include #include @@ -23,11 +22,9 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "common/memory.h" #include "common/value.h" #include "common/value_testing.h" #include "extensions/protobuf/runtime_adapter.h" -#include "extensions/protobuf/value.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/activation.h" @@ -35,7 +32,9 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace cel::extensions { @@ -62,25 +61,28 @@ using ::google::api::expr::parser::Parse; using ::testing::_; using ::testing::AnyOf; using ::testing::HasSubstr; +using ::testing::TestWithParam; struct TestCase { std::string name; std::string expr; std::string msg_textproto; ValueMatcher matcher; -}; -std::ostream& operator<<(std::ostream& out, const TestCase& tc) { - return out << tc.name; -} + template + friend void AbslStringify(S& sink, const TestCase& tc) { + sink.Append(tc.name); + } +}; -class ProtobufValueEndToEndTest - : public common_internal::ThreadCompatibleValueTest { +class ProtobufValueEndToEndTest : public TestWithParam { public: ProtobufValueEndToEndTest() = default; protected: - const TestCase& test_case() const { return std::get<1>(GetParam()); } + const TestCase& test_case() const { return GetParam(); } + + google::protobuf::Arena arena_; }; TEST_P(ProtobufValueEndToEndTest, Runner) { @@ -89,11 +91,11 @@ TEST_P(ProtobufValueEndToEndTest, Runner) { ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case().msg_textproto, &message)); - ASSERT_OK_AND_ASSIGN(Value value, - ProtoMessageToValue(value_manager(), message)); - Activation activation; - activation.InsertOrAssignValue("msg", std::move(value)); + activation.InsertOrAssignValue( + "msg", + Value::Message(&arena_, message, google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); RuntimeOptions opts; opts.enable_empty_wrapper_null_unboxing = true; @@ -109,643 +111,622 @@ TEST_P(ProtobufValueEndToEndTest, Runner) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_manager())); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena_, activation)); EXPECT_THAT(result, test_case().matcher); } INSTANTIATE_TEST_SUITE_P( Singular, ProtobufValueEndToEndTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"single_int64", "msg.single_int64", - R"pb( - single_int64: 42 - )pb", - IntValueIs(42)}, - {"single_int64_has", "has(msg.single_int64)", - R"pb( - single_int64: 42 - )pb", - BoolValueIs(true)}, - {"single_int64_has_false", "has(msg.single_int64)", "", - BoolValueIs(false)}, - {"single_int32", "msg.single_int32", - R"pb( - single_int32: 42 - )pb", - IntValueIs(42)}, - {"single_uint64", "msg.single_uint64", - R"pb( - single_uint64: 42 - )pb", - UintValueIs(42)}, - {"single_uint32", "msg.single_uint32", - R"pb( - single_uint32: 42 - )pb", - UintValueIs(42)}, - {"single_sint64", "msg.single_sint64", - R"pb( - single_sint64: 42 - )pb", - IntValueIs(42)}, - {"single_sint32", "msg.single_sint32", - R"pb( - single_sint32: 42 - )pb", - IntValueIs(42)}, - {"single_fixed64", "msg.single_fixed64", - R"pb( - single_fixed64: 42 - )pb", - UintValueIs(42)}, - {"single_fixed32", "msg.single_fixed32", - R"pb( - single_fixed32: 42 - )pb", - UintValueIs(42)}, - {"single_sfixed64", "msg.single_sfixed64", - R"pb( - single_sfixed64: 42 - )pb", - IntValueIs(42)}, - {"single_sfixed32", "msg.single_sfixed32", - R"pb( - single_sfixed32: 42 - )pb", - IntValueIs(42)}, - {"single_float", "msg.single_float", - R"pb( - single_float: 4.25 - )pb", - DoubleValueIs(4.25)}, - {"single_double", "msg.single_double", - R"pb( - single_double: 4.25 - )pb", - DoubleValueIs(4.25)}, - {"single_bool", "msg.single_bool", - R"pb( - single_bool: true - )pb", - BoolValueIs(true)}, - {"single_string", "msg.single_string", - R"pb( - single_string: "Hello 😀" - )pb", - StringValueIs("Hello 😀")}, - {"single_bytes", "msg.single_bytes", - R"pb( - single_bytes: "Hello" - )pb", - BytesValueIs("Hello")}, - {"wkt_duration", "msg.single_duration", - R"pb( - single_duration { seconds: 10 } - )pb", - DurationValueIs(absl::Seconds(10))}, - {"wkt_duration_default", "msg.single_duration", "", - DurationValueIs(absl::Seconds(0))}, - {"wkt_timestamp", "msg.single_timestamp", - R"pb( - single_timestamp { seconds: 10 } - )pb", - TimestampValueIs(absl::FromUnixSeconds(10))}, - {"wkt_timestamp_default", "msg.single_timestamp", "", - TimestampValueIs(absl::UnixEpoch())}, - {"wkt_int64", "msg.single_int64_wrapper", - R"pb( - single_int64_wrapper { value: -20 } - )pb", - IntValueIs(-20)}, - {"wkt_int64_default", "msg.single_int64_wrapper", "", - IsNullValue()}, - {"wkt_int32", "msg.single_int32_wrapper", - R"pb( - single_int32_wrapper { value: -10 } - )pb", - IntValueIs(-10)}, - {"wkt_int32_default", "msg.single_int32_wrapper", "", - IsNullValue()}, - {"wkt_uint64", "msg.single_uint64_wrapper", - R"pb( - single_uint64_wrapper { value: 10 } - )pb", - UintValueIs(10)}, - {"wkt_uint64_default", "msg.single_uint64_wrapper", "", - IsNullValue()}, - {"wkt_uint32", "msg.single_uint32_wrapper", - R"pb( - single_uint32_wrapper { value: 11 } - )pb", - UintValueIs(11)}, - {"wkt_uint32_default", "msg.single_uint32_wrapper", "", - IsNullValue()}, - {"wkt_float", "msg.single_float_wrapper", - R"pb( - single_float_wrapper { value: 10.25 } - )pb", - DoubleValueIs(10.25)}, - {"wkt_float_default", "msg.single_float_wrapper", "", - IsNullValue()}, - {"wkt_double", "msg.single_double_wrapper", - R"pb( - single_double_wrapper { value: 10.25 } - )pb", - DoubleValueIs(10.25)}, - {"wkt_double_default", "msg.single_double_wrapper", "", - IsNullValue()}, - {"wkt_bool", "msg.single_bool_wrapper", - R"pb( - single_bool_wrapper { value: false } - )pb", - BoolValueIs(false)}, - {"wkt_bool_default", "msg.single_bool_wrapper", "", IsNullValue()}, - {"wkt_string", "msg.single_string_wrapper", - R"pb( - single_string_wrapper { value: "abcd" } - )pb", - StringValueIs("abcd")}, - {"wkt_string_default", "msg.single_string_wrapper", "", - IsNullValue()}, - {"wkt_bytes", "msg.single_bytes_wrapper", - R"pb( - single_bytes_wrapper { value: "abcd" } - )pb", - BytesValueIs("abcd")}, - {"wkt_bytes_default", "msg.single_bytes_wrapper", "", - IsNullValue()}, - {"wkt_null", "msg.null_value", - R"pb( - null_value: NULL_VALUE - )pb", - IsNullValue()}, - {"message_field", "msg.standalone_message", - R"pb( - standalone_message { bb: 2 } - )pb", - StructValueIs(_)}, - {"message_field_has", "has(msg.standalone_message)", - R"pb( - standalone_message { bb: 2 } - )pb", - BoolValueIs(true)}, - {"message_field_has_false", "has(msg.standalone_message)", "", - BoolValueIs(false)}, - {"single_enum", "msg.standalone_enum", - R"pb( - standalone_enum: BAR - )pb", - // BAR - IntValueIs(1)}})), - ProtobufValueEndToEndTest::ToString); + testing::ValuesIn(std::vector{ + {"single_int64", "msg.single_int64", + R"pb( + single_int64: 42 + )pb", + IntValueIs(42)}, + {"single_int64_has", "has(msg.single_int64)", + R"pb( + single_int64: 42 + )pb", + BoolValueIs(true)}, + {"single_int64_has_false", "has(msg.single_int64)", "", + BoolValueIs(false)}, + {"single_int32", "msg.single_int32", + R"pb( + single_int32: 42 + )pb", + IntValueIs(42)}, + {"single_uint64", "msg.single_uint64", + R"pb( + single_uint64: 42 + )pb", + UintValueIs(42)}, + {"single_uint32", "msg.single_uint32", + R"pb( + single_uint32: 42 + )pb", + UintValueIs(42)}, + {"single_sint64", "msg.single_sint64", + R"pb( + single_sint64: 42 + )pb", + IntValueIs(42)}, + {"single_sint32", "msg.single_sint32", + R"pb( + single_sint32: 42 + )pb", + IntValueIs(42)}, + {"single_fixed64", "msg.single_fixed64", + R"pb( + single_fixed64: 42 + )pb", + UintValueIs(42)}, + {"single_fixed32", "msg.single_fixed32", + R"pb( + single_fixed32: 42 + )pb", + UintValueIs(42)}, + {"single_sfixed64", "msg.single_sfixed64", + R"pb( + single_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"single_sfixed32", "msg.single_sfixed32", + R"pb( + single_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"single_float", "msg.single_float", + R"pb( + single_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_double", "msg.single_double", + R"pb( + single_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"single_bool", "msg.single_bool", + R"pb( + single_bool: true + )pb", + BoolValueIs(true)}, + {"single_string", "msg.single_string", + R"pb( + single_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"single_bytes", "msg.single_bytes", + R"pb( + single_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.single_duration", + R"pb( + single_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_duration_default", "msg.single_duration", "", + DurationValueIs(absl::Seconds(0))}, + {"wkt_timestamp", "msg.single_timestamp", + R"pb( + single_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_timestamp_default", "msg.single_timestamp", "", + TimestampValueIs(absl::UnixEpoch())}, + {"wkt_int64", "msg.single_int64_wrapper", + R"pb( + single_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int64_default", "msg.single_int64_wrapper", "", IsNullValue()}, + {"wkt_int32", "msg.single_int32_wrapper", + R"pb( + single_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_int32_default", "msg.single_int32_wrapper", "", IsNullValue()}, + {"wkt_uint64", "msg.single_uint64_wrapper", + R"pb( + single_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint64_default", "msg.single_uint64_wrapper", "", IsNullValue()}, + {"wkt_uint32", "msg.single_uint32_wrapper", + R"pb( + single_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_uint32_default", "msg.single_uint32_wrapper", "", IsNullValue()}, + {"wkt_float", "msg.single_float_wrapper", + R"pb( + single_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_float_default", "msg.single_float_wrapper", "", IsNullValue()}, + {"wkt_double", "msg.single_double_wrapper", + R"pb( + single_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double_default", "msg.single_double_wrapper", "", IsNullValue()}, + {"wkt_bool", "msg.single_bool_wrapper", + R"pb( + single_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_bool_default", "msg.single_bool_wrapper", "", IsNullValue()}, + {"wkt_string", "msg.single_string_wrapper", + R"pb( + single_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_string_default", "msg.single_string_wrapper", "", IsNullValue()}, + {"wkt_bytes", "msg.single_bytes_wrapper", + R"pb( + single_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_bytes_default", "msg.single_bytes_wrapper", "", IsNullValue()}, + {"wkt_null", "msg.null_value", + R"pb( + null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.standalone_message", + R"pb( + standalone_message { bb: 2 } + )pb", + StructValueIs(_)}, + {"message_field_has", "has(msg.standalone_message)", + R"pb( + standalone_message { bb: 2 } + )pb", + BoolValueIs(true)}, + {"message_field_has_false", "has(msg.standalone_message)", "", + BoolValueIs(false)}, + {"single_enum", "msg.standalone_enum", + R"pb( + standalone_enum: BAR + )pb", + // BAR + IntValueIs(1)}})); INSTANTIATE_TEST_SUITE_P( Repeated, ProtobufValueEndToEndTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"repeated_int64", "msg.repeated_int64[0]", - R"pb( - repeated_int64: 42 - )pb", - IntValueIs(42)}, - {"repeated_int64_has", "has(msg.repeated_int64)", - R"pb( - repeated_int64: 42 - )pb", - BoolValueIs(true)}, - {"repeated_int64_has_false", "has(msg.repeated_int64)", "", - BoolValueIs(false)}, - {"repeated_int32", "msg.repeated_int32[0]", - R"pb( - repeated_int32: 42 - )pb", - IntValueIs(42)}, - {"repeated_uint64", "msg.repeated_uint64[0]", - R"pb( - repeated_uint64: 42 - )pb", - UintValueIs(42)}, - {"repeated_uint32", "msg.repeated_uint32[0]", - R"pb( - repeated_uint32: 42 - )pb", - UintValueIs(42)}, - {"repeated_sint64", "msg.repeated_sint64[0]", - R"pb( - repeated_sint64: 42 - )pb", - IntValueIs(42)}, - {"repeated_sint32", "msg.repeated_sint32[0]", - R"pb( - repeated_sint32: 42 - )pb", - IntValueIs(42)}, - {"repeated_fixed64", "msg.repeated_fixed64[0]", - R"pb( - repeated_fixed64: 42 - )pb", - UintValueIs(42)}, - {"repeated_fixed32", "msg.repeated_fixed32[0]", - R"pb( - repeated_fixed32: 42 - )pb", - UintValueIs(42)}, - {"repeated_sfixed64", "msg.repeated_sfixed64[0]", - R"pb( - repeated_sfixed64: 42 - )pb", - IntValueIs(42)}, - {"repeated_sfixed32", "msg.repeated_sfixed32[0]", - R"pb( - repeated_sfixed32: 42 - )pb", - IntValueIs(42)}, - {"repeated_float", "msg.repeated_float[0]", - R"pb( - repeated_float: 4.25 - )pb", - DoubleValueIs(4.25)}, - {"repeated_double", "msg.repeated_double[0]", - R"pb( - repeated_double: 4.25 - )pb", - DoubleValueIs(4.25)}, - {"repeated_bool", "msg.repeated_bool[0]", - R"pb( - repeated_bool: true - )pb", - BoolValueIs(true)}, - {"repeated_string", "msg.repeated_string[0]", - R"pb( - repeated_string: "Hello 😀" - )pb", - StringValueIs("Hello 😀")}, - {"repeated_bytes", "msg.repeated_bytes[0]", - R"pb( - repeated_bytes: "Hello" - )pb", - BytesValueIs("Hello")}, - {"wkt_duration", "msg.repeated_duration[0]", - R"pb( - repeated_duration { seconds: 10 } - )pb", - DurationValueIs(absl::Seconds(10))}, - {"wkt_timestamp", "msg.repeated_timestamp[0]", - R"pb( - repeated_timestamp { seconds: 10 } - )pb", - TimestampValueIs(absl::FromUnixSeconds(10))}, - {"wkt_int64", "msg.repeated_int64_wrapper[0]", - R"pb( - repeated_int64_wrapper { value: -20 } - )pb", - IntValueIs(-20)}, - {"wkt_int32", "msg.repeated_int32_wrapper[0]", - R"pb( - repeated_int32_wrapper { value: -10 } - )pb", - IntValueIs(-10)}, - {"wkt_uint64", "msg.repeated_uint64_wrapper[0]", - R"pb( - repeated_uint64_wrapper { value: 10 } - )pb", - UintValueIs(10)}, - {"wkt_uint32", "msg.repeated_uint32_wrapper[0]", - R"pb( - repeated_uint32_wrapper { value: 11 } - )pb", - UintValueIs(11)}, - {"wkt_float", "msg.repeated_float_wrapper[0]", - R"pb( - repeated_float_wrapper { value: 10.25 } - )pb", - DoubleValueIs(10.25)}, - {"wkt_double", "msg.repeated_double_wrapper[0]", - R"pb( - repeated_double_wrapper { value: 10.25 } - )pb", - DoubleValueIs(10.25)}, - {"wkt_bool", "msg.repeated_bool_wrapper[0]", - R"pb( + testing::ValuesIn(std::vector{ + {"repeated_int64", "msg.repeated_int64[0]", + R"pb( + repeated_int64: 42 + )pb", + IntValueIs(42)}, + {"repeated_int64_has", "has(msg.repeated_int64)", + R"pb( + repeated_int64: 42 + )pb", + BoolValueIs(true)}, + {"repeated_int64_has_false", "has(msg.repeated_int64)", "", + BoolValueIs(false)}, + {"repeated_int32", "msg.repeated_int32[0]", + R"pb( + repeated_int32: 42 + )pb", + IntValueIs(42)}, + {"repeated_uint64", "msg.repeated_uint64[0]", + R"pb( + repeated_uint64: 42 + )pb", + UintValueIs(42)}, + {"repeated_uint32", "msg.repeated_uint32[0]", + R"pb( + repeated_uint32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sint64", "msg.repeated_sint64[0]", + R"pb( + repeated_sint64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sint32", "msg.repeated_sint32[0]", + R"pb( + repeated_sint32: 42 + )pb", + IntValueIs(42)}, + {"repeated_fixed64", "msg.repeated_fixed64[0]", + R"pb( + repeated_fixed64: 42 + )pb", + UintValueIs(42)}, + {"repeated_fixed32", "msg.repeated_fixed32[0]", + R"pb( + repeated_fixed32: 42 + )pb", + UintValueIs(42)}, + {"repeated_sfixed64", "msg.repeated_sfixed64[0]", + R"pb( + repeated_sfixed64: 42 + )pb", + IntValueIs(42)}, + {"repeated_sfixed32", "msg.repeated_sfixed32[0]", + R"pb( + repeated_sfixed32: 42 + )pb", + IntValueIs(42)}, + {"repeated_float", "msg.repeated_float[0]", + R"pb( + repeated_float: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_double", "msg.repeated_double[0]", + R"pb( + repeated_double: 4.25 + )pb", + DoubleValueIs(4.25)}, + {"repeated_bool", "msg.repeated_bool[0]", + R"pb( + repeated_bool: true + )pb", + BoolValueIs(true)}, + {"repeated_string", "msg.repeated_string[0]", + R"pb( + repeated_string: "Hello 😀" + )pb", + StringValueIs("Hello 😀")}, + {"repeated_bytes", "msg.repeated_bytes[0]", + R"pb( + repeated_bytes: "Hello" + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.repeated_duration[0]", + R"pb( + repeated_duration { seconds: 10 } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.repeated_timestamp[0]", + R"pb( + repeated_timestamp { seconds: 10 } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.repeated_int64_wrapper[0]", + R"pb( + repeated_int64_wrapper { value: -20 } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.repeated_int32_wrapper[0]", + R"pb( + repeated_int32_wrapper { value: -10 } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.repeated_uint64_wrapper[0]", + R"pb( + repeated_uint64_wrapper { value: 10 } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.repeated_uint32_wrapper[0]", + R"pb( + repeated_uint32_wrapper { value: 11 } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.repeated_float_wrapper[0]", + R"pb( + repeated_float_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.repeated_double_wrapper[0]", + R"pb( + repeated_double_wrapper { value: 10.25 } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.repeated_bool_wrapper[0]", + R"pb( - repeated_bool_wrapper { value: false } - )pb", - BoolValueIs(false)}, - {"wkt_string", "msg.repeated_string_wrapper[0]", - R"pb( - repeated_string_wrapper { value: "abcd" } - )pb", - StringValueIs("abcd")}, - {"wkt_bytes", "msg.repeated_bytes_wrapper[0]", - R"pb( - repeated_bytes_wrapper { value: "abcd" } - )pb", - BytesValueIs("abcd")}, - {"wkt_null", "msg.repeated_null_value[0]", - R"pb( - repeated_null_value: NULL_VALUE - )pb", - IsNullValue()}, - {"message_field", "msg.repeated_nested_message[0]", - R"pb( - repeated_nested_message { bb: 42 } - )pb", - StructValueIs(_)}, - {"repeated_enum", "msg.repeated_nested_enum[0]", - R"pb( - repeated_nested_enum: BAR - )pb", - // BAR - IntValueIs(1)}, - // Implements CEL list interface - {"repeated_size", "msg.repeated_int64.size()", - R"pb( - repeated_int64: 42 repeated_int64: 43 - )pb", - IntValueIs(2)}, - {"in_repeated", "42 in msg.repeated_int64", - R"pb( - repeated_int64: 42 repeated_int64: 43 - )pb", - BoolValueIs(true)}, - {"in_repeated_false", "44 in msg.repeated_int64", - R"pb( - repeated_int64: 42 repeated_int64: 43 - )pb", - BoolValueIs(false)}, - {"repeated_compre_exists", "msg.repeated_int64.exists(x, x > 42)", - R"pb( - repeated_int64: 42 repeated_int64: 43 - )pb", - BoolValueIs(true)}, - {"repeated_compre_map", "msg.repeated_int64.map(x, x * 2)[0]", - R"pb( - repeated_int64: 42 repeated_int64: 43 - )pb", - IntValueIs(84)}, - })), - ProtobufValueEndToEndTest::ToString); + repeated_bool_wrapper { value: false } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.repeated_string_wrapper[0]", + R"pb( + repeated_string_wrapper { value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.repeated_bytes_wrapper[0]", + R"pb( + repeated_bytes_wrapper { value: "abcd" } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.repeated_null_value[0]", + R"pb( + repeated_null_value: NULL_VALUE + )pb", + IsNullValue()}, + {"message_field", "msg.repeated_nested_message[0]", + R"pb( + repeated_nested_message { bb: 42 } + )pb", + StructValueIs(_)}, + {"repeated_enum", "msg.repeated_nested_enum[0]", + R"pb( + repeated_nested_enum: BAR + )pb", + // BAR + IntValueIs(1)}, + // Implements CEL list interface + {"repeated_size", "msg.repeated_int64.size()", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(2)}, + {"in_repeated", "42 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"in_repeated_false", "44 in msg.repeated_int64", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(false)}, + {"repeated_compre_exists", "msg.repeated_int64.exists(x, x > 42)", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + BoolValueIs(true)}, + {"repeated_compre_map", "msg.repeated_int64.map(x, x * 2)[0]", + R"pb( + repeated_int64: 42 repeated_int64: 43 + )pb", + IntValueIs(84)}, + })); INSTANTIATE_TEST_SUITE_P( Maps, ProtobufValueEndToEndTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"map_bool_int64", "msg.map_bool_int64[false]", - R"pb( - map_bool_int64 { key: false value: 42 } - )pb", - IntValueIs(42)}, - {"map_bool_int64_has", "has(msg.map_bool_int64)", - R"pb( - map_bool_int64 { key: false value: 42 } - )pb", - BoolValueIs(true)}, - {"map_bool_int64_has_false", "has(msg.map_bool_int64)", "", - BoolValueIs(false)}, - {"map_bool_int32", "msg.map_bool_int32[false]", - R"pb( - map_bool_int32 { key: false value: 42 } - )pb", - IntValueIs(42)}, - {"map_bool_uint64", "msg.map_bool_uint64[false]", - R"pb( - map_bool_uint64 { key: false value: 42 } - )pb", - UintValueIs(42)}, - {"map_bool_uint32", "msg.map_bool_uint32[false]", - R"pb( - map_bool_uint32 { key: false, value: 42 } - )pb", - UintValueIs(42)}, - {"map_bool_float", "msg.map_bool_float[false]", - R"pb( - map_bool_float { key: false value: 4.25 } - )pb", - DoubleValueIs(4.25)}, - {"map_bool_double", "msg.map_bool_double[false]", - R"pb( - map_bool_double { key: false value: 4.25 } - )pb", - DoubleValueIs(4.25)}, - {"map_bool_bool", "msg.map_bool_bool[false]", - R"pb( - map_bool_bool { key: false value: true } - )pb", - BoolValueIs(true)}, - {"map_bool_string", "msg.map_bool_string[false]", - R"pb( - map_bool_string { key: false value: "Hello 😀" } - )pb", - StringValueIs("Hello 😀")}, - {"map_bool_bytes", "msg.map_bool_bytes[false]", - R"pb( - map_bool_bytes { key: false value: "Hello" } - )pb", - BytesValueIs("Hello")}, - {"wkt_duration", "msg.map_bool_duration[false]", - R"pb( - map_bool_duration { - key: false - value { seconds: 10 } - } - )pb", - DurationValueIs(absl::Seconds(10))}, - {"wkt_timestamp", "msg.map_bool_timestamp[false]", - R"pb( - map_bool_timestamp { - key: false - value { seconds: 10 } - } - )pb", - TimestampValueIs(absl::FromUnixSeconds(10))}, - {"wkt_int64", "msg.map_bool_int64_wrapper[false]", - R"pb( - map_bool_int64_wrapper { - key: false - value { value: -20 } - } - )pb", - IntValueIs(-20)}, - {"wkt_int32", "msg.map_bool_int32_wrapper[false]", - R"pb( - map_bool_int32_wrapper { - key: false - value { value: -10 } - } - )pb", - IntValueIs(-10)}, - {"wkt_uint64", "msg.map_bool_uint64_wrapper[false]", - R"pb( - map_bool_uint64_wrapper { - key: false - value { value: 10 } - } - )pb", - UintValueIs(10)}, - {"wkt_uint32", "msg.map_bool_uint32_wrapper[false]", - R"pb( - map_bool_uint32_wrapper { - key: false - value { value: 11 } - } - )pb", - UintValueIs(11)}, - {"wkt_float", "msg.map_bool_float_wrapper[false]", - R"pb( - map_bool_float_wrapper { - key: false - value { value: 10.25 } - } - )pb", - DoubleValueIs(10.25)}, - {"wkt_double", "msg.map_bool_double_wrapper[false]", - R"pb( - map_bool_double_wrapper { - key: false - value { value: 10.25 } - } - )pb", - DoubleValueIs(10.25)}, - {"wkt_bool", "msg.map_bool_bool_wrapper[false]", - R"pb( - map_bool_bool_wrapper { - key: false - value { value: false } - } - )pb", - BoolValueIs(false)}, - {"wkt_string", "msg.map_bool_string_wrapper[false]", - R"pb( - map_bool_string_wrapper { - key: false - value { value: "abcd" } - } - )pb", - StringValueIs("abcd")}, - {"wkt_bytes", "msg.map_bool_bytes_wrapper[false]", - R"pb( - map_bool_bytes_wrapper { - key: false - value { value: "abcd" } - } - )pb", - BytesValueIs("abcd")}, - {"wkt_null", "msg.map_bool_null_value[false]", - R"pb( - map_bool_null_value { key: false value: NULL_VALUE } - )pb", - IsNullValue()}, - {"message_field", "msg.map_bool_message[false]", - R"pb( - map_bool_message { - key: false - value { bb: 42 } - } - )pb", - StructValueIs(_)}, - {"map_bool_enum", "msg.map_bool_enum[false]", - R"pb( - map_bool_enum { key: false value: BAR } - )pb", - // BAR - IntValueIs(1)}, - // Simplified for remaining key types. - {"map_int32_int64", "msg.map_int32_int64[42]", - R"pb( - map_int32_int64 { key: 42 value: -42 } - )pb", - IntValueIs(-42)}, - {"map_int64_int64", "msg.map_int64_int64[42]", - R"pb( - map_int64_int64 { key: 42 value: -42 } - )pb", - IntValueIs(-42)}, - {"map_uint32_int64", "msg.map_uint32_int64[42u]", - R"pb( - map_uint32_int64 { key: 42 value: -42 } - )pb", - IntValueIs(-42)}, - {"map_uint64_int64", "msg.map_uint64_int64[42u]", - R"pb( - map_uint64_int64 { key: 42 value: -42 } - )pb", - IntValueIs(-42)}, - {"map_string_int64", "msg.map_string_int64['key1']", - R"pb( - map_string_int64 { key: "key1" value: -42 } - )pb", - IntValueIs(-42)}, - // Implements CEL map - {"in_map_int64_true", "42 in msg.map_int64_int64", - R"pb( - map_int64_int64 { key: 42 value: -42 } - map_int64_int64 { key: 43 value: -43 } - )pb", - BoolValueIs(true)}, - {"in_map_int64_false", "44 in msg.map_int64_int64", - R"pb( - map_int64_int64 { key: 42 value: -42 } - map_int64_int64 { key: 43 value: -43 } - )pb", - BoolValueIs(false)}, - {"int_map_int64_compre_exists", - "msg.map_int64_int64.exists(key, key > 42)", - R"pb( - map_int64_int64 { key: 42 value: -42 } - map_int64_int64 { key: 43 value: -43 } - )pb", - BoolValueIs(true)}, - {"int_map_int64_compre_map", - "msg.map_int64_int64.map(key, key + 20)[0]", - R"pb( - map_int64_int64 { key: 42 value: -42 } - map_int64_int64 { key: 43 value: -43 } - )pb", + testing::ValuesIn(std::vector{ + {"map_bool_int64", "msg.map_bool_int64[false]", + R"pb( + map_bool_int64 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_int64_has", "has(msg.map_bool_int64)", + R"pb( + map_bool_int64 { key: false value: 42 } + )pb", + BoolValueIs(true)}, + {"map_bool_int64_has_false", "has(msg.map_bool_int64)", "", + BoolValueIs(false)}, + {"map_bool_int32", "msg.map_bool_int32[false]", + R"pb( + map_bool_int32 { key: false value: 42 } + )pb", + IntValueIs(42)}, + {"map_bool_uint64", "msg.map_bool_uint64[false]", + R"pb( + map_bool_uint64 { key: false value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_uint32", "msg.map_bool_uint32[false]", + R"pb( + map_bool_uint32 { key: false, value: 42 } + )pb", + UintValueIs(42)}, + {"map_bool_float", "msg.map_bool_float[false]", + R"pb( + map_bool_float { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_double", "msg.map_bool_double[false]", + R"pb( + map_bool_double { key: false value: 4.25 } + )pb", + DoubleValueIs(4.25)}, + {"map_bool_bool", "msg.map_bool_bool[false]", + R"pb( + map_bool_bool { key: false value: true } + )pb", + BoolValueIs(true)}, + {"map_bool_string", "msg.map_bool_string[false]", + R"pb( + map_bool_string { key: false value: "Hello 😀" } + )pb", + StringValueIs("Hello 😀")}, + {"map_bool_bytes", "msg.map_bool_bytes[false]", + R"pb( + map_bool_bytes { key: false value: "Hello" } + )pb", + BytesValueIs("Hello")}, + {"wkt_duration", "msg.map_bool_duration[false]", + R"pb( + map_bool_duration { + key: false + value { seconds: 10 } + } + )pb", + DurationValueIs(absl::Seconds(10))}, + {"wkt_timestamp", "msg.map_bool_timestamp[false]", + R"pb( + map_bool_timestamp { + key: false + value { seconds: 10 } + } + )pb", + TimestampValueIs(absl::FromUnixSeconds(10))}, + {"wkt_int64", "msg.map_bool_int64_wrapper[false]", + R"pb( + map_bool_int64_wrapper { + key: false + value { value: -20 } + } + )pb", + IntValueIs(-20)}, + {"wkt_int32", "msg.map_bool_int32_wrapper[false]", + R"pb( + map_bool_int32_wrapper { + key: false + value { value: -10 } + } + )pb", + IntValueIs(-10)}, + {"wkt_uint64", "msg.map_bool_uint64_wrapper[false]", + R"pb( + map_bool_uint64_wrapper { + key: false + value { value: 10 } + } + )pb", + UintValueIs(10)}, + {"wkt_uint32", "msg.map_bool_uint32_wrapper[false]", + R"pb( + map_bool_uint32_wrapper { + key: false + value { value: 11 } + } + )pb", + UintValueIs(11)}, + {"wkt_float", "msg.map_bool_float_wrapper[false]", + R"pb( + map_bool_float_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_double", "msg.map_bool_double_wrapper[false]", + R"pb( + map_bool_double_wrapper { + key: false + value { value: 10.25 } + } + )pb", + DoubleValueIs(10.25)}, + {"wkt_bool", "msg.map_bool_bool_wrapper[false]", + R"pb( + map_bool_bool_wrapper { + key: false + value { value: false } + } + )pb", + BoolValueIs(false)}, + {"wkt_string", "msg.map_bool_string_wrapper[false]", + R"pb( + map_bool_string_wrapper { + key: false + value { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + {"wkt_bytes", "msg.map_bool_bytes_wrapper[false]", + R"pb( + map_bool_bytes_wrapper { + key: false + value { value: "abcd" } + } + )pb", + BytesValueIs("abcd")}, + {"wkt_null", "msg.map_bool_null_value[false]", + R"pb( + map_bool_null_value { key: false value: NULL_VALUE } + )pb", + IsNullValue()}, + {"message_field", "msg.map_bool_message[false]", + R"pb( + map_bool_message { + key: false + value { bb: 42 } + } + )pb", + StructValueIs(_)}, + {"map_bool_enum", "msg.map_bool_enum[false]", + R"pb( + map_bool_enum { key: false value: BAR } + )pb", + // BAR + IntValueIs(1)}, + // Simplified for remaining key types. + {"map_int32_int64", "msg.map_int32_int64[42]", + R"pb( + map_int32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_int64_int64", "msg.map_int64_int64[42]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint32_int64", "msg.map_uint32_int64[42u]", + R"pb( + map_uint32_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_uint64_int64", "msg.map_uint64_int64[42u]", + R"pb( + map_uint64_int64 { key: 42 value: -42 } + )pb", + IntValueIs(-42)}, + {"map_string_int64", "msg.map_string_int64['key1']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + IntValueIs(-42)}, + // Implements CEL map + {"in_map_int64_true", "42 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"in_map_int64_false", "44 in msg.map_int64_int64", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(false)}, + {"int_map_int64_compre_exists", + "msg.map_int64_int64.exists(key, key > 42)", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", + BoolValueIs(true)}, + {"int_map_int64_compre_map", + "msg.map_int64_int64.map(key, key + 20)[0]", + R"pb( + map_int64_int64 { key: 42 value: -42 } + map_int64_int64 { key: 43 value: -43 } + )pb", - IntValueIs(AnyOf(62, 63))}, - {"map_string_key_not_found", "msg.map_string_int64['key2']", - R"pb( - map_string_int64 { key: "key1" value: -42 } - )pb", - ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, - HasSubstr("Key not found in map")))}, - {"map_string_select_key", "msg.map_string_int64.key1", - R"pb( - map_string_int64 { key: "key1" value: -42 } - )pb", - IntValueIs(-42)}, - {"map_string_has_key", "has(msg.map_string_int64.key1)", - R"pb( - map_string_int64 { key: "key1" value: -42 } - )pb", - BoolValueIs(true)}, - {"map_string_has_key_false", "has(msg.map_string_int64.key2)", - R"pb( - map_string_int64 { key: "key1" value: -42 } - )pb", - BoolValueIs(false)}, - {"map_int32_out_of_range", "msg.map_int32_int64[0x1FFFFFFFF]", - R"pb( - map_int32_int64 { key: 10 value: -42 } - )pb", - ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, - HasSubstr("Key not found in map")))}, - {"map_uint32_out_of_range", "msg.map_uint32_int64[0x1FFFFFFFFu]", - R"pb( - map_uint32_int64 { key: 10 value: -42 } - )pb", - ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, - HasSubstr("Key not found in map")))}})), - ProtobufValueEndToEndTest::ToString); + IntValueIs(AnyOf(62, 63))}, + {"map_string_key_not_found", "msg.map_string_int64['key2']", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}, + {"map_string_select_key", "msg.map_string_int64.key1", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + IntValueIs(-42)}, + {"map_string_has_key", "has(msg.map_string_int64.key1)", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + BoolValueIs(true)}, + {"map_string_has_key_false", "has(msg.map_string_int64.key2)", + R"pb( + map_string_int64 { key: "key1" value: -42 } + )pb", + BoolValueIs(false)}, + {"map_int32_out_of_range", "msg.map_int32_int64[0x1FFFFFFFF]", + R"pb( + map_int32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}, + {"map_uint32_out_of_range", "msg.map_uint32_int64[0x1FFFFFFFFu]", + R"pb( + map_uint32_int64 { key: 10 value: -42 } + )pb", + ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Key not found in map")))}})); MATCHER_P(CelSizeIs, size, "") { auto s = arg.Size(); @@ -754,367 +735,353 @@ MATCHER_P(CelSizeIs, size, "") { INSTANTIATE_TEST_SUITE_P( JsonWrappers, ProtobufValueEndToEndTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"single_struct", "msg.single_struct", - R"pb( - single_struct { - fields { - key: "field1" - value { null_value: NULL_VALUE } - } - } - )pb", - MapValueIs(CelSizeIs(1))}, - {"single_struct_null_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { null_value: NULL_VALUE } - } - } - )pb", - IsNullValue()}, - {"single_struct_number_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { number_value: 10.25 } - } - } - )pb", - DoubleValueIs(10.25)}, - {"single_struct_bool_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - } - )pb", - BoolValueIs(true)}, - {"single_struct_string_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { string_value: "abcd" } - } - } - )pb", - StringValueIs("abcd")}, - {"single_struct_struct_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { - struct_value { - fields { - key: "field2", - value: { null_value: NULL_VALUE } - } - } + testing::ValuesIn(std::vector{ + {"single_struct", "msg.single_struct", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_null_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { null_value: NULL_VALUE } + } + } + )pb", + IsNullValue()}, + {"single_struct_number_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { number_value: 10.25 } + } + } + )pb", + DoubleValueIs(10.25)}, + {"single_struct_bool_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_string_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { string_value: "abcd" } + } + } + )pb", + StringValueIs("abcd")}, + {"single_struct_struct_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { + struct_value { + fields { + key: "field2", + value: { null_value: NULL_VALUE } } } } - )pb", - MapValueIs(CelSizeIs(1))}, - {"single_struct_list_value_field", "msg.single_struct['field1']", - R"pb( - single_struct { - fields { - key: "field1" - value { list_value { values { null_value: NULL_VALUE } } } - } - } - )pb", - ListValueIs(CelSizeIs(1))}, - {"single_struct_select_field", "msg.single_struct.field1", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - } - )pb", - BoolValueIs(true)}, - {"single_struct_has_field", "has(msg.single_struct.field1)", - R"pb( - single_struct { + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_struct_list_value_field", "msg.single_struct['field1']", + R"pb( + single_struct { + fields { + key: "field1" + value { list_value { values { null_value: NULL_VALUE } } } + } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_struct_select_field", "msg.single_struct.field1", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_has_field", "has(msg.single_struct.field1)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_has_field_false", "has(msg.single_struct.field2)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + } + )pb", + BoolValueIs(false)}, + {"single_struct_map_size", "msg.single_struct.size()", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + IntValueIs(2)}, + {"single_struct_map_in", "'field2' in msg.single_struct", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_map_compre_exists", + "msg.single_struct.exists(key, key == 'field2')", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_struct_map_compre_map", + "'__field1' in msg.single_struct.map(key, '__' + key)", + R"pb( + single_struct { + fields { + key: "field1" + value { bool_value: true } + } + fields { + key: "field2" + value { bool_value: false } + } + } + )pb", + BoolValueIs(true)}, + {"single_list_value", "msg.list_value", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_list_value_index_null", "msg.list_value[0]", + R"pb( + list_value { values { null_value: NULL_VALUE } } + )pb", + IsNullValue()}, + {"single_list_value_index_number", "msg.list_value[0]", + R"pb( + list_value { values { number_value: 10.25 } } + )pb", + DoubleValueIs(10.25)}, + {"single_list_value_index_string", "msg.list_value[0]", + R"pb( + list_value { values { string_value: "abc" } } + )pb", + StringValueIs("abc")}, + {"single_list_value_index_bool", "msg.list_value[0]", + R"pb( + list_value { values { bool_value: false } } + )pb", + BoolValueIs(false)}, + {"single_list_value_list_size", "msg.list_value.size()", + R"pb( + list_value { + values { bool_value: false } + values { bool_value: false } + } + )pb", + IntValueIs(2)}, + {"single_list_value_list_in", "10.25 in msg.list_value", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + BoolValueIs(true)}, + {"single_list_value_list_compre_exists", + "msg.list_value.exists(x, x == 10.25)", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + BoolValueIs(true)}, + {"single_list_value_list_compre_map", + "msg.list_value.map(x, x + 0.5)[1]", + R"pb( + list_value { + values { number_value: 10.0 } + values { number_value: 10.25 } + } + )pb", + DoubleValueIs(10.75)}, + {"single_list_value_index_struct", "msg.list_value[0]", + R"pb( + list_value { + values { + struct_value { fields { key: "field1" - value { bool_value: true } - } - } - )pb", - BoolValueIs(true)}, - {"single_struct_has_field_false", "has(msg.single_struct.field2)", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - } - )pb", - BoolValueIs(false)}, - {"single_struct_map_size", "msg.single_struct.size()", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - fields { - key: "field2" - value { bool_value: false } - } - } - )pb", - IntValueIs(2)}, - {"single_struct_map_in", "'field2' in msg.single_struct", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - fields { - key: "field2" - value { bool_value: false } - } - } - )pb", - BoolValueIs(true)}, - {"single_struct_map_compre_exists", - "msg.single_struct.exists(key, key == 'field2')", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - fields { - key: "field2" - value { bool_value: false } - } - } - )pb", - BoolValueIs(true)}, - {"single_struct_map_compre_map", - "'__field1' in msg.single_struct.map(key, '__' + key)", - R"pb( - single_struct { - fields { - key: "field1" - value { bool_value: true } - } - fields { - key: "field2" - value { bool_value: false } - } - } - )pb", - BoolValueIs(true)}, - {"single_list_value", "msg.list_value", - R"pb( - list_value { values { null_value: NULL_VALUE } } - )pb", - ListValueIs(CelSizeIs(1))}, - {"single_list_value_index_null", "msg.list_value[0]", - R"pb( - list_value { values { null_value: NULL_VALUE } } - )pb", - IsNullValue()}, - {"single_list_value_index_number", "msg.list_value[0]", - R"pb( - list_value { values { number_value: 10.25 } } - )pb", - DoubleValueIs(10.25)}, - {"single_list_value_index_string", "msg.list_value[0]", - R"pb( - list_value { values { string_value: "abc" } } - )pb", - StringValueIs("abc")}, - {"single_list_value_index_bool", "msg.list_value[0]", - R"pb( - list_value { values { bool_value: false } } - )pb", - BoolValueIs(false)}, - {"single_list_value_list_size", "msg.list_value.size()", - R"pb( - list_value { - values { bool_value: false } - values { bool_value: false } - } - )pb", - IntValueIs(2)}, - {"single_list_value_list_in", "10.25 in msg.list_value", - R"pb( - list_value { - values { number_value: 10.0 } - values { number_value: 10.25 } - } - )pb", - BoolValueIs(true)}, - {"single_list_value_list_compre_exists", - "msg.list_value.exists(x, x == 10.25)", - R"pb( - list_value { - values { number_value: 10.0 } - values { number_value: 10.25 } - } - )pb", - BoolValueIs(true)}, - {"single_list_value_list_compre_map", - "msg.list_value.map(x, x + 0.5)[1]", - R"pb( - list_value { - values { number_value: 10.0 } - values { number_value: 10.25 } - } - )pb", - DoubleValueIs(10.75)}, - {"single_list_value_index_struct", "msg.list_value[0]", - R"pb( - list_value { - values { - struct_value { - fields { - key: "field1" - value { null_value: NULL_VALUE } - } - } + value { null_value: NULL_VALUE } } } - )pb", - MapValueIs(CelSizeIs(1))}, - {"single_list_value_index_list", "msg.list_value[0]", - R"pb( - list_value { - values { list_value { values { null_value: NULL_VALUE } } } - } - )pb", - ListValueIs(CelSizeIs(1))}, - {"single_json_value_null", "msg.single_value", - R"pb( - single_value { null_value: NULL_VALUE } - )pb", - IsNullValue()}, - {"single_json_value_number", "msg.single_value", - R"pb( - single_value { number_value: 13.25 } - )pb", - DoubleValueIs(13.25)}, - {"single_json_value_string", "msg.single_value", - R"pb( - single_value { string_value: "abcd" } - )pb", - StringValueIs("abcd")}, - {"single_json_value_bool", "msg.single_value", - R"pb( - single_value { bool_value: false } - )pb", - BoolValueIs(false)}, - {"single_json_value_struct", "msg.single_value", - R"pb( - single_value { struct_value {} } - )pb", - MapValueIs(CelSizeIs(0))}, - {"single_json_value_list", "msg.single_value", - R"pb( - single_value { list_value {} } - )pb", - ListValueIs(CelSizeIs(0))}, - })), - ProtobufValueEndToEndTest::ToString); + } + } + )pb", + MapValueIs(CelSizeIs(1))}, + {"single_list_value_index_list", "msg.list_value[0]", + R"pb( + list_value { + values { list_value { values { null_value: NULL_VALUE } } } + } + )pb", + ListValueIs(CelSizeIs(1))}, + {"single_json_value_null", "msg.single_value", + R"pb( + single_value { null_value: NULL_VALUE } + )pb", + IsNullValue()}, + {"single_json_value_number", "msg.single_value", + R"pb( + single_value { number_value: 13.25 } + )pb", + DoubleValueIs(13.25)}, + {"single_json_value_string", "msg.single_value", + R"pb( + single_value { string_value: "abcd" } + )pb", + StringValueIs("abcd")}, + {"single_json_value_bool", "msg.single_value", + R"pb( + single_value { bool_value: false } + )pb", + BoolValueIs(false)}, + {"single_json_value_struct", "msg.single_value", + R"pb( + single_value { struct_value {} } + )pb", + MapValueIs(CelSizeIs(0))}, + {"single_json_value_list", "msg.single_value", + R"pb( + single_value { list_value {} } + )pb", + ListValueIs(CelSizeIs(0))}, + })); // TODO: any support needs the reflection impl for looking up the // type name and corresponding deserializer (outside of the WKTs which are // special cased). INSTANTIATE_TEST_SUITE_P( Any, ProtobufValueEndToEndTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"single_any_wkt_int64", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } - } - )pb", - IntValueIs(42)}, - {"single_any_wkt_int32", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.Int32Value] { value: 42 } - } - )pb", - IntValueIs(42)}, - {"single_any_wkt_uint64", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.UInt64Value] { value: 42 } - } - )pb", - UintValueIs(42)}, - {"single_any_wkt_uint32", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.UInt32Value] { value: 42 } - } - )pb", - UintValueIs(42)}, - {"single_any_wkt_double", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.DoubleValue] { - value: 30.5 - } - } - )pb", - DoubleValueIs(30.5)}, - {"single_any_wkt_string", "msg.single_any", - R"pb( - single_any { - [type.googleapis.com/google.protobuf.StringValue] { - value: "abcd" - } - } - )pb", - StringValueIs("abcd")}, + testing::ValuesIn(std::vector{ + {"single_any_wkt_int64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_int32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.Int32Value] { value: 42 } + } + )pb", + IntValueIs(42)}, + {"single_any_wkt_uint64", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt64Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_uint32", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.UInt32Value] { value: 42 } + } + )pb", + UintValueIs(42)}, + {"single_any_wkt_double", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.DoubleValue] { value: 30.5 } + } + )pb", + DoubleValueIs(30.5)}, + {"single_any_wkt_string", "msg.single_any", + R"pb( + single_any { + [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, - {"repeated_any_wkt_string", "msg.repeated_any[0]", - R"pb( - repeated_any { - [type.googleapis.com/google.protobuf.StringValue] { - value: "abcd" - } - } - )pb", - StringValueIs("abcd")}, - {"map_int64_any_wkt_string", "msg.map_int64_any[0]", - R"pb( - map_int64_any { - key: 0 - value { - [type.googleapis.com/google.protobuf.StringValue] { - value: "abcd" - } - } + {"repeated_any_wkt_string", "msg.repeated_any[0]", + R"pb( + repeated_any { + [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } + } + )pb", + StringValueIs("abcd")}, + {"map_int64_any_wkt_string", "msg.map_int64_any[0]", + R"pb( + map_int64_any { + key: 0 + value { + [type.googleapis.com/google.protobuf.StringValue] { + value: "abcd" } - )pb", - StringValueIs("abcd")}, - })), - ProtobufValueEndToEndTest::ToString); + } + } + )pb", + StringValueIs("abcd")}, + })); } // namespace } // namespace cel::extensions diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index 8174f4e66..16e48b113 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -20,9 +20,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" -#include "common/memory.h" #include "common/value.h" -#include "common/values/legacy_value_manager.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -33,6 +31,7 @@ #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" namespace cel::extensions { namespace { @@ -43,7 +42,7 @@ using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; TEST(Strings, SplitWithEmptyDelimiterCord) { - MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( @@ -61,21 +60,17 @@ TEST(Strings, SplitWithEmptyDelimiterCord) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello world!")}); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.Is()); EXPECT_TRUE(result.GetBool().NativeValue()); } TEST(Strings, Replace) { - MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( @@ -92,20 +87,16 @@ TEST(Strings, Replace) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")}); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.Is()); EXPECT_TRUE(result.GetBool().NativeValue()); } TEST(Strings, ReplaceWithNegativeLimit) { - MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( @@ -122,20 +113,16 @@ TEST(Strings, ReplaceWithNegativeLimit) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")}); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.Is()); EXPECT_TRUE(result.GetBool().NativeValue()); } TEST(Strings, ReplaceWithLimit) { - MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( @@ -152,20 +139,16 @@ TEST(Strings, ReplaceWithLimit) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")}); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.Is()); EXPECT_TRUE(result.GetBool().NativeValue()); } TEST(Strings, ReplaceWithZeroLimit) { - MemoryManagerRef memory_manager = MemoryManagerRef::ReferenceCounting(); + google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( @@ -182,14 +165,10 @@ TEST(Strings, ReplaceWithZeroLimit) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; activation.InsertOrAssignValue("foo", StringValue{absl::Cord("hello hello")}); - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.Is()); EXPECT_TRUE(result.GetBool().NativeValue()); } diff --git a/internal/BUILD b/internal/BUILD index fa833c1ae..0c7b1e6e0 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -118,12 +118,7 @@ cc_library( name = "benchmark", testonly = True, hdrs = ["benchmark.h"], - deps = [ - "@com_github_google_benchmark//:benchmark_main", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - ], + deps = ["@com_github_google_benchmark//:benchmark_main"], ) cc_library( @@ -177,7 +172,6 @@ cc_test( deps = [ ":number", ":testing", - "@com_google_absl//absl/types:optional", ], ) @@ -194,7 +188,6 @@ cc_library( ":status_builder", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], ) @@ -539,6 +532,8 @@ cel_proto_transitive_descriptor_set( name = "testing_descriptor_set", testonly = True, deps = [ + "//eval/testutil:test_extensions_proto", + "//eval/testutil:test_message_proto", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:eval_proto", "@com_google_cel_spec//proto/cel/expr:explain_proto", diff --git a/runtime/BUILD b/runtime/BUILD index 1e10f3d4a..91b8382b2 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -159,10 +159,13 @@ cc_library( deps = [ "//base:data", "//common:type", - "//runtime/internal:composed_type_provider", + "//runtime/internal:legacy_runtime_type_provider", + "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -176,9 +179,12 @@ cc_library( "//base:data", "//common:native_type", "//common:value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -329,7 +335,6 @@ cc_test( deps = [ ":activation", ":constant_folding", - ":managed_value_factory", ":register_function_helper", ":runtime_builder", ":runtime_options", @@ -345,6 +350,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -374,7 +380,6 @@ cc_test( deps = [ ":activation", ":constant_folding", - ":managed_value_factory", ":regex_precompilation", ":register_function_helper", ":runtime_builder", @@ -391,6 +396,7 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -419,7 +425,6 @@ cc_test( srcs = ["reference_resolver_test.cc"], deps = [ ":activation", - ":managed_value_factory", ":reference_resolver", ":register_function_helper", ":runtime_builder", @@ -559,10 +564,8 @@ cc_test( "//base:function", "//base:function_descriptor", "//common:kind", - "//common:memory", "//common:value", "//common:value_testing", - "//extensions/protobuf:memory_manager", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", @@ -570,6 +573,7 @@ cc_test( "//parser:options", "//runtime/internal:runtime_impl", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc index f579cb400..775503ebe 100644 --- a/runtime/constant_folding_test.cc +++ b/runtime/constant_folding_test.cc @@ -30,11 +30,11 @@ #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" -#include "runtime/managed_value_factory.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" namespace cel::extensions { namespace { @@ -74,6 +74,7 @@ MATCHER_P(IsErrorValue, expected_substr, "") { class ConstantFoldingExtTest : public testing::TestWithParam {}; TEST_P(ConstantFoldingExtTest, Runner) { + google::protobuf::Arena arena; RuntimeOptions options; const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, @@ -99,12 +100,9 @@ TEST_P(ConstantFoldingExtTest, Runner) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); - - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); Activation activation; - auto result = program->Evaluate(activation, value_factory.get()); + auto result = program->Evaluate(&arena, activation); if (test_case.status.ok()) { ASSERT_OK_AND_ASSIGN(Value value, std::move(result)); diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 33b5c22b0..21500ca1a 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -19,26 +19,6 @@ package( licenses(["notice"]) -cc_library( - name = "composed_type_provider", - srcs = ["composed_type_provider.cc"], - hdrs = ["composed_type_provider.h"], - deps = [ - "//base:data", - "//common:type", - "//common:value", - "//internal:status_macros", - "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "runtime_friend_access", hdrs = ["runtime_friend_access.h"], @@ -74,6 +54,7 @@ cc_library( hdrs = ["runtime_impl.h"], deps = [ ":runtime_env", + ":runtime_value_manager", "//base:ast", "//base:data", "//common:native_type", @@ -95,6 +76,7 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -104,6 +86,7 @@ cc_library( hdrs = ["convert_constant.h"], deps = [ "//base/ast_internal:expr", + "//common:allocator", "//common:constant", "//common:value", "//eval/internal:errors", @@ -185,8 +168,55 @@ cc_library( hdrs = ["runtime_env_testing.h"], deps = [ ":runtime_env", + "//internal:noop_delete", "//internal:testing_descriptor_pool", + "//internal:testing_message_factory", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_value_manager", + hdrs = ["runtime_value_manager.h"], + deps = [ + "//common:memory", + "//common:type", + "//common:value", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "legacy_runtime_type_provider", + hdrs = ["legacy_runtime_type_provider.h"], + deps = [ + "//eval/public/structs:protobuf_descriptor_type_provider", + "@com_google_absl//absl/base:nullability", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "runtime_type_provider", + srcs = ["runtime_type_provider.cc"], + hdrs = ["runtime_type_provider.h"], + deps = [ + "//common:any", + "//common:memory", + "//common:type", + "//common:value", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) diff --git a/runtime/internal/composed_type_provider.cc b/runtime/internal/composed_type_provider.cc deleted file mode 100644 index 65542ac04..000000000 --- a/runtime/internal/composed_type_provider.cc +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "runtime/internal/composed_type_provider.h" - -#include -#include - -#include "absl/base/nullability.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_reflector.h" -#include "common/value.h" -#include "common/value_factory.h" -#include "internal/status_macros.h" - -namespace cel::runtime_internal { - -absl::Status ComposedTypeProvider::RegisterType(const OpaqueType& type) { - auto insertion = types_.insert(std::pair{type.name(), Type(type)}); - if (!insertion.second) { - return absl::AlreadyExistsError( - absl::StrCat("type already registered: ", insertion.first->first)); - } - return absl::OkStatus(); -} - -absl::StatusOr> -ComposedTypeProvider::NewStructValueBuilder(ValueFactory& value_factory, - const StructType& type) const { - for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto builder, - provider->NewStructValueBuilder(value_factory, type)); - if (builder != nullptr) { - return builder; - } - } - return nullptr; -} - -absl::StatusOr ComposedTypeProvider::FindValue( - ValueFactory& value_factory, absl::string_view name, Value& result) const { - for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto value, - provider->FindValue(value_factory, name, result)); - if (value) { - return value; - } - } - return false; -} - -absl::StatusOr> -ComposedTypeProvider::DeserializeValueImpl(ValueFactory& value_factory, - absl::string_view type_url, - const absl::Cord& value) const { - for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto result, provider->DeserializeValue( - value_factory, type_url, value)); - if (result.has_value()) { - return result; - } - } - return absl::nullopt; -} - -absl::StatusOr> ComposedTypeProvider::FindTypeImpl( - absl::string_view name) const { - if (auto type = types_.find(name); type != types_.end()) { - return type->second; - } - for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto result, provider->FindType(name)); - if (result.has_value()) { - return result; - } - } - return absl::nullopt; -} - -absl::StatusOr> -ComposedTypeProvider::FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const { - for (const std::unique_ptr& provider : providers_) { - CEL_ASSIGN_OR_RETURN(auto result, - provider->FindStructTypeFieldByName(type, name)); - if (result.has_value()) { - return result; - } - } - return absl::nullopt; -} - -} // namespace cel::runtime_internal diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc index a70531334..b4b1ed4a5 100644 --- a/runtime/internal/convert_constant.cc +++ b/runtime/internal/convert_constant.cc @@ -15,16 +15,15 @@ #include "runtime/internal/convert_constant.h" #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "base/ast_internal/expr.h" +#include "common/allocator.h" #include "common/constant.h" #include "common/value.h" -#include "common/value_manager.h" #include "eval/internal/errors.h" namespace cel::runtime_internal { @@ -32,52 +31,51 @@ namespace { using ::cel::ast_internal::Constant; struct ConvertVisitor { - cel::ValueManager& value_factory; + Allocator<> allocator; absl::StatusOr operator()(absl::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()( const cel::ast_internal::NullValue& value) { - return value_factory.GetNullValue(); - } - absl::StatusOr operator()(bool value) { - return value_factory.CreateBoolValue(value); + return NullValue(); } + absl::StatusOr operator()(bool value) { return BoolValue(value); } absl::StatusOr operator()(int64_t value) { - return value_factory.CreateIntValue(value); + return IntValue(value); } absl::StatusOr operator()(uint64_t value) { - return value_factory.CreateUintValue(value); + return UintValue(value); } absl::StatusOr operator()(double value) { - return value_factory.CreateDoubleValue(value); + return DoubleValue(value); } absl::StatusOr operator()(const cel::StringConstant& value) { - return value_factory.CreateUncheckedStringValue(value); + return StringValue(allocator, value); } absl::StatusOr operator()(const cel::BytesConstant& value) { - return value_factory.CreateBytesValue(value); + return BytesValue(allocator, value); } absl::StatusOr operator()(const absl::Duration duration) { if (duration >= kDurationHigh || duration <= kDurationLow) { - return value_factory.CreateErrorValue(*DurationOverflowError()); + return ErrorValue(*DurationOverflowError()); } - return value_factory.CreateUncheckedDurationValue(duration); + return DurationValue(duration); } absl::StatusOr operator()(const absl::Time timestamp) { - return value_factory.CreateUncheckedTimestampValue(timestamp); + return TimestampValue(timestamp); } }; } // namespace + // Converts an Ast constant into a runtime value, managed according to the // given value factory. // // A status maybe returned if value creation fails. absl::StatusOr ConvertConstant(const Constant& constant, - ValueManager& value_factory) { - return absl::visit(ConvertVisitor{value_factory}, constant.constant_kind()); + Allocator<> allocator) { + return absl::visit(ConvertVisitor{allocator}, constant.constant_kind()); } } // namespace cel::runtime_internal diff --git a/runtime/internal/convert_constant.h b/runtime/internal/convert_constant.h index ae51ba63b..b0d445f6c 100644 --- a/runtime/internal/convert_constant.h +++ b/runtime/internal/convert_constant.h @@ -16,8 +16,8 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "common/allocator.h" #include "common/value.h" -#include "common/value_manager.h" namespace cel::runtime_internal { @@ -32,7 +32,7 @@ namespace cel::runtime_internal { // A status may still be returned if value creation fails according to // value_factory's policy. absl::StatusOr ConvertConstant(const ast_internal::Constant& constant, - ValueManager& value_factory); + Allocator<> allocator); } // namespace cel::runtime_internal diff --git a/runtime/internal/legacy_runtime_type_provider.h b/runtime/internal/legacy_runtime_type_provider.h new file mode 100644 index 000000000..f12242f12 --- /dev/null +++ b/runtime/internal/legacy_runtime_type_provider.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ + +#include "absl/base/nullability.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +class LegacyRuntimeTypeProvider final + : public google::api::expr::runtime::ProtobufDescriptorProvider { + public: + LegacyRuntimeTypeProvider( + absl::Nonnull descriptor_pool, + absl::Nullable message_factory) + : google::api::expr::runtime::ProtobufDescriptorProvider( + descriptor_pool, message_factory) {} +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ diff --git a/runtime/internal/runtime_env.h b/runtime/internal/runtime_env.h index e0ab566b1..08bc792ee 100644 --- a/runtime/internal/runtime_env.h +++ b/runtime/internal/runtime_env.h @@ -47,6 +47,8 @@ struct RuntimeEnv final { nullptr) : descriptor_pool(std::move(descriptor_pool)), message_factory(std::move(message_factory)), + legacy_type_registry(this->descriptor_pool.get(), + this->message_factory.get()), type_registry(legacy_type_registry.InternalGetModernRegistry()), function_registry(legacy_function_registry.InternalGetRegistry()) { if (this->message_factory != nullptr) { diff --git a/runtime/internal/runtime_env_testing.cc b/runtime/internal/runtime_env_testing.cc index ae7dd0ab9..25b9d1792 100644 --- a/runtime/internal/runtime_env_testing.cc +++ b/runtime/internal/runtime_env_testing.cc @@ -18,14 +18,20 @@ #include "absl/base/nullability.h" #include "absl/log/absl_check.h" +#include "internal/noop_delete.h" #include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" #include "runtime/internal/runtime_env.h" +#include "google/protobuf/message.h" namespace cel::runtime_internal { absl::Nonnull> NewTestingRuntimeEnv() { - auto env = - std::make_shared(internal::GetSharedTestingDescriptorPool()); + auto env = std::make_shared( + internal::GetSharedTestingDescriptorPool(), + std::shared_ptr( + internal::GetTestingMessageFactory(), + internal::NoopDeleteFor())); ABSL_CHECK_OK(env->Initialize()); // Crash OK return env; } diff --git a/runtime/internal/runtime_impl.cc b/runtime/internal/runtime_impl.cc index a85112a30..767cd1890 100644 --- a/runtime/internal/runtime_impl.cc +++ b/runtime/internal/runtime_impl.cc @@ -17,6 +17,7 @@ #include #include "absl/base/nullability.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/type_provider.h" @@ -29,11 +30,15 @@ #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/activation_interface.h" +#include "runtime/internal/runtime_value_manager.h" #include "runtime/runtime.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace cel::runtime_internal { namespace { +using ::cel::runtime_internal::RuntimeValueManager; using ::google::api::expr::runtime::AttributeTrail; using ::google::api::expr::runtime::ComprehensionSlots; using ::google::api::expr::runtime::DirectExpressionStep; @@ -49,16 +54,20 @@ class ProgramImpl final : public TraceableProgram { FlatExpression impl) : environment_(environment), impl_(std::move(impl)) {} - absl::StatusOr Evaluate(const ActivationInterface& activation, - ValueManager& value_factory) const override { - return Trace(activation, EvaluationListener(), value_factory); - } - - absl::StatusOr Trace(const ActivationInterface& activation, - EvaluationListener callback, - ValueManager& value_factory) const override { - auto state = impl_.MakeEvaluatorState(value_factory); - return impl_.EvaluateWithCallback(activation, std::move(callback), state); + absl::StatusOr Trace( + absl::Nonnull arena, + absl::Nullable message_factory, + const ActivationInterface& activation, + EvaluationListener evaluation_listener) const override { + ABSL_DCHECK(arena != nullptr); + RuntimeValueManager value_manager( + arena, environment_->descriptor_pool.get(), + message_factory != nullptr ? message_factory + : environment_->MutableMessageFactory(), + GetTypeProvider()); + auto state = impl_.MakeEvaluatorState(value_manager); + return impl_.EvaluateWithCallback(activation, + std::move(evaluation_listener), state); } const TypeProvider& GetTypeProvider() const override { @@ -79,17 +88,20 @@ class RecursiveProgramImpl final : public TraceableProgram { FlatExpression impl, absl::Nonnull root) : environment_(environment), impl_(std::move(impl)), root_(root) {} - absl::StatusOr Evaluate(const ActivationInterface& activation, - ValueManager& value_factory) const override { - return Trace(activation, /*callback=*/nullptr, value_factory); - } - - absl::StatusOr Trace(const ActivationInterface& activation, - EvaluationListener callback, - ValueManager& value_factory) const override { + absl::StatusOr Trace( + absl::Nonnull arena, + absl::Nullable message_factory, + const ActivationInterface& activation, + EvaluationListener evaluation_listener) const override { + ABSL_DCHECK(arena != nullptr); + RuntimeValueManager value_manager( + arena, environment_->descriptor_pool.get(), + message_factory != nullptr ? message_factory + : environment_->MutableMessageFactory(), + GetTypeProvider()); ComprehensionSlots slots(impl_.comprehension_slots_size()); - ExecutionFrameBase frame(activation, std::move(callback), impl_.options(), - value_factory, slots); + ExecutionFrameBase frame(activation, std::move(evaluation_listener), + impl_.options(), value_manager, slots); Value result; AttributeTrail attribute; diff --git a/runtime/internal/runtime_impl.h b/runtime/internal/runtime_impl.h index 0c4972fcf..74e297e96 100644 --- a/runtime/internal/runtime_impl.h +++ b/runtime/internal/runtime_impl.h @@ -32,6 +32,8 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace cel::runtime_internal { @@ -86,6 +88,15 @@ class RuntimeImpl : public Runtime { return environment_->type_registry.GetComposedTypeProvider(); } + absl::Nonnull GetDescriptorPool() + const override { + return environment_->descriptor_pool.get(); + } + + absl::Nonnull GetMessageFactory() const override { + return environment_->MutableMessageFactory(); + } + // exposed for extensions access google::api::expr::runtime::FlatExprBuilder& expr_builder() ABSL_ATTRIBUTE_LIFETIME_BOUND { diff --git a/runtime/internal/runtime_type_provider.cc b/runtime/internal/runtime_type_provider.cc new file mode 100644 index 000000000..a13a14af2 --- /dev/null +++ b/runtime/internal/runtime_type_provider.cc @@ -0,0 +1,161 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/internal/runtime_type_provider.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/any.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "common/value.h" +#include "common/value_factory.h" +#include "common/values/struct_value_builder.h" +#include "common/values/value_builder.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel::runtime_internal { + +absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { + auto insertion = types_.insert(std::pair{type.name(), Type(type)}); + if (!insertion.second) { + return absl::AlreadyExistsError( + absl::StrCat("type already registered: ", insertion.first->first)); + } + return absl::OkStatus(); +} + +absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( + absl::string_view name) const { + // We do not have to worry about well known types here. + // `TypeIntrospector::FindType` handles those directly. + const auto* desc = descriptor_pool()->FindMessageTypeByName(name); + if (desc == nullptr) { + if (const auto it = types_.find(name); it != types_.end()) { + return it->second; + } + return absl::nullopt; + } + return MessageType(desc); +} + +absl::StatusOr> +RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, + absl::string_view value) const { + const google::protobuf::EnumDescriptor* enum_desc = + descriptor_pool()->FindEnumTypeByName(type); + // google.protobuf.NullValue is special cased in the base class. + if (enum_desc == nullptr) { + return absl::nullopt; + } + + // Note: we don't support strong enum typing at this time so only the fully + // qualified enum values are meaningful, so we don't provide any signal if the + // enum type is found but can't match the value name. + const google::protobuf::EnumValueDescriptor* value_desc = + enum_desc->FindValueByName(value); + if (value_desc == nullptr) { + return absl::nullopt; + } + + return TypeIntrospector::EnumConstant{ + EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), + value_desc->number()}; +} + +absl::StatusOr> +RuntimeTypeProvider::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + // We do not have to worry about well known types here. + // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. + const auto* desc = descriptor_pool()->FindMessageTypeByName(type); + if (desc == nullptr) { + return absl::nullopt; + } + const auto* field_desc = desc->FindFieldByName(name); + if (field_desc == nullptr) { + field_desc = descriptor_pool()->FindExtensionByPrintableName(desc, name); + if (field_desc == nullptr) { + return absl::nullopt; + } + } + return MessageTypeField(field_desc); +} + +absl::StatusOr> +RuntimeTypeProvider::NewStructValueBuilder(ValueFactory& value_factory, + const StructType& type) const { + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return nullptr; + } + return common_internal::NewStructValueBuilder( + value_factory.GetMemoryManager().arena(), descriptor_pool(), + message_factory, type.name()); +} + +absl::StatusOr> +RuntimeTypeProvider::NewValueBuilder(ValueFactory& value_factory, + absl::string_view name) const { + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return nullptr; + } + return common_internal::NewValueBuilder(value_factory.GetMemoryManager(), + descriptor_pool(), message_factory, + name); +} + +absl::StatusOr> RuntimeTypeProvider::DeserializeValue( + ValueFactory& value_factory, absl::string_view type_url, + const absl::Cord& value) const { + const auto* descriptor_pool = this->descriptor_pool(); + auto* message_factory = value_factory.message_factory(); + if (message_factory == nullptr) { + return absl::nullopt; + } + absl::string_view type_name; + if (!ParseTypeUrl(type_url, &type_name)) { + return absl::InvalidArgumentError("invalid type URL"); + } + const auto* descriptor = descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::nullopt; + } + const auto* prototype = message_factory->GetPrototype(descriptor); + if (prototype == nullptr) { + return absl::nullopt; + } + absl::Nullable arena = + value_factory.GetMemoryManager().arena(); + auto message = WrapShared(prototype->New(arena), arena); + if (!message->ParsePartialFromCord(value)) { + return absl::InvalidArgumentError( + absl::StrCat("failed to parse `", type_url, "`")); + } + return Value::Message(WrapShared(prototype->New(arena), arena), + descriptor_pool, message_factory); +} + +} // namespace cel::runtime_internal diff --git a/runtime/internal/composed_type_provider.h b/runtime/internal/runtime_type_provider.h similarity index 53% rename from runtime/internal/composed_type_provider.h rename to runtime/internal/runtime_type_provider.h index 8ec9ecda2..ed53be6c2 100644 --- a/runtime/internal/composed_type_provider.h +++ b/runtime/internal/runtime_type_provider.h @@ -1,4 +1,4 @@ -// Copyright 2023 Google LLC +// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,12 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_COMPOSED_TYPE_PROVIDER_H_ -#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_COMPOSED_TYPE_PROVIDER_H_ -#include -#include -#include +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" @@ -25,59 +22,54 @@ #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "base/type_provider.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "common/value_factory.h" +#include "google/protobuf/descriptor.h" namespace cel::runtime_internal { -// Type provider implementation managed by the runtime type registry. -// -// Maintains ownership of client provided type provider implementations and -// delegates type resolution to them in order. To meet the requirements for use -// with TypeManager, this should not be updated after any call to ProvideType. -// -// The builtin type provider is implicitly consulted first in a type manager, -// so it is not represented here. -class ComposedTypeProvider : public TypeReflector { +class RuntimeTypeProvider final : public TypeReflector { public: - // Register an additional type provider. - void AddTypeProvider(std::unique_ptr provider) { - providers_.push_back(std::move(provider)); - } + explicit RuntimeTypeProvider( + absl::Nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} absl::Status RegisterType(const OpaqueType& type); - void set_use_legacy_container_builders(bool use_legacy_container_builders) { - use_legacy_container_builders_ = use_legacy_container_builders; - } - absl::StatusOr> NewStructValueBuilder( ValueFactory& value_factory, const StructType& type) const override; - absl::StatusOr FindValue(ValueFactory& value_factory, - absl::string_view name, - Value& result) const override; + absl::StatusOr> NewValueBuilder( + ValueFactory& value_factory, absl::string_view name) const override; - protected: - absl::StatusOr> DeserializeValueImpl( + // `DeserializeValue` deserializes the bytes of `value` according to + // `type_url`. Returns `NOT_FOUND` if `type_url` is unrecognized. + absl::StatusOr> DeserializeValue( ValueFactory& value_factory, absl::string_view type_url, const absl::Cord& value) const override; + absl::Nonnull descriptor_pool() + const override { + return descriptor_pool_; + } + + protected: absl::StatusOr> FindTypeImpl( absl::string_view name) const override; + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const override; + absl::StatusOr> FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const override; private: + absl::Nonnull descriptor_pool_; absl::flat_hash_map types_; - std::vector> providers_; - bool use_legacy_container_builders_ = true; }; } // namespace cel::runtime_internal -#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_COMPOSED_TYPE_PROVIDER_H_ +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ diff --git a/runtime/internal/runtime_value_manager.h b/runtime/internal/runtime_value_manager.h new file mode 100644 index 000000000..ddc294b1f --- /dev/null +++ b/runtime/internal/runtime_value_manager.h @@ -0,0 +1,75 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_VALUE_MANAGER_H_ +#define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_VALUE_MANAGER_H_ + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "common/memory.h" +#include "common/type_introspector.h" +#include "common/type_reflector.h" +#include "common/value_manager.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::runtime_internal { + +class RuntimeValueManager final : public ValueManager { + public: + RuntimeValueManager( + absl::Nonnull arena, + absl::Nonnull descriptor_pool, + absl::Nonnull message_factory, + const TypeReflector& type_reflector) + : arena_(arena), + descriptor_pool_(descriptor_pool), + message_factory_(message_factory), + type_reflector_(type_reflector) { + ABSL_DCHECK_EQ(descriptor_pool_, type_reflector_.descriptor_pool()); + } + + MemoryManagerRef GetMemoryManager() const override { + return MemoryManagerRef::Pooling(arena_); + } + + absl::Nonnull descriptor_pool() + const override { + return descriptor_pool_; + } + + absl::Nonnull message_factory() const override { + return message_factory_; + } + + protected: + const TypeIntrospector& GetTypeIntrospector() const override { + return type_reflector_; + } + + const TypeReflector& GetTypeReflector() const override { + return type_reflector_; + } + + private: + absl::Nonnull const arena_; + absl::Nonnull const descriptor_pool_; + absl::Nonnull const message_factory_; + const TypeReflector& type_reflector_; +}; + +} // namespace cel::runtime_internal + +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_VALUE_MANAGER_H_ diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc index a2381c9e8..803a8d7a3 100644 --- a/runtime/optional_types_test.cc +++ b/runtime/optional_types_test.cc @@ -16,23 +16,21 @@ #include #include -#include #include +#include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "base/function.h" #include "base/function_descriptor.h" #include "common/kind.h" -#include "common/memory.h" #include "common/value.h" #include "common/value_testing.h" -#include "common/values/legacy_value_manager.h" -#include "extensions/protobuf/memory_manager.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -53,7 +51,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; -using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::cel::test::OptionalValueIs; @@ -63,6 +60,7 @@ using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; using ::testing::ElementsAre; using ::testing::HasSubstr; +using ::testing::TestWithParam; MATCHER_P(MatchesOptionalReceiver1, name, "") { const FunctionDescriptor& descriptor = arg.descriptor; @@ -170,24 +168,23 @@ struct EvaluateResultTestCase { std::string name; std::string expression; test::ValueMatcher value_matcher; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } }; class OptionalTypesTest - : public common_internal::ThreadCompatibleValueTest { + : public TestWithParam> { public: const EvaluateResultTestCase& GetTestCase() { - return std::get<1>(GetParam()); + return std::get<0>(GetParam()); } - bool EnableShortCircuiting() { return std::get<2>(GetParam()); } + bool EnableShortCircuiting() { return std::get<1>(GetParam()); } }; -std::ostream& operator<<(std::ostream& os, - const EvaluateResultTestCase& test_case) { - return os << test_case.name; -} - TEST_P(OptionalTypesTest, RecursivePlan) { RuntimeOptions opts; opts.use_legacy_container_builders = false; @@ -216,13 +213,10 @@ TEST_P(OptionalTypesTest, RecursivePlan) { EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); - cel::common_internal::LegacyValueManager value_factory( - memory_manager(), runtime->GetTypeProvider()); - + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; } @@ -251,13 +245,10 @@ TEST_P(OptionalTypesTest, Defaults) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(this->memory_manager(), - runtime->GetTypeProvider()); - + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; } @@ -265,8 +256,6 @@ TEST_P(OptionalTypesTest, Defaults) { INSTANTIATE_TEST_SUITE_P( Basic, OptionalTypesTest, testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), testing::ValuesIn(std::vector{ {"optional_none_hasValue", "optional.none().hasValue()", BoolValueIs(false)}, @@ -285,8 +274,7 @@ INSTANTIATE_TEST_SUITE_P( IntValueIs(1)}, {"list_of_optional", "[optional.of(1)][0].orValue(1)", IntValueIs(1)}}), - /*enable_short_circuiting*/ testing::Bool()), - OptionalTypesTest::ToString); + /*enable_short_circuiting*/ testing::Bool())); class UnreachableFunction final : public cel::Function { public: @@ -305,7 +293,6 @@ class UnreachableFunction final : public cel::Function { TEST(OptionalTypesTest, ErrorShortCircuiting) { RuntimeOptions opts{.enable_qualified_type_identifiers = true}; google::protobuf::Arena arena; - auto memory_manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN( auto builder, @@ -330,13 +317,9 @@ TEST(OptionalTypesTest, ErrorShortCircuiting) { ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); - common_internal::LegacyValueManager value_factory(memory_manager, - runtime->GetTypeProvider()); - Activation activation; - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_EQ(unreachable_count, 0); ASSERT_TRUE(result->Is()) << result->DebugString(); diff --git a/runtime/reference_resolver_test.cc b/runtime/reference_resolver_test.cc index 2f6a7f483..2c8b27af1 100644 --- a/runtime/reference_resolver_test.cc +++ b/runtime/reference_resolver_test.cc @@ -27,11 +27,11 @@ #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" -#include "runtime/managed_value_factory.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace cel { @@ -79,12 +79,10 @@ TEST(ReferenceResolver, ResolveQualifiedFunctions) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_TRUE(value.GetBool().NativeValue()); } @@ -207,17 +205,13 @@ TEST(ReferenceResolver, ResolveQualifiedIdentifiers) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - activation.InsertOrAssignValue("com.example.x", - value_factory.get().CreateIntValue(3)); - activation.InsertOrAssignValue("com.example.y", - value_factory.get().CreateIntValue(4)); + activation.InsertOrAssignValue("com.example.x", IntValue(3)); + activation.InsertOrAssignValue("com.example.y", IntValue(4)); - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_EQ(value.GetInt().NativeValue(), 7); @@ -243,17 +237,13 @@ TEST(ReferenceResolver, ResolveQualifiedIdentifiersSkipParseOnly) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr.expr())); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - activation.InsertOrAssignValue("com.example.x", - value_factory.get().CreateIntValue(3)); - activation.InsertOrAssignValue("com.example.y", - value_factory.get().CreateIntValue(4)); + activation.InsertOrAssignValue("com.example.x", IntValue(3)); + activation.InsertOrAssignValue("com.example.y", IntValue(4)); - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_THAT(value.GetError().NativeValue(), @@ -333,12 +323,10 @@ TEST(ReferenceResolver, ResolveEnumConstants) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_TRUE(value.GetBool().NativeValue()); @@ -362,12 +350,10 @@ TEST(ReferenceResolver, ResolveEnumConstantsSkipParseOnly) { ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, unchecked_expr)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_THAT( diff --git a/runtime/regex_precompilation_test.cc b/runtime/regex_precompilation_test.cc index 5cbdb291c..cbf5682ef 100644 --- a/runtime/regex_precompilation_test.cc +++ b/runtime/regex_precompilation_test.cc @@ -31,11 +31,11 @@ #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" -#include "runtime/managed_value_factory.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" namespace cel::extensions { namespace { @@ -110,15 +110,12 @@ TEST_P(RegexPrecompilationTest, Basic) { ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(auto var, - value_factory.get().CreateStringValue("string_var")); - activation.InsertOrAssignValue("string_var", var); + activation.InsertOrAssignValue("string_var", + StringValue(&arena, "string_var")); - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_THAT(value, test_case.result_matcher); } @@ -157,15 +154,12 @@ TEST_P(RegexPrecompilationTest, WithConstantFolding) { } ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); - ManagedValueFactory value_factory(program->GetTypeProvider(), - MemoryManagerRef::ReferenceCounting()); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(auto var, - value_factory.get().CreateStringValue("string_var")); - activation.InsertOrAssignValue("string_var", var); + activation.InsertOrAssignValue("string_var", + StringValue(&arena, "string_var")); - ASSERT_OK_AND_ASSIGN(Value value, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_THAT(value, test_case.result_matcher); } diff --git a/runtime/runtime.h b/runtime/runtime.h index 5b1e654aa..8ec1a78db 100644 --- a/runtime/runtime.h +++ b/runtime/runtime.h @@ -22,6 +22,8 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -32,6 +34,9 @@ #include "common/value_manager.h" #include "runtime/activation_interface.h" #include "runtime/runtime_issue.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace cel { @@ -58,22 +63,23 @@ class Program { // Activation manages instances of variables available in the cel expression's // environment. // - // The memory manager determines the lifecycle requirements of the returned - // value. The most common choices are: - // - cel::MemoryManagerRef::ReferenceCounting(): created values are allocated - // on the heap - // and managed by a reference count. Destructor is called when reference - // count is 0. - // - cel::extensions::ProtoMemoryManager instance: created values are - // allocated on the backing protobuf Arena. Destructors for allocated - // objects are called on destruction of the Arena. Note: instances may - // still allocate additional memory on the heap e.g. a vector's storage - // may still be on the global heap. + // The arena will be used to as necessary to allocate values and must outlive + // the returned value, as must this program. // - // For consistency, users should use the same memory manager to create values + // For consistency, users should use the same arena to create values // in the activation and for Program evaluation. - virtual absl::StatusOr Evaluate(const ActivationInterface& activation, - ValueManager& value_factory) const = 0; + virtual absl::StatusOr Evaluate( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation) const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + virtual absl::StatusOr Evaluate( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Evaluate(arena, /*message_factory=*/nullptr, activation); + } virtual const TypeProvider& GetTypeProvider() const = 0; }; @@ -96,6 +102,16 @@ class TraceableProgram : public Program { using EvaluationListener = absl::AnyInvocable; + using Program::Evaluate; + absl::StatusOr Evaluate( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation) const + ABSL_ATTRIBUTE_LIFETIME_BOUND override { + return Trace(arena, message_factory, activation, EvaluationListener()); + } + // Evaluate the Program plan with a Listener. // // The given callback will be invoked after evaluating any program step @@ -103,9 +119,21 @@ class TraceableProgram : public Program { // // If the callback returns a non-ok status, evaluation stops and the Status // is forwarded as the result of the EvaluateWithCallback call. - virtual absl::StatusOr Trace(const ActivationInterface&, - EvaluationListener evaluation_listener, - ValueManager& value_factory) const = 0; + virtual absl::StatusOr Trace( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nullable message_factory + ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation, + EvaluationListener evaluation_listener) const + ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; + virtual absl::StatusOr Trace( + absl::Nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + const ActivationInterface& activation, + EvaluationListener evaluation_listener) const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Trace(arena, /*message_factory=*/nullptr, activation, + std::move(evaluation_listener)); + }; }; // Interface for a CEL runtime. @@ -144,6 +172,11 @@ class Runtime { virtual const TypeProvider& GetTypeProvider() const = 0; + virtual absl::Nonnull GetDescriptorPool() + const = 0; + + virtual absl::Nonnull GetMessageFactory() const = 0; + private: friend class runtime_internal::RuntimeFriendAccess; diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index 00a9899e3..c9ce72bb9 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -58,6 +57,7 @@ using ::cel::test::BoolValueIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; +using ::testing::TestWithParam; using ::testing::Truly; struct EvaluateResultTestCase { @@ -65,12 +65,12 @@ struct EvaluateResultTestCase { std::string expression; bool expected_result; std::function activation_builder; -}; -std::ostream& operator<<(std::ostream& os, - const EvaluateResultTestCase& test_case) { - return os << test_case.name; -} + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; const cel::MacroRegistry& GetMacros() { static absl::NoDestructor macros([]() { @@ -90,12 +90,9 @@ absl::StatusOr ParseWithTestMacros(absl::string_view expression) { return Parse(**src, GetMacros()); } -class StandardRuntimeTest : public common_internal::ThreadCompatibleValueTest< - EvaluateResultTestCase> { +class StandardRuntimeTest : public TestWithParam { public: - const EvaluateResultTestCase& GetTestCase() { - return std::get<1>(GetParam()); - } + const EvaluateResultTestCase& GetTestCase() { return GetParam(); } }; TEST_P(StandardRuntimeTest, Defaults) { @@ -116,16 +113,15 @@ TEST_P(StandardRuntimeTest, Defaults) { EXPECT_FALSE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); - common_internal::LegacyValueManager value_factory(memory_manager(), - runtime->GetTypeProvider()); - + google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { + common_internal::LegacyValueManager value_factory( + MemoryManager::Pooling(&arena), runtime->GetTypeProvider()); ASSERT_OK(test_case.activation_builder(value_factory, activation)); } - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; @@ -153,363 +149,310 @@ TEST_P(StandardRuntimeTest, Recursive) { // allocating a value stack). EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); - common_internal::LegacyValueManager value_factory(memory_manager(), - runtime->GetTypeProvider()); - + google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { + common_internal::LegacyValueManager value_factory( + MemoryManager::Pooling(&arena), runtime->GetTypeProvider()); ASSERT_OK(test_case.activation_builder(value_factory, activation)); } - ASSERT_OK_AND_ASSIGN(Value result, - program->Evaluate(activation, value_factory)); + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; } INSTANTIATE_TEST_SUITE_P( Basic, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"int_identifier", "int_var == 42", true, - [](ValueManager& value_factory, Activation& activation) { - activation.InsertOrAssignValue("int_var", - value_factory.CreateIntValue(42)); - return absl::OkStatus(); - }}, - {"logic_and_true", "true && 1 < 2", true}, - {"logic_and_false", "true && 1 > 2", false}, - {"logic_or_true", "false || 1 < 2", true}, - {"logic_or_false", "false && 1 > 2", false}, - {"ternary_true_cond", "(1 < 2 ? 'yes' : 'no') == 'yes'", true}, - {"ternary_false_cond", "(1 > 2 ? 'yes' : 'no') == 'no'", true}, - {"list_index", "['a', 'b', 'c', 'd'][1] == 'b'", true}, - {"map_index_bool", "{true: 1, false: 2}[false] == 2", true}, - {"map_index_string", "{'abc': 123}['abc'] == 123", true}, - {"map_index_int", "{1: 2, 2: 4}[2] == 4", true}, - {"map_index_uint", "{1u: 1, 2u: 2}[1u] == 1", true}, - {"map_index_coerced_double", "{1: 2, 2: 4}[2.0] == 4", true}, - })), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"int_identifier", "int_var == 42", true, + [](ValueManager& value_factory, Activation& activation) { + activation.InsertOrAssignValue("int_var", + value_factory.CreateIntValue(42)); + return absl::OkStatus(); + }}, + {"logic_and_true", "true && 1 < 2", true}, + {"logic_and_false", "true && 1 > 2", false}, + {"logic_or_true", "false || 1 < 2", true}, + {"logic_or_false", "false && 1 > 2", false}, + {"ternary_true_cond", "(1 < 2 ? 'yes' : 'no') == 'yes'", true}, + {"ternary_false_cond", "(1 > 2 ? 'yes' : 'no') == 'no'", true}, + {"list_index", "['a', 'b', 'c', 'd'][1] == 'b'", true}, + {"map_index_bool", "{true: 1, false: 2}[false] == 2", true}, + {"map_index_string", "{'abc': 123}['abc'] == 123", true}, + {"map_index_int", "{1: 2, 2: 4}[2] == 4", true}, + {"map_index_uint", "{1u: 1, 2u: 2}[1u] == 1", true}, + {"map_index_coerced_double", "{1: 2, 2: 4}[2.0] == 4", true}, + })); INSTANTIATE_TEST_SUITE_P( Equality, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"eq_bool_bool_true", "false == false", true}, - {"eq_bool_bool_false", "false == true", false}, - {"eq_int_int_true", "-1 == -1", true}, - {"eq_int_int_false", "-1 == 1", false}, - {"eq_uint_uint_true", "2u == 2u", true}, - {"eq_uint_uint_false", "2u == 3u", false}, - {"eq_double_double_true", "2.4 == 2.4", true}, - {"eq_double_double_false", "2.4 == 3.3", false}, - {"eq_string_string_true", "'abc' == 'abc'", true}, - {"eq_string_string_false", "'abc' == 'def'", false}, - {"eq_bytes_bytes_true", "b'abc' == b'abc'", true}, - {"eq_bytes_bytes_false", "b'abc' == b'def'", false}, - {"eq_duration_duration_true", "duration('15m') == duration('15m')", - true}, - {"eq_duration_duration_false", "duration('15m') == duration('1h')", - false}, - {"eq_timestamp_timestamp_true", - "timestamp('1970-01-01T00:02:00Z') == " - "timestamp('1970-01-01T00:02:00Z')", - true}, - {"eq_timestamp_timestamp_false", - "timestamp('1970-01-01T00:02:00Z') == " - "timestamp('2020-01-01T00:02:00Z')", - false}, - {"eq_null_null_true", "null == null", true}, - {"eq_list_list_true", "[1, 2, 3] == [1, 2, 3]", true}, - {"eq_list_list_false", "[1, 2, 3] == [1, 2, 3, 4]", false}, - {"eq_map_map_true", "{1: 2, 2: 4} == {1: 2, 2: 4}", true}, - {"eq_map_map_false", "{1: 2, 2: 4} == {1: 2, 2: 5}", false}, - - {"neq_bool_bool_true", "false != false", false}, - {"neq_bool_bool_false", "false != true", true}, - {"neq_int_int_true", "-1 != -1", false}, - {"neq_int_int_false", "-1 != 1", true}, - {"neq_uint_uint_true", "2u != 2u", false}, - {"neq_uint_uint_false", "2u != 3u", true}, - {"neq_double_double_true", "2.4 != 2.4", false}, - {"neq_double_double_false", "2.4 != 3.3", true}, - {"neq_string_string_true", "'abc' != 'abc'", false}, - {"neq_string_string_false", "'abc' != 'def'", true}, - {"neq_bytes_bytes_true", "b'abc' != b'abc'", false}, - {"neq_bytes_bytes_false", "b'abc' != b'def'", true}, - {"neq_duration_duration_true", "duration('15m') != duration('15m')", - false}, - {"neq_duration_duration_false", "duration('15m') != duration('1h')", - true}, - {"neq_timestamp_timestamp_true", - "timestamp('1970-01-01T00:02:00Z') != " - "timestamp('1970-01-01T00:02:00Z')", - false}, - {"neq_timestamp_timestamp_false", - "timestamp('1970-01-01T00:02:00Z') != " - "timestamp('2020-01-01T00:02:00Z')", - true}, - {"neq_null_null_true", "null != null", false}, - {"neq_list_list_true", "[1, 2, 3] != [1, 2, 3]", false}, - {"neq_list_list_false", "[1, 2, 3] != [1, 2, 3, 4]", true}, - {"neq_map_map_true", "{1: 2, 2: 4} != {1: 2, 2: 4}", false}, - {"neq_map_map_false", "{1: 2, 2: 4} != {1: 2, 2: 5}", true}})), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"eq_bool_bool_true", "false == false", true}, + {"eq_bool_bool_false", "false == true", false}, + {"eq_int_int_true", "-1 == -1", true}, + {"eq_int_int_false", "-1 == 1", false}, + {"eq_uint_uint_true", "2u == 2u", true}, + {"eq_uint_uint_false", "2u == 3u", false}, + {"eq_double_double_true", "2.4 == 2.4", true}, + {"eq_double_double_false", "2.4 == 3.3", false}, + {"eq_string_string_true", "'abc' == 'abc'", true}, + {"eq_string_string_false", "'abc' == 'def'", false}, + {"eq_bytes_bytes_true", "b'abc' == b'abc'", true}, + {"eq_bytes_bytes_false", "b'abc' == b'def'", false}, + {"eq_duration_duration_true", "duration('15m') == duration('15m')", + true}, + {"eq_duration_duration_false", "duration('15m') == duration('1h')", + false}, + {"eq_timestamp_timestamp_true", + "timestamp('1970-01-01T00:02:00Z') == " + "timestamp('1970-01-01T00:02:00Z')", + true}, + {"eq_timestamp_timestamp_false", + "timestamp('1970-01-01T00:02:00Z') == " + "timestamp('2020-01-01T00:02:00Z')", + false}, + {"eq_null_null_true", "null == null", true}, + {"eq_list_list_true", "[1, 2, 3] == [1, 2, 3]", true}, + {"eq_list_list_false", "[1, 2, 3] == [1, 2, 3, 4]", false}, + {"eq_map_map_true", "{1: 2, 2: 4} == {1: 2, 2: 4}", true}, + {"eq_map_map_false", "{1: 2, 2: 4} == {1: 2, 2: 5}", false}, + + {"neq_bool_bool_true", "false != false", false}, + {"neq_bool_bool_false", "false != true", true}, + {"neq_int_int_true", "-1 != -1", false}, + {"neq_int_int_false", "-1 != 1", true}, + {"neq_uint_uint_true", "2u != 2u", false}, + {"neq_uint_uint_false", "2u != 3u", true}, + {"neq_double_double_true", "2.4 != 2.4", false}, + {"neq_double_double_false", "2.4 != 3.3", true}, + {"neq_string_string_true", "'abc' != 'abc'", false}, + {"neq_string_string_false", "'abc' != 'def'", true}, + {"neq_bytes_bytes_true", "b'abc' != b'abc'", false}, + {"neq_bytes_bytes_false", "b'abc' != b'def'", true}, + {"neq_duration_duration_true", "duration('15m') != duration('15m')", + false}, + {"neq_duration_duration_false", "duration('15m') != duration('1h')", + true}, + {"neq_timestamp_timestamp_true", + "timestamp('1970-01-01T00:02:00Z') != " + "timestamp('1970-01-01T00:02:00Z')", + false}, + {"neq_timestamp_timestamp_false", + "timestamp('1970-01-01T00:02:00Z') != " + "timestamp('2020-01-01T00:02:00Z')", + true}, + {"neq_null_null_true", "null != null", false}, + {"neq_list_list_true", "[1, 2, 3] != [1, 2, 3]", false}, + {"neq_list_list_false", "[1, 2, 3] != [1, 2, 3, 4]", true}, + {"neq_map_map_true", "{1: 2, 2: 4} != {1: 2, 2: 4}", false}, + {"neq_map_map_false", "{1: 2, 2: 4} != {1: 2, 2: 5}", true}})); INSTANTIATE_TEST_SUITE_P( ArithmeticFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"lt_int_int_true", "-1 < 2", true}, - {"lt_int_int_false", "2 < -1", false}, - {"lt_double_double_true", "-1.1 < 2.2", true}, - {"lt_double_double_false", "2.2 < -1.1", false}, - {"lt_uint_uint_true", "1u < 2u", true}, - {"lt_uint_uint_false", "2u < 1u", false}, - {"lt_string_string_true", "'abc' < 'def'", true}, - {"lt_string_string_false", "'def' < 'abc'", false}, - {"lt_duration_duration_true", "duration('1s') < duration('2s')", - true}, - {"lt_duration_duration_false", "duration('2s') < duration('1s')", - false}, - {"lt_timestamp_timestamp_true", "timestamp(1) < timestamp(2)", - true}, - {"lt_timestamp_timestamp_false", "timestamp(2) < timestamp(1)", - false}, - - {"gt_int_int_false", "-1 > 2", false}, - {"gt_int_int_true", "2 > -1", true}, - {"gt_double_double_false", "-1.1 > 2.2", false}, - {"gt_double_double_true", "2.2 > -1.1", true}, - {"gt_uint_uint_false", "1u > 2u", false}, - {"gt_uint_uint_true", "2u > 1u", true}, - {"gt_string_string_false", "'abc' > 'def'", false}, - {"gt_string_string_true", "'def' > 'abc'", true}, - {"gt_duration_duration_false", "duration('1s') > duration('2s')", - false}, - {"gt_duration_duration_true", "duration('2s') > duration('1s')", - true}, - {"gt_timestamp_timestamp_false", "timestamp(1) > timestamp(2)", - false}, - {"gt_timestamp_timestamp_true", "timestamp(2) > timestamp(1)", - true}, - - {"le_int_int_true", "-1 <= -1", true}, - {"le_int_int_false", "2 <= -1", false}, - {"le_double_double_true", "-1.1 <= -1.1", true}, - {"le_double_double_false", "2.2 <= -1.1", false}, - {"le_uint_uint_true", "1u <= 1u", true}, - {"le_uint_uint_false", "2u <= 1u", false}, - {"le_string_string_true", "'abc' <= 'abc'", true}, - {"le_string_string_false", "'def' <= 'abc'", false}, - {"le_duration_duration_true", "duration('1s') <= duration('1s')", - true}, - {"le_duration_duration_false", "duration('2s') <= duration('1s')", - false}, - {"le_timestamp_timestamp_true", "timestamp(1) <= timestamp(1)", - true}, - {"le_timestamp_timestamp_false", "timestamp(2) <= timestamp(1)", - false}, - - {"ge_int_int_false", "-1 >= 2", false}, - {"ge_int_int_true", "2 >= 2", true}, - {"ge_double_double_false", "-1.1 >= 2.2", false}, - {"ge_double_double_true", "2.2 >= 2.2", true}, - {"ge_uint_uint_false", "1u >= 2u", false}, - {"ge_uint_uint_true", "2u >= 2u", true}, - {"ge_string_string_false", "'abc' >= 'def'", false}, - {"ge_string_string_true", "'abc' >= 'abc'", true}, - {"ge_duration_duration_false", "duration('1s') >= duration('2s')", - false}, - {"ge_duration_duration_true", "duration('1s') >= duration('1s')", - true}, - {"ge_timestamp_timestamp_false", "timestamp(1) >= timestamp(2)", - false}, - {"ge_timestamp_timestamp_true", "timestamp(1) >= timestamp(1)", - true}, - - {"sum_int_int", "1 + 2 == 3", true}, - {"sum_uint_uint", "3u + 4u == 7", true}, - {"sum_double_double", "1.0 + 2.5 == 3.5", true}, - {"sum_duration_duration", - "duration('2m') + duration('30s') == duration('150s')", true}, - {"sum_time_duration", - "timestamp(0) + duration('2m') == " - "timestamp('1970-01-01T00:02:00Z')", - true}, - - {"difference_int_int", "1 - 2 == -1", true}, - {"difference_uint_uint", "4u - 3u == 1u", true}, - {"difference_double_double", "1.0 - 2.5 == -1.5", true}, - {"difference_duration_duration", - "duration('5m') - duration('45s') == duration('4m15s')", true}, - {"difference_time_time", - "timestamp(10) - timestamp(0) == duration('10s')", true}, - {"difference_time_duration", - "timestamp(0) - duration('2m') == " - "timestamp('1969-12-31T23:58:00Z')", - true}, - - {"multiplication_int_int", "2 * 3 == 6", true}, - {"multiplication_uint_uint", "2u * 3u == 6u", true}, - {"multiplication_double_double", "2.5 * 3.0 == 7.5", true}, - - {"division_int_int", "6 / 3 == 2", true}, - {"division_uint_uint", "8u / 4u == 2u", true}, - {"division_double_double", "1.0 / 0.0 == double('inf')", true}, - - {"modulo_int_int", "6 % 4 == 2", true}, - {"modulo_uint_uint", "8u % 5u == 3u", true}, - })), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"lt_int_int_true", "-1 < 2", true}, + {"lt_int_int_false", "2 < -1", false}, + {"lt_double_double_true", "-1.1 < 2.2", true}, + {"lt_double_double_false", "2.2 < -1.1", false}, + {"lt_uint_uint_true", "1u < 2u", true}, + {"lt_uint_uint_false", "2u < 1u", false}, + {"lt_string_string_true", "'abc' < 'def'", true}, + {"lt_string_string_false", "'def' < 'abc'", false}, + {"lt_duration_duration_true", "duration('1s') < duration('2s')", true}, + {"lt_duration_duration_false", "duration('2s') < duration('1s')", + false}, + {"lt_timestamp_timestamp_true", "timestamp(1) < timestamp(2)", true}, + {"lt_timestamp_timestamp_false", "timestamp(2) < timestamp(1)", false}, + + {"gt_int_int_false", "-1 > 2", false}, + {"gt_int_int_true", "2 > -1", true}, + {"gt_double_double_false", "-1.1 > 2.2", false}, + {"gt_double_double_true", "2.2 > -1.1", true}, + {"gt_uint_uint_false", "1u > 2u", false}, + {"gt_uint_uint_true", "2u > 1u", true}, + {"gt_string_string_false", "'abc' > 'def'", false}, + {"gt_string_string_true", "'def' > 'abc'", true}, + {"gt_duration_duration_false", "duration('1s') > duration('2s')", + false}, + {"gt_duration_duration_true", "duration('2s') > duration('1s')", true}, + {"gt_timestamp_timestamp_false", "timestamp(1) > timestamp(2)", false}, + {"gt_timestamp_timestamp_true", "timestamp(2) > timestamp(1)", true}, + + {"le_int_int_true", "-1 <= -1", true}, + {"le_int_int_false", "2 <= -1", false}, + {"le_double_double_true", "-1.1 <= -1.1", true}, + {"le_double_double_false", "2.2 <= -1.1", false}, + {"le_uint_uint_true", "1u <= 1u", true}, + {"le_uint_uint_false", "2u <= 1u", false}, + {"le_string_string_true", "'abc' <= 'abc'", true}, + {"le_string_string_false", "'def' <= 'abc'", false}, + {"le_duration_duration_true", "duration('1s') <= duration('1s')", true}, + {"le_duration_duration_false", "duration('2s') <= duration('1s')", + false}, + {"le_timestamp_timestamp_true", "timestamp(1) <= timestamp(1)", true}, + {"le_timestamp_timestamp_false", "timestamp(2) <= timestamp(1)", false}, + + {"ge_int_int_false", "-1 >= 2", false}, + {"ge_int_int_true", "2 >= 2", true}, + {"ge_double_double_false", "-1.1 >= 2.2", false}, + {"ge_double_double_true", "2.2 >= 2.2", true}, + {"ge_uint_uint_false", "1u >= 2u", false}, + {"ge_uint_uint_true", "2u >= 2u", true}, + {"ge_string_string_false", "'abc' >= 'def'", false}, + {"ge_string_string_true", "'abc' >= 'abc'", true}, + {"ge_duration_duration_false", "duration('1s') >= duration('2s')", + false}, + {"ge_duration_duration_true", "duration('1s') >= duration('1s')", true}, + {"ge_timestamp_timestamp_false", "timestamp(1) >= timestamp(2)", false}, + {"ge_timestamp_timestamp_true", "timestamp(1) >= timestamp(1)", true}, + + {"sum_int_int", "1 + 2 == 3", true}, + {"sum_uint_uint", "3u + 4u == 7", true}, + {"sum_double_double", "1.0 + 2.5 == 3.5", true}, + {"sum_duration_duration", + "duration('2m') + duration('30s') == duration('150s')", true}, + {"sum_time_duration", + "timestamp(0) + duration('2m') == " + "timestamp('1970-01-01T00:02:00Z')", + true}, + + {"difference_int_int", "1 - 2 == -1", true}, + {"difference_uint_uint", "4u - 3u == 1u", true}, + {"difference_double_double", "1.0 - 2.5 == -1.5", true}, + {"difference_duration_duration", + "duration('5m') - duration('45s') == duration('4m15s')", true}, + {"difference_time_time", + "timestamp(10) - timestamp(0) == duration('10s')", true}, + {"difference_time_duration", + "timestamp(0) - duration('2m') == " + "timestamp('1969-12-31T23:58:00Z')", + true}, + + {"multiplication_int_int", "2 * 3 == 6", true}, + {"multiplication_uint_uint", "2u * 3u == 6u", true}, + {"multiplication_double_double", "2.5 * 3.0 == 7.5", true}, + + {"division_int_int", "6 / 3 == 2", true}, + {"division_uint_uint", "8u / 4u == 2u", true}, + {"division_double_double", "1.0 / 0.0 == double('inf')", true}, + + {"modulo_int_int", "6 % 4 == 2", true}, + {"modulo_uint_uint", "8u % 5u == 3u", true}, + })); INSTANTIATE_TEST_SUITE_P( Macros, StandardRuntimeTest, - testing::Combine(testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"map", "[1, 2, 3, 4].map(x, x * x)[3] == 16", true}, - {"filter", "[1, 2, 3, 4].filter(x, x < 4).size() == 3", - true}, - {"exists", "[1, 2, 3, 4].exists(x, x < 4)", true}, - {"all", "[1, 2, 3, 4].all(x, x < 5)", true}})), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"map", "[1, 2, 3, 4].map(x, x * x)[3] == 16", true}, + {"filter", "[1, 2, 3, 4].filter(x, x < 4).size() == 3", true}, + {"exists", "[1, 2, 3, 4].exists(x, x < 4)", true}, + {"all", "[1, 2, 3, 4].all(x, x < 5)", true}})); INSTANTIATE_TEST_SUITE_P( StringFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"string_contains", "'tacocat'.contains('acoca')", true}, - {"string_contains_global", "contains('tacocat', 'dog')", false}, - {"string_ends_with", "'abcdefg'.endsWith('efg')", true}, - {"string_ends_with_global", "endsWith('abcdefg', 'fgh')", false}, - {"string_starts_with", "'abcdefg'.startsWith('abc')", true}, - {"string_starts_with_global", "startsWith('abcd', 'bcd')", false}, - {"string_size", "'Hello World! 😀'.size() == 14", true}, - {"string_size_global", "size('Hello world!') == 12", true}, - {"bytes_size", "b'0123'.size() == 4", true}, - {"bytes_size_global", "size(b'😀') == 4", true}})), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"string_contains", "'tacocat'.contains('acoca')", true}, + {"string_contains_global", "contains('tacocat', 'dog')", false}, + {"string_ends_with", "'abcdefg'.endsWith('efg')", true}, + {"string_ends_with_global", "endsWith('abcdefg', 'fgh')", false}, + {"string_starts_with", "'abcdefg'.startsWith('abc')", true}, + {"string_starts_with_global", "startsWith('abcd', 'bcd')", false}, + {"string_size", "'Hello World! 😀'.size() == 14", true}, + {"string_size_global", "size('Hello world!') == 12", true}, + {"bytes_size", "b'0123'.size() == 4", true}, + {"bytes_size_global", "size(b'😀') == 4", true}})); INSTANTIATE_TEST_SUITE_P( RegExFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"matches_string_re", - "'127.0.0.1'.matches(r'127\\.\\d+\\.\\d+\\.\\d+')", true}, - {"matches_string_re_global", - "matches('192.168.0.1', r'127\\.\\d+\\.\\d+\\.\\d+')", false}})), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"matches_string_re", + "'127.0.0.1'.matches(r'127\\.\\d+\\.\\d+\\.\\d+')", true}, + {"matches_string_re_global", + "matches('192.168.0.1', r'127\\.\\d+\\.\\d+\\.\\d+')", false}})); INSTANTIATE_TEST_SUITE_P( TimeFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"timestamp_get_full_year", - "timestamp('2001-02-03T04:05:06.007Z').getFullYear() == 2001", - true}, - {"timestamp_get_date", - "timestamp('2001-02-03T04:05:06.007Z').getDate() == 3", true}, - {"timestamp_get_hours", - "timestamp('2001-02-03T04:05:06.007Z').getHours() == 4", true}, - {"timestamp_get_minutes", - "timestamp('2001-02-03T04:05:06.007Z').getMinutes() == 5", true}, - {"timestamp_get_seconds", - "timestamp('2001-02-03T04:05:06.007Z').getSeconds() == 6", true}, - {"timestamp_get_milliseconds", - "timestamp('2001-02-03T04:05:06.007Z').getMilliseconds() == 7", - true}, - // Zero based indexing - {"timestamp_get_month", - "timestamp('2001-02-03T04:05:06.007Z').getMonth() == 1", true}, - {"timestamp_get_day_of_year", - "timestamp('2001-02-03T04:05:06.007Z').getDayOfYear() == 33", - true}, - {"timestamp_get_day_of_month", - "timestamp('2001-02-03T04:05:06.007Z').getDayOfMonth() == 2", - true}, - {"timestamp_get_day_of_week", - "timestamp('2001-02-03T04:05:06.007Z').getDayOfWeek() == 6", true}, - {"duration_get_hours", "duration('10h20m30s40ms').getHours() == 10", - true}, - {"duration_get_minutes", - "duration('10h20m30s40ms').getMinutes() == 20 + 600", true}, - {"duration_get_seconds", - "duration('10h20m30s40ms').getSeconds() == 30 + 20 * 60 + 10 * 60 " - "* " - "60", - true}, - {"duration_get_milliseconds", - "duration('10h20m30s40ms').getMilliseconds() == 40", true}, - })), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"timestamp_get_full_year", + "timestamp('2001-02-03T04:05:06.007Z').getFullYear() == 2001", true}, + {"timestamp_get_date", + "timestamp('2001-02-03T04:05:06.007Z').getDate() == 3", true}, + {"timestamp_get_hours", + "timestamp('2001-02-03T04:05:06.007Z').getHours() == 4", true}, + {"timestamp_get_minutes", + "timestamp('2001-02-03T04:05:06.007Z').getMinutes() == 5", true}, + {"timestamp_get_seconds", + "timestamp('2001-02-03T04:05:06.007Z').getSeconds() == 6", true}, + {"timestamp_get_milliseconds", + "timestamp('2001-02-03T04:05:06.007Z').getMilliseconds() == 7", true}, + // Zero based indexing + {"timestamp_get_month", + "timestamp('2001-02-03T04:05:06.007Z').getMonth() == 1", true}, + {"timestamp_get_day_of_year", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfYear() == 33", true}, + {"timestamp_get_day_of_month", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfMonth() == 2", true}, + {"timestamp_get_day_of_week", + "timestamp('2001-02-03T04:05:06.007Z').getDayOfWeek() == 6", true}, + {"duration_get_hours", "duration('10h20m30s40ms').getHours() == 10", + true}, + {"duration_get_minutes", + "duration('10h20m30s40ms').getMinutes() == 20 + 600", true}, + {"duration_get_seconds", + "duration('10h20m30s40ms').getSeconds() == 30 + 20 * 60 + 10 * 60 " + "* " + "60", + true}, + {"duration_get_milliseconds", + "duration('10h20m30s40ms').getMilliseconds() == 40", true}, + })); INSTANTIATE_TEST_SUITE_P( TypeConversionFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - {"string_timestamp", - "string(timestamp(1)) == '1970-01-01T00:00:01Z'", true}, - {"string_duration", "string(duration('10m30s')) == '630s'", true}, - {"string_int", "string(-1) == '-1'", true}, - {"string_uint", "string(1u) == '1'", true}, - {"string_double", "string(double('inf')) == 'inf'", true}, - {"string_bytes", R"(string(b'\xF0\x9F\x98\x80') == '😀')", true}, - {"string_string", "string('hello!') == 'hello!'", true}, - {"bytes_bytes", "bytes(b'123') == b'123'", true}, - {"bytes_string", "bytes('😀') == b'\xF0\x9F\x98\x80'", true}, - {"timestamp", "timestamp(1) == timestamp('1970-01-01T00:00:01Z')", - true}, - {"duration", "duration('10h') == duration('600m')", true}, - {"double_string", "double('1.0') == 1.0", true}, - {"double_string_nan", "double('nan') != double('nan')", true}, - {"double_int", "double(1) == 1.0", true}, - {"double_uint", "double(1u) == 1.0", true}, - {"double_double", "double(1.0) == 1.0", true}, - {"uint_string", "uint('1') == 1u", true}, - {"uint_int", "uint(1) == 1u", true}, - {"uint_uint", "uint(1u) == 1u", true}, - {"uint_double", "uint(1.1) == 1u", true}, - {"int_string", "int('-1') == -1", true}, - {"int_int", "int(-1) == -1", true}, - {"int_uint", "int(1u) == 1", true}, - {"int_double", "int(-1.1) == -1", true}, - {"int_timestamp", "int(timestamp('1969-12-31T23:30:00Z')) == -1800", - true}, - })), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + {"string_timestamp", "string(timestamp(1)) == '1970-01-01T00:00:01Z'", + true}, + {"string_duration", "string(duration('10m30s')) == '630s'", true}, + {"string_int", "string(-1) == '-1'", true}, + {"string_uint", "string(1u) == '1'", true}, + {"string_double", "string(double('inf')) == 'inf'", true}, + {"string_bytes", R"(string(b'\xF0\x9F\x98\x80') == '😀')", true}, + {"string_string", "string('hello!') == 'hello!'", true}, + {"bytes_bytes", "bytes(b'123') == b'123'", true}, + {"bytes_string", "bytes('😀') == b'\xF0\x9F\x98\x80'", true}, + {"timestamp", "timestamp(1) == timestamp('1970-01-01T00:00:01Z')", + true}, + {"duration", "duration('10h') == duration('600m')", true}, + {"double_string", "double('1.0') == 1.0", true}, + {"double_string_nan", "double('nan') != double('nan')", true}, + {"double_int", "double(1) == 1.0", true}, + {"double_uint", "double(1u) == 1.0", true}, + {"double_double", "double(1.0) == 1.0", true}, + {"uint_string", "uint('1') == 1u", true}, + {"uint_int", "uint(1) == 1u", true}, + {"uint_uint", "uint(1u) == 1u", true}, + {"uint_double", "uint(1.1) == 1u", true}, + {"int_string", "int('-1') == -1", true}, + {"int_int", "int(-1) == -1", true}, + {"int_uint", "int(1u) == 1", true}, + {"int_double", "int(-1.1) == -1", true}, + {"int_timestamp", "int(timestamp('1969-12-31T23:30:00Z')) == -1800", + true}, + })); INSTANTIATE_TEST_SUITE_P( ContainerFunctions, StandardRuntimeTest, - testing::Combine( - testing::Values(MemoryManagement::kPooling, - MemoryManagement::kReferenceCounting), - testing::ValuesIn(std::vector{ - // Containers - {"map_size", "{'abc': 1, 'def': 2}.size() == 2", true}, - {"map_in", "'abc' in {'abc': 1, 'def': 2}", true}, - {"map_in_numeric", "1.0 in {1u: 1, 2u: 2}", true}, - {"list_size", "[1, 2, 3, 4].size() == 4", true}, - {"list_size_global", "size([1, 2, 3]) == 3", true}, - {"list_concat", "[1, 2] + [3, 4] == [1, 2, 3, 4]", true}, - {"list_in", "'a' in ['a', 'b', 'c', 'd']", true}, - {"list_in_numeric", "3u in [1.1, 2.3, 3.0, 4.4]", true}})), - StandardRuntimeTest::ToString); + testing::ValuesIn(std::vector{ + // Containers + {"map_size", "{'abc': 1, 'def': 2}.size() == 2", true}, + {"map_in", "'abc' in {'abc': 1, 'def': 2}", true}, + {"map_in_numeric", "1.0 in {1u: 1, 2u: 2}", true}, + {"list_size", "[1, 2, 3, 4].size() == 4", true}, + {"list_size_global", "size([1, 2, 3]) == 3", true}, + {"list_concat", "[1, 2] + [3, 4] == [1, 2, 3, 4]", true}, + {"list_in", "'a' in ['a', 'b', 'c', 'd']", true}, + {"list_in_numeric", "3u in [1.1, 2.3, 3.0, 4.4]", true}})); TEST(StandardRuntimeTest, RuntimeIssueSupport) { RuntimeOptions options; @@ -593,10 +536,10 @@ TEST(StandardRuntimeTest, RuntimeIssueSupport) { ManagedValueFactory value_factory(program->GetTypeProvider(), memory_manager); + google::protobuf::Arena arena; Activation activation; - ASSERT_OK_AND_ASSIGN(auto result, - program->Evaluate(activation, value_factory.get())); + ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); EXPECT_TRUE(result->Is() && result.GetBool().NativeValue()); } } diff --git a/runtime/type_registry.cc b/runtime/type_registry.cc index 5d93e725d..73a31d62c 100644 --- a/runtime/type_registry.cc +++ b/runtime/type_registry.cc @@ -18,12 +18,22 @@ #include #include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "runtime/internal/legacy_runtime_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace cel { -TypeRegistry::TypeRegistry() { +TypeRegistry::TypeRegistry( + absl::Nonnull descriptor_pool, + absl::Nullable message_factory) + : type_provider_(descriptor_pool), + legacy_type_provider_( + std::make_shared( + descriptor_pool, message_factory)) { RegisterEnum("google.protobuf.NullValue", {{"NULL_VALUE", 0}}); } diff --git a/runtime/type_registry.h b/runtime/type_registry.h index fb47723dd..61a0b2b2d 100644 --- a/runtime/type_registry.h +++ b/runtime/type_registry.h @@ -18,18 +18,30 @@ #include #include #include -#include #include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/type_provider.h" #include "common/type.h" -#include "runtime/internal/composed_type_provider.h" +#include "runtime/internal/legacy_runtime_type_provider.h" +#include "runtime/internal/runtime_type_provider.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace cel { +class TypeRegistry; + +namespace runtime_internal { +const RuntimeTypeProvider& GetRuntimeTypeProvider( + const TypeRegistry& type_registry); +const absl::Nonnull>& +GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry); +} // namespace runtime_internal + // TypeRegistry manages composing TypeProviders used with a Runtime. // // It provides a single effective type provider to be used in a ValueManager. @@ -46,7 +58,12 @@ class TypeRegistry { std::vector enumerators; }; - TypeRegistry(); + TypeRegistry() + : TypeRegistry(google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory()) {} + + TypeRegistry(absl::Nonnull descriptor_pool, + absl::Nullable message_factory); // Move-only TypeRegistry(const TypeRegistry& other) = delete; @@ -54,14 +71,10 @@ class TypeRegistry { TypeRegistry(TypeRegistry&& other) = default; TypeRegistry& operator=(TypeRegistry&& other) = default; - void AddTypeProvider(std::unique_ptr provider) { - impl_.AddTypeProvider(std::move(provider)); - } - // Registers a type such that it can be accessed by name, i.e. `type(foo) == // my_type`. Where `my_type` is the type being registered. absl::Status RegisterType(const OpaqueType& type) { - return impl_.RegisterType(type); + return type_provider_.RegisterType(type); } // Register a custom enum type. @@ -77,16 +90,34 @@ class TypeRegistry { } // Returns the effective type provider. - const TypeProvider& GetComposedTypeProvider() const { return impl_; } - void set_use_legacy_container_builders(bool use_legacy_container_builders) { - impl_.set_use_legacy_container_builders(use_legacy_container_builders); - } + const TypeProvider& GetComposedTypeProvider() const { return type_provider_; } + void set_use_legacy_container_builders(bool use_legacy_container_builders) {} private: - runtime_internal::ComposedTypeProvider impl_; + friend const runtime_internal::RuntimeTypeProvider& + runtime_internal::GetRuntimeTypeProvider(const TypeRegistry& type_registry); + friend const absl::Nonnull< + std::shared_ptr>& + runtime_internal::GetLegacyRuntimeTypeProvider( + const TypeRegistry& type_registry); + + runtime_internal::RuntimeTypeProvider type_provider_; + absl::Nonnull> + legacy_type_provider_; absl::flat_hash_map enum_types_; }; +namespace runtime_internal { +inline const RuntimeTypeProvider& GetRuntimeTypeProvider( + const TypeRegistry& type_registry) { + return type_registry.type_provider_; +} +inline const absl::Nonnull>& +GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry) { + return type_registry.legacy_type_provider_; +} +} // namespace runtime_internal + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ From 0834a1592a5ad7ae21b3c84b167196c0daa9cb9f Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 6 Nov 2024 11:40:43 -0800 Subject: [PATCH 026/180] Drop `absl::Status` return from `NewStructValueBuilder` and `NewValueBuilder` PiperOrigin-RevId: 693806382 --- common/type_reflector_test.cc | 96 ++++++++++----------------- common/value.cc | 2 +- common/value.h | 2 +- common/values/struct_value_builder.cc | 25 ++++--- common/values/struct_value_builder.h | 4 +- common/values/value_builder.h | 3 +- 6 files changed, 56 insertions(+), 76 deletions(-) diff --git a/common/type_reflector_test.cc b/common/type_reflector_test.cc index 32e906c14..96be394c9 100644 --- a/common/type_reflector_test.cc +++ b/common/type_reflector_test.cc @@ -222,11 +222,9 @@ TEST_P(TypeReflectorTest, JsonKeyCoverage) { } TEST_P(TypeReflectorTest, NewValueBuilder_BoolValue) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.BoolValue")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BoolValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), @@ -244,11 +242,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_BoolValue) { } TEST_P(TypeReflectorTest, NewValueBuilder_Int32Value) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.Int32Value")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -272,11 +268,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Int32Value) { } TEST_P(TypeReflectorTest, NewValueBuilder_Int64Value) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.Int64Value")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Int64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -294,11 +288,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Int64Value) { } TEST_P(TypeReflectorTest, NewValueBuilder_UInt32Value) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), @@ -322,11 +314,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_UInt32Value) { } TEST_P(TypeReflectorTest, NewValueBuilder_UInt64Value) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), @@ -344,11 +334,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_UInt64Value) { } TEST_P(TypeReflectorTest, NewValueBuilder_FloatValue) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.FloatValue")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.FloatValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), @@ -366,11 +354,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_FloatValue) { } TEST_P(TypeReflectorTest, NewValueBuilder_DoubleValue) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), @@ -388,11 +374,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_DoubleValue) { } TEST_P(TypeReflectorTest, NewValueBuilder_StringValue) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.StringValue")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.StringValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), @@ -410,11 +394,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_StringValue) { } TEST_P(TypeReflectorTest, NewValueBuilder_BytesValue) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.BytesValue")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.BytesValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), @@ -432,11 +414,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_BytesValue) { } TEST_P(TypeReflectorTest, NewValueBuilder_Duration) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.Duration")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Duration"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -467,11 +447,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Duration) { } TEST_P(TypeReflectorTest, NewValueBuilder_Timestamp) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.Timestamp")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Timestamp"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOk()); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), @@ -502,11 +480,9 @@ TEST_P(TypeReflectorTest, NewValueBuilder_Timestamp) { } TEST_P(TypeReflectorTest, NewValueBuilder_Any) { - ASSERT_OK_AND_ASSIGN( - auto builder, - common_internal::NewValueBuilder( - memory_manager(), internal::GetTestingDescriptorPool(), - internal::GetTestingMessageFactory(), "google.protobuf.Any")); + auto builder = common_internal::NewValueBuilder( + memory_manager(), internal::GetTestingDescriptorPool(), + internal::GetTestingMessageFactory(), "google.protobuf.Any"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName( "type_url", diff --git a/common/value.cc b/common/value.cc index c4a7a8a28..2929a20be 100644 --- a/common/value.cc +++ b/common/value.cc @@ -2552,7 +2552,7 @@ absl::Nonnull NewMapValueBuilder( return common_internal::NewMapValueBuilder(arena); } -absl::StatusOr> NewStructValueBuilder( +absl::Nullable NewStructValueBuilder( absl::Nonnull arena, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, diff --git a/common/value.h b/common/value.h index e14223e37..fe9861c1e 100644 --- a/common/value.h +++ b/common/value.h @@ -2655,7 +2655,7 @@ absl::Nonnull NewMapValueBuilder( // message type with the name `name` in `descriptor_pool`. Returns an error if // `message_factory` is unable to provide a prototype for the descriptor // returned from `descriptor_pool`. -absl::StatusOr> NewStructValueBuilder( +absl::Nullable NewStructValueBuilder( absl::Nonnull arena, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 6bf9440a9..a7d144e89 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -1647,7 +1647,7 @@ class StructValueBuilderImpl final : public StructValueBuilder { } // namespace -absl::StatusOr> NewValueBuilder( +absl::Nullable NewValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, @@ -1659,17 +1659,20 @@ absl::StatusOr> NewValueBuilder( } absl::Nullable prototype = message_factory->GetPrototype(descriptor); - if (prototype == nullptr) { - return absl::NotFoundError(absl::StrCat( - "unable to get prototype for descriptor: ", descriptor->full_name())); + ABSL_DCHECK(prototype != nullptr) + << "failed to get message prototype from factory, did you pass a dynamic " + "descriptor to the generated message factory? we consider this to be " + "a logic error and not a runtime error: " + << descriptor->full_name(); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return nullptr; } return std::make_unique(allocator.arena(), descriptor_pool, message_factory, prototype->New(allocator.arena())); } -absl::StatusOr> -NewStructValueBuilder( +absl::Nullable NewStructValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, @@ -1681,9 +1684,13 @@ NewStructValueBuilder( } absl::Nullable prototype = message_factory->GetPrototype(descriptor); - if (prototype == nullptr) { - return absl::NotFoundError(absl::StrCat( - "unable to get prototype for descriptor: ", descriptor->full_name())); + ABSL_DCHECK(prototype != nullptr) + << "failed to get message prototype from factory, did you pass a dynamic " + "descriptor to the generated message factory? we consider this to be " + "a logic error and not a runtime error: " + << descriptor->full_name(); + if (ABSL_PREDICT_FALSE(prototype == nullptr)) { + return nullptr; } return std::make_unique( allocator.arena(), descriptor_pool, message_factory, diff --git a/common/values/struct_value_builder.h b/common/values/struct_value_builder.h index 063dc8c84..bf95022b5 100644 --- a/common/values/struct_value_builder.h +++ b/common/values/struct_value_builder.h @@ -16,7 +16,6 @@ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ #include "absl/base/nullability.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/allocator.h" #include "common/value.h" @@ -25,8 +24,7 @@ namespace cel::common_internal { -absl::StatusOr> -NewStructValueBuilder( +absl::Nullable NewStructValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, diff --git a/common/values/value_builder.h b/common/values/value_builder.h index e93704884..15c6b6dd9 100644 --- a/common/values/value_builder.h +++ b/common/values/value_builder.h @@ -16,7 +16,6 @@ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ #include "absl/base/nullability.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/allocator.h" #include "common/value.h" @@ -26,7 +25,7 @@ namespace cel::common_internal { // Like NewStructValueBuilder, but deals with well known types. -absl::StatusOr> NewValueBuilder( +absl::Nullable NewValueBuilder( Allocator<> allocator, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, From c58e7388594e0224f556347faa9e6e13b22eaf08 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 6 Nov 2024 18:39:18 -0800 Subject: [PATCH 027/180] Refactor JSON-related well known types PiperOrigin-RevId: 693937881 --- internal/json.cc | 83 +++++++++++++----------------------- internal/well_known_types.cc | 48 ++++++++++++++++++--- internal/well_known_types.h | 64 +++++++++++++++++++++++---- 3 files changed, 128 insertions(+), 67 deletions(-) diff --git a/internal/json.cc b/internal/json.cc index aa5d6cce0..f557a5491 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -59,13 +59,11 @@ namespace { using ::cel::well_known_types::AsVariant; using ::cel::well_known_types::GetListValueReflection; -using ::cel::well_known_types::GetListValueReflectionOrDie; using ::cel::well_known_types::GetRepeatedBytesField; using ::cel::well_known_types::GetRepeatedStringField; using ::cel::well_known_types::GetStructReflection; -using ::cel::well_known_types::GetStructReflectionOrDie; using ::cel::well_known_types::GetValueReflection; -using ::cel::well_known_types::GetValueReflectionOrDie; +using ::cel::well_known_types::JsonReflection; using ::cel::well_known_types::ListValueReflection; using ::cel::well_known_types::Reflection; using ::cel::well_known_types::StructReflection; @@ -1105,92 +1103,86 @@ class DynamicMessageToJsonState final : public MessageToJsonState { absl::Status Initialize( absl::Nonnull message) override { - CEL_RETURN_IF_ERROR(value_reflection_.Initialize( + CEL_RETURN_IF_ERROR(reflection_.Initialize( google::protobuf::DownCastMessage(message)->GetDescriptor())); - CEL_RETURN_IF_ERROR(list_value_reflection_.Initialize( - value_reflection_.GetListValueDescriptor())); - CEL_RETURN_IF_ERROR( - struct_reflection_.Initialize(value_reflection_.GetStructDescriptor())); return absl::OkStatus(); } private: void SetNullValue( absl::Nonnull message) const override { - value_reflection_.SetNullValue( + reflection_.Value().SetNullValue( google::protobuf::DownCastMessage(message)); } void SetBoolValue(absl::Nonnull message, bool value) const override { - value_reflection_.SetBoolValue( + reflection_.Value().SetBoolValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(absl::Nonnull message, double value) const override { - value_reflection_.SetNumberValue( + reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(absl::Nonnull message, int64_t value) const override { - value_reflection_.SetNumberValue( + reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(absl::Nonnull message, uint64_t value) const override { - value_reflection_.SetNumberValue( + reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(absl::Nonnull message, absl::string_view value) const override { - value_reflection_.SetStringValue( + reflection_.Value().SetStringValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(absl::Nonnull message, const absl::Cord& value) const override { - value_reflection_.SetStringValue( + reflection_.Value().SetStringValue( google::protobuf::DownCastMessage(message), value); } absl::Nonnull MutableListValue( absl::Nonnull message) const override { - return value_reflection_.MutableListValue( + return reflection_.Value().MutableListValue( google::protobuf::DownCastMessage(message)); } absl::Nonnull MutableStructValue( absl::Nonnull message) const override { - return value_reflection_.MutableStructValue( + return reflection_.Value().MutableStructValue( google::protobuf::DownCastMessage(message)); } void ReserveValues(absl::Nonnull message, int capacity) const override { - list_value_reflection_.ReserveValues( + reflection_.ListValue().ReserveValues( google::protobuf::DownCastMessage(message), capacity); } absl::Nonnull AddValues( absl::Nonnull message) const override { - return list_value_reflection_.AddValues( + return reflection_.ListValue().AddValues( google::protobuf::DownCastMessage(message)); } absl::Nonnull InsertField( absl::Nonnull message, absl::string_view name) const override { - return struct_reflection_.InsertField( + return reflection_.Struct().InsertField( google::protobuf::DownCastMessage(message), name); } - ValueReflection value_reflection_; - ListValueReflection list_value_reflection_; - StructReflection struct_reflection_; + JsonReflection reflection_; }; } // namespace @@ -1478,97 +1470,82 @@ class GeneratedJsonAccessor final : public JsonAccessor { class DynamicJsonAccessor final : public JsonAccessor { public: void InitializeValue(const google::protobuf::Message& message) { - value_reflection_ = GetValueReflectionOrDie(message.GetDescriptor()); - list_value_reflection_ = - GetListValueReflectionOrDie(value_reflection_.GetListValueDescriptor()); - struct_reflection_ = - GetStructReflectionOrDie(value_reflection_.GetStructDescriptor()); + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } void InitializeListValue(const google::protobuf::Message& message) { - list_value_reflection_ = - GetListValueReflectionOrDie(message.GetDescriptor()); - value_reflection_ = - GetValueReflectionOrDie(list_value_reflection_.GetValueDescriptor()); - struct_reflection_ = - GetStructReflectionOrDie(value_reflection_.GetStructDescriptor()); + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } void InitializeStruct(const google::protobuf::Message& message) { - struct_reflection_ = GetStructReflectionOrDie(message.GetDescriptor()); - value_reflection_ = - GetValueReflectionOrDie(struct_reflection_.GetValueDescriptor()); - list_value_reflection_ = - GetListValueReflectionOrDie(value_reflection_.GetListValueDescriptor()); + ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } google::protobuf::Value::KindCase GetKindCase( const google::protobuf::MessageLite& message) const override { - return value_reflection_.GetKindCase( + return reflection_.Value().GetKindCase( google::protobuf::DownCastMessage(message)); } bool GetBoolValue(const google::protobuf::MessageLite& message) const override { - return value_reflection_.GetBoolValue( + return reflection_.Value().GetBoolValue( google::protobuf::DownCastMessage(message)); } double GetNumberValue(const google::protobuf::MessageLite& message) const override { - return value_reflection_.GetNumberValue( + return reflection_.Value().GetNumberValue( google::protobuf::DownCastMessage(message)); } well_known_types::StringValue GetStringValue( const google::protobuf::MessageLite& message, std::string& scratch) const override { - return value_reflection_.GetStringValue( + return reflection_.Value().GetStringValue( google::protobuf::DownCastMessage(message), scratch); } const google::protobuf::MessageLite& GetListValue( const google::protobuf::MessageLite& message) const override { - return value_reflection_.GetListValue( + return reflection_.Value().GetListValue( google::protobuf::DownCastMessage(message)); } int ValuesSize(const google::protobuf::MessageLite& message) const override { - return list_value_reflection_.ValuesSize( + return reflection_.ListValue().ValuesSize( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, int index) const override { - return list_value_reflection_.Values( + return reflection_.ListValue().Values( google::protobuf::DownCastMessage(message), index); } const google::protobuf::MessageLite& GetStructValue( const google::protobuf::MessageLite& message) const override { - return value_reflection_.GetStructValue( + return reflection_.Value().GetStructValue( google::protobuf::DownCastMessage(message)); } int FieldsSize(const google::protobuf::MessageLite& message) const override { - return struct_reflection_.FieldsSize( + return reflection_.Struct().FieldsSize( google::protobuf::DownCastMessage(message)); } absl::Nullable FindField( const google::protobuf::MessageLite& message, absl::string_view name) const override { - return struct_reflection_.FindField( + return reflection_.Struct().FindField( google::protobuf::DownCastMessage(message), name); } JsonMapIterator IterateFields( const google::protobuf::MessageLite& message) const override { - return struct_reflection_.BeginFields( + return reflection_.Struct().BeginFields( google::protobuf::DownCastMessage(message)); } private: - ValueReflection value_reflection_; - ListValueReflection list_value_reflection_; - StructReflection struct_reflection_; + JsonReflection reflection_; }; std::string JsonStringDebugString(const well_known_types::StringValue& value) { diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index f18d11b03..2e9cae6c6 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1737,6 +1737,47 @@ absl::StatusOr GetFieldMaskReflection( return reflection; } +absl::Status JsonReflection::Initialize( + absl::Nonnull pool) { + CEL_RETURN_IF_ERROR(Value().Initialize(pool)); + CEL_RETURN_IF_ERROR(ListValue().Initialize(pool)); + CEL_RETURN_IF_ERROR(Struct().Initialize(pool)); + return absl::OkStatus(); +} + +absl::Status JsonReflection::Initialize( + absl::Nonnull descriptor) { + switch (descriptor->well_known_type()) { + case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: + CEL_RETURN_IF_ERROR(Value().Initialize(descriptor)); + CEL_RETURN_IF_ERROR( + ListValue().Initialize(Value().GetListValueDescriptor())); + CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); + return absl::OkStatus(); + case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: + CEL_RETURN_IF_ERROR(ListValue().Initialize(descriptor)); + CEL_RETURN_IF_ERROR(Value().Initialize(ListValue().GetValueDescriptor())); + CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); + return absl::OkStatus(); + case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: + CEL_RETURN_IF_ERROR(Struct().Initialize(descriptor)); + CEL_RETURN_IF_ERROR(Value().Initialize(Struct().GetValueDescriptor())); + CEL_RETURN_IF_ERROR( + ListValue().Initialize(Value().GetListValueDescriptor())); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError( + absl::StrCat("expected message to be JSON-like well known type: ", + descriptor->full_name(), " ", + WellKnownTypeToString(descriptor->well_known_type()))); + } +} + +bool JsonReflection::IsInitialized() const { + return Value().IsInitialized() && ListValue().IsInitialized() && + Struct().IsInitialized(); +} + namespace { [[maybe_unused]] ABSL_CONST_INIT absl::once_flag @@ -1781,9 +1822,7 @@ absl::Status Reflection::Initialize(absl::Nonnull pool) { CEL_RETURN_IF_ERROR(Any().Initialize(pool)); CEL_RETURN_IF_ERROR(Duration().Initialize(pool)); CEL_RETURN_IF_ERROR(Timestamp().Initialize(pool)); - CEL_RETURN_IF_ERROR(Value().Initialize(pool)); - CEL_RETURN_IF_ERROR(ListValue().Initialize(pool)); - CEL_RETURN_IF_ERROR(Struct().Initialize(pool)); + CEL_RETURN_IF_ERROR(Json().Initialize(pool)); // google.protobuf.FieldMask is not strictly mandatory, but we do have to // treat it specifically for JSON. So use it if we have it. if (const auto* descriptor = @@ -1802,8 +1841,7 @@ bool Reflection::IsInitialized() const { FloatValue().IsInitialized() && DoubleValue().IsInitialized() && BytesValue().IsInitialized() && StringValue().IsInitialized() && Any().IsInitialized() && Duration().IsInitialized() && - Timestamp().IsInitialized() && Value().IsInitialized() && - ListValue().IsInitialized() && Struct().IsInitialized(); + Timestamp().IsInitialized() && Json().IsInitialized(); } namespace { diff --git a/internal/well_known_types.h b/internal/well_known_types.h index fa4fe485c..2cef32a96 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -1383,6 +1383,44 @@ absl::StatusOr AdaptFromMessage( ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); +class JsonReflection final { + public: + JsonReflection() = default; + JsonReflection(const JsonReflection&) = default; + JsonReflection& operator=(const JsonReflection&) = default; + + absl::Status Initialize(absl::Nonnull pool); + + absl::Status Initialize(absl::Nonnull descriptor); + + bool IsInitialized() const; + + ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + + ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return list_value_; + } + + StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { return struct_; } + + const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return value_; + } + + const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return list_value_; + } + + const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return struct_; + } + + private: + ValueReflection value_; + ListValueReflection list_value_; + StructReflection struct_; +}; + class Reflection final { public: Reflection() = default; @@ -1443,13 +1481,19 @@ class Reflection final { return timestamp_; } - ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } + JsonReflection& Json() ABSL_ATTRIBUTE_LIFETIME_BOUND { return json_; } + + ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Value(); + } ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { - return list_value_; + return Json().ListValue(); } - StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { return struct_; } + StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Json().Struct(); + } FieldMaskReflection& FieldMask() ABSL_ATTRIBUTE_LIFETIME_BOUND { return field_mask_; @@ -1507,16 +1551,20 @@ class Reflection final { return timestamp_; } + const JsonReflection& Json() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return json_; + } + const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return value_; + return Json().Value(); } const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return list_value_; + return Json().ListValue(); } const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return struct_; + return Json().Struct(); } const FieldMaskReflection& FieldMask() const ABSL_ATTRIBUTE_LIFETIME_BOUND { @@ -1545,9 +1593,7 @@ class Reflection final { AnyReflection any_; DurationReflection duration_; TimestampReflection timestamp_; - ValueReflection value_; - ListValueReflection list_value_; - StructReflection struct_; + JsonReflection json_; FieldMaskReflection field_mask_; }; From 5ce19bab9adf3099b1fcb5c88b0c5583dfe5ad0e Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 7 Nov 2024 08:11:40 -0800 Subject: [PATCH 028/180] Update conformance to correctly handle applying function declarations from the test file to the TypeCheckerBuilder. PiperOrigin-RevId: 694123081 --- conformance/service.cc | 7 ++++++- conformance/value_conversion.cc | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/conformance/service.cc b/conformance/service.cc index da5ca39a7..55736ef45 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -655,10 +655,15 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { } for (const auto& param : overload_pb.params()) { CEL_ASSIGN_OR_RETURN(auto param_type, - FromConformanceType(arena, param.type())); + FromConformanceType(arena, param)); overload.mutable_args().push_back(param_type); } + CEL_ASSIGN_OR_RETURN( + auto return_type, + FromConformanceType(arena, overload_pb.result_type())); + overload.set_result(return_type); + CEL_RETURN_IF_ERROR(fn_decl.AddOverload(std::move(overload))); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); diff --git a/conformance/value_conversion.cc b/conformance/value_conversion.cc index 9c8f8c361..34b7785f1 100644 --- a/conformance/value_conversion.cc +++ b/conformance/value_conversion.cc @@ -421,7 +421,9 @@ absl::StatusOr FromConformanceType(google::protobuf::Arena* arena, FromConformanceType(arena, param)); parameters.push_back(std::move(param_type)); } - return OpaqueType(arena, type.abstract_type().name(), parameters); + const auto* name = google::protobuf::Arena::Create( + arena, type.abstract_type().name()); + return OpaqueType(arena, *name, parameters); } default: return absl::UnimplementedError(absl::StrCat( From f15d8e6e2d02f8ef24726ff1069905791b8f8a4f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 7 Nov 2024 08:24:49 -0800 Subject: [PATCH 029/180] Add documentation for baseline test utility, update formatting to match java in a few cases. PiperOrigin-RevId: 694126823 --- parser/parser_test.cc | 2 +- testutil/BUILD | 1 + testutil/baseline_tests.cc | 4 +++- testutil/baseline_tests.h | 4 ++++ testutil/baseline_tests_test.cc | 2 +- testutil/expr_printer.cc | 4 ++++ 6 files changed, 14 insertions(+), 3 deletions(-) diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 0838fbfff..af31800bf 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -100,7 +100,7 @@ std::vector test_cases = { {"x * 2.0", "_*_(\n" " x^#1:Expr.Ident#,\n" - " 2.^#3:double#\n" + " 2.0^#3:double#\n" ")^#2:Expr.Call#"}, {"\"\\u2764\"", "\"\u2764\"^#1:string#"}, {"\"\u2764\"", "\"\u2764\"^#1:string#"}, diff --git a/testutil/BUILD b/testutil/BUILD index 96124bb06..eb62e3e8f 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -30,6 +30,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index 18ef9cd7b..c636b4c02 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -15,6 +15,8 @@ #include "testutil/baseline_tests.h" #include +#include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -95,7 +97,7 @@ std::string FormatType(const AstType& t) { } else if (t.has_list_type()) { return absl::StrCat("list(", FormatType(t.list_type().elem_type()), ")"); } else if (t.has_map_type()) { - return absl::StrCat("map(", FormatType(t.map_type().key_type()), ",", + return absl::StrCat("map(", FormatType(t.map_type().key_type()), ", ", FormatType(t.map_type().value_type()), ")"); } return ""; diff --git a/testutil/baseline_tests.h b/testutil/baseline_tests.h index bcb7852a2..35d85de4c 100644 --- a/testutil/baseline_tests.h +++ b/testutil/baseline_tests.h @@ -46,8 +46,12 @@ namespace cel::test { +// Returns a string representation of the AST that matches the baseline format +// used in tests across the CEL libraries. std::string FormatBaselineAst(const Ast& ast); +// Returns a string representation of the protobuf AST that matches the baseline +// format used in tests across the CEL libraries. std::string FormatBaselineCheckedExpr( const cel::expr::CheckedExpr& checked); diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index 325597d42..bfc80c12e 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -212,7 +212,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_unique(ast_internal::PrimitiveType::kString), std::make_unique( ast_internal::PrimitiveType::kString))), - "x~map(string,string)"}, + "x~map(string, string)"}, TestCase{AstType(ast_internal::ListType(std::make_unique( ast_internal::PrimitiveType::kString))), "x~list(string)"})); diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index d38d7aa77..1c0fd8819 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -21,6 +21,7 @@ #include "absl/base/no_destructor.h" #include "absl/log/absl_log.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "base/ast_internal/ast_impl.h" #include "common/ast.h" @@ -285,6 +286,9 @@ class StringBuilder { auto idx = std::find_if_not(s.rbegin(), s.rend(), [](const char c) { return c == '0'; }); s.erase(idx.base(), s.end()); + if (absl::EndsWith(s, ".")) { + s += '0'; + } return s; } case ConstantKindCase::kInt: From 410d43926545bf6242fe3450bf8522666b82a1f7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 7 Nov 2024 10:07:22 -0800 Subject: [PATCH 030/180] Add interfaces and initial implementation for cel::Compiler. PiperOrigin-RevId: 694158596 --- checker/BUILD | 1 + checker/type_checker_builder.cc | 5 +- checker/type_checker_builder.h | 9 +- checker/validation_result.h | 13 +++ compiler/BUILD | 78 +++++++++++++ compiler/compiler.h | 117 +++++++++++++++++++ compiler/compiler_factory.cc | 135 +++++++++++++++++++++ compiler/compiler_factory.h | 70 +++++++++++ compiler/compiler_factory_test.cc | 187 ++++++++++++++++++++++++++++++ 9 files changed, 610 insertions(+), 5 deletions(-) create mode 100644 compiler/BUILD create mode 100644 compiler/compiler.h create mode 100644 compiler/compiler_factory.cc create mode 100644 compiler/compiler_factory.h create mode 100644 compiler/compiler_factory_test.cc diff --git a/checker/BUILD b/checker/BUILD index 7a5ffab13..df9049a12 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":type_check_issue", "//common:ast", + "//common:source", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index aada156ed..f7c2e9064 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -117,7 +117,10 @@ absl::Status TypeCheckerBuilder::AddLibrary(CheckerLibrary library) { return absl::AlreadyExistsError( absl::StrCat("library '", library.id, "' already exists")); } - absl::Status status = library.options(*this); + if (!library.configure) { + return absl::OkStatus(); + } + absl::Status status = library.configure(*this); libraries_.push_back(std::move(library)); return status; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 1253c0cae..f4b3386a7 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -66,15 +66,16 @@ absl::StatusOr CreateTypeCheckerBuilder( descriptor_pool, const CheckerOptions& options = {}); -using ConfigureBuilderCallback = - absl::AnyInvocable; +// Functional implementation to apply the library features to a +// TypeCheckerBuilder. +using TypeCheckerBuilderConfigurer = + absl::AnyInvocable; struct CheckerLibrary { // Optional identifier to avoid collisions re-adding the same declarations. // If id is empty, it is not considered. std::string id; - // Functional implementation applying the library features to the builder. - ConfigureBuilderCallback options; + TypeCheckerBuilderConfigurer configure; }; // Builder for TypeChecker instances. diff --git a/checker/validation_result.h b/checker/validation_result.h index a094915e7..c5ed50b35 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -25,6 +25,7 @@ #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" +#include "common/source.h" namespace cel { @@ -56,9 +57,21 @@ class ValidationResult { absl::Span GetIssues() const { return issues_; } + // The source expression may optionally be set if it is available. + absl::Nullable GetSource() const { return source_.get(); } + + void SetSource(std::unique_ptr source) { + source_ = std::move(source); + } + + absl::Nullable> ReleaseSource() { + return std::move(source_); + } + private: absl::Nullable> ast_; std::vector issues_; + absl::Nullable> source_; }; } // namespace cel diff --git a/compiler/BUILD b/compiler/BUILD new file mode 100644 index 000000000..7af72d5cd --- /dev/null +++ b/compiler/BUILD @@ -0,0 +1,78 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "compiler", + hdrs = ["compiler.h"], + deps = [ + "//checker:checker_options", + "//checker:type_checker_builder", + "//checker:validation_result", + "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "compiler_factory", + srcs = ["compiler_factory.cc"], + hdrs = ["compiler_factory.h"], + deps = [ + ":compiler", + "//checker:type_checker", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:source", + "//internal:noop_delete", + "//internal:status_macros", + "//parser", + "//parser:parser_interface", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "compiler_factory_test", + srcs = ["compiler_factory_test.cc"], + deps = [ + ":compiler", + ":compiler_factory", + "//checker:optional", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:parser_interface", + "//testutil:baseline_tests", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/compiler/compiler.h b/compiler/compiler.h new file mode 100644 index 000000000..1ea80d9da --- /dev/null +++ b/compiler/compiler.h @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/checker_options.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "parser/options.h" +#include "parser/parser_interface.h" + +namespace cel { + +class Compiler; +class CompilerBuilder; + +// Callable for configuring a ParserBuilder. +using ParserBuilderConfigurer = + absl::AnyInvocable; + +// A CompilerLibrary represents a package of CEL configuration that can be +// added to a Compiler. +// +// It may contain either or both of a Parser configuration and a +// TypeChecker configuration. +struct CompilerLibrary { + // Optional identifier to avoid collisions re-adding the same library. + // If id is empty, it is not considered. + std::string id; + // Optional callback for configuring the parser. + ParserBuilderConfigurer configure_parser; + // Optional callback for configuring the type checker. + TypeCheckerBuilderConfigurer configure_checker; + + CompilerLibrary(std::string id, ParserBuilderConfigurer configure_parser, + TypeCheckerBuilderConfigurer configure_checker = nullptr) + : id(std::move(id)), + configure_parser(std::move(configure_parser)), + configure_checker(std::move(configure_checker)) {} + + CompilerLibrary(std::string id, + TypeCheckerBuilderConfigurer configure_checker) + : id(std::move(id)), + configure_parser(std::move(nullptr)), + configure_checker(std::move(configure_checker)) {} + + // Convenience conversion from the CheckerLibrary type. + // NOLINTNEXTLINE(google-explicit-constructor) + CompilerLibrary(CheckerLibrary checker_library) + : id(std::move(checker_library.id)), + configure_parser(nullptr), + configure_checker(std::move(checker_library.configure)) {} +}; + +// General options for configuring the underlying parser and checker. +struct CompilerOptions { + ParserOptions parser_options; + CheckerOptions checker_options; +}; + +// Interface for CEL CompilerBuilder objects. +// +// Builder implementations are thread hostile, but should create +// thread-compatible Compiler instances. +class CompilerBuilder { + public: + virtual ~CompilerBuilder() = default; + + virtual absl::Status AddLibrary(cel::CompilerLibrary library) = 0; + + virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; + virtual ParserBuilder& GetParserBuilder() = 0; + + virtual absl::StatusOr> Build() && = 0; +}; + +// Interface for CEL Compiler objects. +// +// For CEL, compilation is the process of bundling the parse and type-check +// passes. +// +// Compiler instances should be thread-compatible. +class Compiler { + public: + virtual ~Compiler() = default; + + virtual absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const = 0; + + absl::StatusOr Compile(absl::string_view source) const { + return Compile(source, ""); + } +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc new file mode 100644 index 000000000..4a1f3209f --- /dev/null +++ b/compiler/compiler_factory.cc @@ -0,0 +1,135 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_factory.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/source.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace { + +class CompilerImpl : public Compiler { + public: + CompilerImpl(std::unique_ptr type_checker, + std::unique_ptr parser) + : type_checker_(std::move(type_checker)), parser_(std::move(parser)) {} + + absl::StatusOr Compile( + absl::string_view expression, + absl::string_view description) const override { + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(expression, std::string(description))); + CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + CEL_ASSIGN_OR_RETURN(ValidationResult result, + type_checker_->Check(std::move(ast))); + + result.SetSource(std::move(source)); + return result; + } + + private: + std::unique_ptr type_checker_; + std::unique_ptr parser_; +}; + +class CompilerBuilderImpl : public CompilerBuilder { + public: + CompilerBuilderImpl(TypeCheckerBuilder type_checker_builder, + std::unique_ptr parser_builder) + : type_checker_builder_(std::move(type_checker_builder)), + parser_builder_(std::move(parser_builder)) {} + + absl::Status AddLibrary(CompilerLibrary library) override { + if (!library.id.empty()) { + auto [it, inserted] = library_ids_.insert(library.id); + + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("library already exists: ", library.id)); + } + } + + if (library.configure_checker) { + CEL_RETURN_IF_ERROR(type_checker_builder_.AddLibrary({ + .id = std::move(library.id), + .configure = std::move(library.configure_checker), + })); + } + if (library.configure_parser) { + parser_libraries_.push_back(std::move(library.configure_parser)); + } + return absl::OkStatus(); + } + + ParserBuilder& GetParserBuilder() override { return *parser_builder_; } + TypeCheckerBuilder& GetCheckerBuilder() override { + return type_checker_builder_; + } + + absl::StatusOr> Build() && override { + for (const auto& library : parser_libraries_) { + CEL_RETURN_IF_ERROR(library(*parser_builder_)); + } + CEL_ASSIGN_OR_RETURN(auto parser, std::move(*parser_builder_).Build()); + CEL_ASSIGN_OR_RETURN(auto type_checker, + std::move(type_checker_builder_).Build()); + return std::make_unique(std::move(type_checker), + std::move(parser)); + } + + private: + TypeCheckerBuilder type_checker_builder_; + std::unique_ptr parser_builder_; + + absl::flat_hash_set library_ids_; + std::vector parser_libraries_; +}; + +} // namespace + +absl::StatusOr> NewCompilerBuilder( + std::shared_ptr descriptor_pool, + CompilerOptions options) { + if (descriptor_pool == nullptr) { + return absl::InvalidArgumentError("descriptor_pool must not be null"); + } + CEL_ASSIGN_OR_RETURN(auto type_checker_builder, + CreateTypeCheckerBuilder(std::move(descriptor_pool), + options.checker_options)); + auto parser_builder = NewParserBuilder(options.parser_options); + + return std::make_unique(std::move(type_checker_builder), + std::move(parser_builder)); +} + +} // namespace cel diff --git a/compiler/compiler_factory.h b/compiler/compiler_factory.h new file mode 100644 index 000000000..5339be4c1 --- /dev/null +++ b/compiler/compiler_factory.h @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "internal/noop_delete.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a new unconfigured CompilerBuilder for creating a new CEL Compiler +// instance. +// +// The builder is thread-hostile and intended to be configured by a single +// thread, but the created Compiler instances are thread-compatible (and +// effectively immutable). +// +// The descriptor pool must include the standard definitions for the protobuf +// well-known types: +// - google.protobuf.NullValue +// - google.protobuf.BoolValue +// - google.protobuf.Int32Value +// - google.protobuf.Int64Value +// - google.protobuf.UInt32Value +// - google.protobuf.UInt64Value +// - google.protobuf.FloatValue +// - google.protobuf.DoubleValue +// - google.protobuf.BytesValue +// - google.protobuf.StringValue +// - google.protobuf.Any +// - google.protobuf.Duration +// - google.protobuf.Timestamp +absl::StatusOr> NewCompilerBuilder( + std::shared_ptr descriptor_pool, + CompilerOptions options = {}); + +// Convenience overload for non-owning pointers (such as the generated pool). +// The descriptor pool must outlive the compiler builder and any compiler +// instances it builds. +inline absl::StatusOr> NewCompilerBuilder( + absl::Nonnull descriptor_pool, + CompilerOptions options = {}) { + return NewCompilerBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + std::move(options)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc new file mode 100644 index 000000000..85c2711b7 --- /dev/null +++ b/compiler/compiler_factory_test.cc @@ -0,0 +1,187 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/compiler_factory.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/optional.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/parser_interface.h" +#include "testutil/baseline_tests.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::FormatBaselineAst; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::Property; + +TEST(CompilerFactoryTest, Works) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("['a', 'b', 'c'].exists(x, x in ['c', 'd', 'e']) && 10 " + "< (5 % 3 * 2 + 1 - 2)")); + + ASSERT_TRUE(result.IsValid()); + + EXPECT_EQ(FormatBaselineAst(*result.GetAst()), + R"(_&&_( + __comprehension__( + // Variable + x, + // Target + [ + "a"~string, + "b"~string, + "c"~string + ]~list(string), + // Accumulator + __result__, + // Init + false~bool, + // LoopCondition + @not_strictly_false( + !_( + __result__~bool^__result__ + )~bool^logical_not + )~bool^not_strictly_false, + // LoopStep + _||_( + __result__~bool^__result__, + @in( + x~string^x, + [ + "c"~string, + "d"~string, + "e"~string + ]~list(string) + )~bool^in_list + )~bool^logical_or, + // Result + __result__~bool^__result__)~bool, + _<_( + 10~int, + _-_( + _+_( + _*_( + _%_( + 5~int, + 3~int + )~int^modulo_int64, + 2~int + )~int^multiply_int64, + 1~int + )~int^add_int64, + 2~int + )~int^subtract_int64 + )~bool^less_int64 +)~bool^logical_and)"); +} + +TEST(CompilerFactoryTest, ParserLibrary) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT( + builder->AddLibrary({"test", + [](ParserBuilder& builder) -> absl::Status { + builder.GetOptions().disable_standard_macros = + true; + return builder.AddMacro(cel::HasMacro()); + }}), + IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_THAT(compiler->Compile("has(a.b)"), IsOk()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[].map(x, x)")); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'map'")))) + << result.GetIssues()[2].message(); +} + +TEST(CompilerFactoryTest, ParserOptions) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + builder->GetParserBuilder().GetOptions().enable_optional_syntax = true; + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_THAT(compiler->Compile("a.?b.orValue('foo')"), IsOk()); +} + +TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), + StatusIs(absl::StatusCode::kAlreadyExists, + HasSubstr("library already exists: stdlib"))); +} + +TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { + std::shared_ptr pool = + internal::GetSharedTestingDescriptorPool(); + pool.reset(); + ASSERT_THAT( + NewCompilerBuilder(std::move(pool)), + absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("descriptor_pool must not be null"))); +} + +} // namespace +} // namespace cel From 74ba25a626f987e043f1dcbf16c5787f7337ba0d Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 7 Nov 2024 11:51:34 -0800 Subject: [PATCH 031/180] Add compiler library for optionals. PiperOrigin-RevId: 694198328 --- compiler/BUILD | 36 +++++ compiler/optional.cc | 38 ++++++ compiler/optional.h | 26 ++++ compiler/optional_test.cc | 275 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 375 insertions(+) create mode 100644 compiler/optional.cc create mode 100644 compiler/optional.h create mode 100644 compiler/optional_test.cc diff --git a/compiler/BUILD b/compiler/BUILD index 7af72d5cd..7cfe940e5 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -76,3 +76,39 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "optional", + srcs = ["optional.cc"], + hdrs = ["optional.h"], + deps = [ + ":compiler", + "//checker:optional", + "//checker:type_checker_builder", + "//parser:parser_interface", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "optional_test", + srcs = ["optional_test.cc"], + deps = [ + ":compiler", + ":compiler_factory", + ":optional", + "//checker:standard_library", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:decl", + "//common:source", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:baseline_tests", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + ], +) diff --git a/compiler/optional.cc b/compiler/optional.cc new file mode 100644 index 000000000..785833989 --- /dev/null +++ b/compiler/optional.cc @@ -0,0 +1,38 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/optional.h" + +#include + +#include "absl/status/status.h" +#include "checker/optional.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "parser/parser_interface.h" + +namespace cel { + +CompilerLibrary OptionalCompilerLibrary() { + CheckerLibrary checker_library = OptionalCheckerLibrary(); + return CompilerLibrary( + std::move(checker_library.id), + [](ParserBuilder& builder) { + builder.GetOptions().enable_optional_syntax = true; + return absl::OkStatus(); + }, + std::move(checker_library.configure)); +} + +} // namespace cel diff --git a/compiler/optional.h b/compiler/optional.h new file mode 100644 index 000000000..cc804ddbd --- /dev/null +++ b/compiler/optional.h @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ +#define THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ + +#include "compiler/compiler.h" + +namespace cel { + +// CompilerLibrary that enables support for CEL optional types. +CompilerLibrary OptionalCompilerLibrary(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ diff --git a/compiler/optional_test.cc b/compiler/optional_test.cc new file mode 100644 index 000000000..e26f1d1f3 --- /dev/null +++ b/compiler/optional_test.cc @@ -0,0 +1,275 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "compiler/optional.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "checker/standard_library.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/baseline_tests.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::proto3::TestAllTypes; +using ::cel::test::FormatBaselineAst; +using ::testing::HasSubstr; + +struct TestCase { + std::string expr; + std::string expected_ast; +}; + +class OptionalTest : public testing::TestWithParam {}; + +std::string FormatIssues(const ValidationResult& result) { + const Source* source = result.GetSource(); + return absl::StrJoin( + result.GetIssues(), "\n", + [=](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend( + out, (source) ? issue.ToDisplayString(*source) : issue.message()); + }); +} + +TEST_P(OptionalTest, OptionalsEnabled) { + const TestCase& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( + "msg", MessageType(TestAllTypes::descriptor()))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + absl::StatusOr maybe_result = + compiler->Compile(test_case.expr); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, std::move(maybe_result)); + ASSERT_TRUE(result.IsValid()) << FormatIssues(result); + EXPECT_EQ(FormatBaselineAst(*result.GetAst()), + absl::StripAsciiWhitespace(test_case.expected_ast)) + << test_case.expr; +} + +INSTANTIATE_TEST_SUITE_P( + OptionalTest, OptionalTest, + ::testing::Values( + TestCase{ + .expr = "msg.?single_int64", + .expected_ast = R"( +_?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "single_int64" +)~optional_type(int)^select_optional_field)", + }, + TestCase{ + .expr = "optional.of('foo')", + .expected_ast = R"( +optional.of( + "foo"~string +)~optional_type(string)^optional_of)", + }, + TestCase{ + .expr = "optional.of('foo').optMap(x, x)", + .expected_ast = R"( +_?_:_( + optional.of( + "foo"~string + )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, + optional.of( + __comprehension__( + // Variable + #unused, + // Target + []~list(dyn), + // Accumulator + x, + // Init + optional.of( + "foo"~string + )~optional_type(string)^optional_of.value()~string^optional_value, + // LoopCondition + false~bool, + // LoopStep + x~string^x, + // Result + x~string^x)~string + )~optional_type(string)^optional_of, + optional.none()~optional_type(string)^optional_none +)~optional_type(string)^conditional +)", + }, + TestCase{ + .expr = "optional.of('foo').optFlatMap(x, optional.of(x))", + .expected_ast = R"( +_?_:_( + optional.of( + "foo"~string + )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, + __comprehension__( + // Variable + #unused, + // Target + []~list(dyn), + // Accumulator + x, + // Init + optional.of( + "foo"~string + )~optional_type(string)^optional_of.value()~string^optional_value, + // LoopCondition + false~bool, + // LoopStep + x~string^x, + // Result + optional.of( + x~string^x + )~optional_type(string)^optional_of)~optional_type(string), + optional.none()~optional_type(string)^optional_none +)~optional_type(string)^conditional +)", + }, + TestCase{ + .expr = "optional.ofNonZeroValue(1)", + .expected_ast = R"( +optional.ofNonZeroValue( + 1~int +)~optional_type(int)^optional_ofNonZeroValue +)", + }, + TestCase{ + .expr = "[0][?1]", + .expected_ast = R"( +_[?_]( + [ + 0~int + ]~list(int), + 1~int +)~optional_type(int)^list_optindex_optional_int +)", + }, + TestCase{ + .expr = "{0: 2}[?1]", + .expected_ast = R"( +_[?_]( + { + 0~int:2~int + }~map(int, int), + 1~int +)~optional_type(int)^map_optindex_optional_value +)", + }, + TestCase{ + .expr = "msg.?repeated_int64[1]", + .expected_ast = R"( +_[_]( + _?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "repeated_int64" + )~optional_type(list(int))^select_optional_field, + 1~int +)~optional_type(int)^optional_list_index_int +)", + }, + TestCase{ + .expr = "msg.?map_int64_int64[1]", + .expected_ast = R"( +_[_]( + _?._( + msg~cel.expr.conformance.proto3.TestAllTypes^msg, + "map_int64_int64" + )~optional_type(map(int, int))^select_optional_field, + 1~int +)~optional_type(int)^optional_map_index_value +)", + }, + TestCase{ + .expr = "optional.of(1).or(optional.of(2))", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.or( + optional.of( + 2~int + )~optional_type(int)^optional_of +)~optional_type(int)^optional_or_optional)", + }, + TestCase{ + .expr = "optional.of(1).orValue(2)", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.orValue( + 2~int +)~int^optional_orValue_value +)", + }, + TestCase{ + .expr = "optional.of(1).value()", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.value()~int^optional_value +)", + }, + TestCase{ + .expr = "optional.of(1).hasValue()", + .expected_ast = R"( +optional.of( + 1~int +)~optional_type(int)^optional_of.hasValue()~bool^optional_hasValue +)", + })); + +TEST(OptionalTest, NotEnabled) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( + "msg", MessageType(TestAllTypes::descriptor()))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("optional.of(1)")); + + EXPECT_THAT(FormatIssues(result), + HasSubstr("undeclared reference to 'optional'")); +} + +} // namespace +} // namespace cel From 2f43982fbeb5519b92f5d09069080f7d92a54092 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 7 Nov 2024 14:58:01 -0800 Subject: [PATCH 032/180] Fix gcc warning about shadowing base class virtual overload. Adapted from https://github.com/google/cel-cpp/pull/1048. PiperOrigin-RevId: 694258444 --- eval/public/containers/internal_field_backed_map_impl.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/eval/public/containers/internal_field_backed_map_impl.h b/eval/public/containers/internal_field_backed_map_impl.h index ec773d9d2..cf43866b7 100644 --- a/eval/public/containers/internal_field_backed_map_impl.h +++ b/eval/public/containers/internal_field_backed_map_impl.h @@ -45,6 +45,10 @@ class FieldBackedMapImpl : public CelMap { absl::StatusOr ListKeys() const override; + // Include base class definitions to avoid GCC warnings about hidden virtual + // overloads. + using CelMap::ListKeys; + protected: // These methods are exposed as protected methods for testing purposes since // whether one or the other is used depends on build time flags, but each From c4f9b589c5555cf93e36ccc76510594d9258bf7a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 8 Nov 2024 18:43:54 -0800 Subject: [PATCH 033/180] Fix for slot calculation for block expressions. With this change, all comprehension variables inside cel.@block are assigned to a dedicated slot. It's hard to identify when they can be safely reused with lazy evaluation and likely doesn't provide enough benefit to try to reuse them more aggressively. PiperOrigin-RevId: 694717283 --- eval/compiler/flat_expr_builder.cc | 16 ++++++++++++++-- runtime/internal/BUILD | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index e787f3411..af34b13a7 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -886,6 +886,7 @@ class FlatExprVisitor : public cel::AstVisitor { block.bindings_set.insert(&list_expr_element.expr()); } block.index = index_manager().ReserveSlots(block.size); + block.slot_count = block.size; block.expr = &expr; block.bindings = &call_expr.args()[0]; block.bound = &call_expr.args()[1]; @@ -1203,8 +1204,8 @@ class FlatExprVisitor : public cel::AstVisitor { BlockInfo& block = *block_; if (block.expr == &expr) { block.in = false; - index_manager().ReleaseSlots(block.size); - AddStep(CreateClearSlotsStep(block.index, block.size, -1)); + index_manager().ReleaseSlots(block.slot_count); + AddStep(CreateClearSlotsStep(block.index, block.slot_count, -1)); return; } } @@ -1269,6 +1270,7 @@ class FlatExprVisitor : public cel::AstVisitor { size_t iter_slot, accu_slot, slot_count; bool is_bind = IsBind(&comprehension); + if (is_bind) { accu_slot = iter_slot = index_manager_.ReserveSlots(1); slot_count = 1; @@ -1277,6 +1279,14 @@ class FlatExprVisitor : public cel::AstVisitor { accu_slot = iter_slot + 1; slot_count = 2; } + + if (block_.has_value()) { + BlockInfo& block = *block_; + if (block.in) { + block.slot_count += slot_count; + slot_count = 0; + } + } // If this is in the scope of an optimized bind accu-init, account the slots // to the outermost bind-init scope. // @@ -1707,6 +1717,8 @@ class FlatExprVisitor : public cel::AstVisitor { // Starting slot index for `cel.@block`. We occupy he slot indices `index` // through `index + size + (var_size * 2)`. size_t index = 0; + // The total number of slots needed for evaluating the bound expressions. + size_t slot_count = 0; // The current slot index we are processing, any index references must be // less than this to be valid. size_t current_index = 0; diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 21500ca1a..a4577fd16 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -13,7 +13,7 @@ # limitations under the License. package( - # Under active development, not yet being released. + # Internals for cel/runtime. default_visibility = ["//visibility:public"], ) From 58c5ea9a3704ff54e9bb827f7320819a40c9a679 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 11 Nov 2024 06:28:50 -0800 Subject: [PATCH 034/180] Introduce upperAscii() string extension function PiperOrigin-RevId: 695311716 --- extensions/strings.cc | 13 +++++++++++ extensions/strings_test.cc | 48 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/extensions/strings.cc b/extensions/strings.cc index d49b43817..7ac4c2d2b 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -216,6 +216,14 @@ absl::StatusOr LowerAscii(ValueManager& value_manager, return value_manager.CreateUncheckedStringValue(std::move(content)); } +absl::StatusOr UpperAscii(ValueManager& value_manager, + const StringValue& string) { + std::string content = string.NativeString(); + absl::AsciiStrToUpper(&content); + // We assume the original string was well-formed. + return value_manager.CreateUncheckedStringValue(std::move(content)); +} + absl::StatusOr Replace2(ValueManager& value_manager, const StringValue& string, const StringValue& old_sub, @@ -291,6 +299,11 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CreateDescriptor("lowerAscii", /*receiver_style=*/true), UnaryFunctionAdapter, StringValue>::WrapFunction( LowerAscii))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, StringValue>:: + CreateDescriptor("upperAscii", /*receiver_style=*/true), + UnaryFunctionAdapter, StringValue>::WrapFunction( + UpperAscii))); CEL_RETURN_IF_ERROR(registry.Register( VariadicFunctionAdapter< absl::StatusOr, StringValue, StringValue, diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index 16e48b113..c41157676 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -173,5 +173,53 @@ TEST(Strings, ReplaceWithZeroLimit) { EXPECT_TRUE(result.GetBool().NativeValue()); } +TEST(Strings, LowerAscii) { + google::protobuf::Arena arena; + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("'UPPER lower'.lowerAscii() == 'upper lower'", + "", ParserOptions{})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + ASSERT_TRUE(result.Is()); + EXPECT_TRUE(result.GetBool().NativeValue()); +} + +TEST(Strings, UpperAscii) { + google::protobuf::Arena arena; + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse("'UPPER lower'.upperAscii() == 'UPPER LOWER'", + "", ParserOptions{})); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + ASSERT_TRUE(result.Is()); + EXPECT_TRUE(result.GetBool().NativeValue()); +} + } // namespace } // namespace cel::extensions From 818e9ced8612b9759b458c70db615df39254249a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 11 Nov 2024 11:34:34 -0800 Subject: [PATCH 035/180] Add declarations for string extension functions. PiperOrigin-RevId: 695423779 --- conformance/BUILD | 1 - conformance/service.cc | 2 + extensions/BUILD | 10 ++++ extensions/strings.cc | 119 +++++++++++++++++++++++++++++++++++++ extensions/strings.h | 3 + extensions/strings_test.cc | 72 ++++++++++++++++++++++ 6 files changed, 206 insertions(+), 1 deletion(-) diff --git a/conformance/BUILD b/conformance/BUILD index f2de5277a..c0a41a3a7 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -318,7 +318,6 @@ gen_conformance_tests( modern = True, skip_tests = _TESTS_TO_SKIP_MODERN + [ # TODO: Need to add function declarations for these extensions. - "string_ext", "math_ext", "encoders_ext", # block is a post-check optimization that inserts internal variables. The C++ type checker diff --git a/conformance/service.cc b/conformance/service.cc index 55736ef45..5fdd8f86b 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -640,6 +640,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder.AddLibrary(cel::extensions::StringsCheckerLibrary())); } for (const auto& decl : request.type_env()) { diff --git a/extensions/BUILD b/extensions/BUILD index 3b428eefd..0ebc0a603 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -329,7 +329,10 @@ cc_library( srcs = ["strings.cc"], hdrs = ["strings.h"], deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", "//common:casting", + "//common:decl", "//common:type", "//common:value", "//eval/public:cel_function_registry", @@ -340,6 +343,7 @@ cc_library( "//runtime:function_registry", "//runtime:runtime_options", "//runtime/internal:errors", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -353,7 +357,12 @@ cc_test( srcs = ["strings_test.cc"], deps = [ ":strings", + "//checker:standard_library", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:decl", "//common:value", + "//compiler:compiler_factory", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", @@ -364,6 +373,7 @@ cc_test( "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:baseline_tests", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:cord", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/extensions/strings.cc b/extensions/strings.cc index 7ac4c2d2b..535416261 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -21,12 +21,16 @@ #include #include +#include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" #include "common/casting.h" +#include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "common/value_manager.h" @@ -43,6 +47,8 @@ namespace cel::extensions { namespace { +using ::cel::checker_internal::BuiltinsArena; + struct AppendToStringVisitor { std::string& append_to; @@ -269,6 +275,115 @@ absl::StatusOr Replace1(ValueManager& value_manager, return Replace2(value_manager, string, old_sub, new_sub, -1); } +const Type& ListStringType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), StringType())); + return *kInstance; +} + +absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder) { + // Runtime Supported functions. + CEL_ASSIGN_OR_RETURN( + auto join_decl, + MakeFunctionDecl( + "join", + MakeMemberOverloadDecl("list_join", StringType(), ListStringType()), + MakeMemberOverloadDecl("list_join_string", StringType(), + ListStringType(), StringType()))); + CEL_ASSIGN_OR_RETURN( + auto split_decl, + MakeFunctionDecl( + "split", + MakeMemberOverloadDecl("string_split_string", ListStringType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_split_string_int", ListStringType(), + StringType(), StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto lower_decl, + MakeFunctionDecl("lowerAscii", + MakeMemberOverloadDecl("string_lower_ascii", + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto replace_decl, + MakeFunctionDecl( + "replace", + MakeMemberOverloadDecl("string_replace_string_string", StringType(), + StringType(), StringType(), StringType()), + MakeMemberOverloadDecl("string_replace_string_string_int", + StringType(), StringType(), StringType(), + StringType(), IntType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(join_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(split_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(lower_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(replace_decl))); + + // Additional functions described in the spec. + CEL_ASSIGN_OR_RETURN( + auto char_at_decl, + MakeFunctionDecl( + "charAt", MakeMemberOverloadDecl("string_char_at_int", StringType(), + StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto index_of_decl, + MakeFunctionDecl( + "indexOf", + MakeMemberOverloadDecl("string_index_of_string", IntType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_index_of_string_int", IntType(), + StringType(), StringType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto last_index_of_decl, + MakeFunctionDecl( + "lastIndexOf", + MakeMemberOverloadDecl("string_last_index_of_string", IntType(), + StringType(), StringType()), + MakeMemberOverloadDecl("string_last_index_of_string_int", IntType(), + StringType(), StringType(), IntType()))); + + CEL_ASSIGN_OR_RETURN( + auto substring_decl, + MakeFunctionDecl( + "substring", + MakeMemberOverloadDecl("string_substring_int", StringType(), + StringType(), IntType()), + MakeMemberOverloadDecl("string_substring_int_int", StringType(), + StringType(), IntType(), IntType()))); + CEL_ASSIGN_OR_RETURN( + auto upper_ascii_decl, + MakeFunctionDecl("upperAscii", + MakeMemberOverloadDecl("string_upper_ascii", + StringType(), StringType()))); + CEL_ASSIGN_OR_RETURN( + auto format_decl, + MakeFunctionDecl("format", + MakeMemberOverloadDecl("string_format", StringType(), + StringType(), ListType()))); + CEL_ASSIGN_OR_RETURN( + auto quote_decl, + MakeFunctionDecl( + "strings.quote", + MakeOverloadDecl("strings_quote", StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto reverse_decl, + MakeFunctionDecl("reverse", + MakeMemberOverloadDecl("string_reverse", StringType(), + StringType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(char_at_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(index_of_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last_index_of_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(substring_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(upper_ascii_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(format_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(quote_decl))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(reverse_decl))); + + return absl::OkStatus(); +} + } // namespace absl::Status RegisterStringsFunctions(FunctionRegistry& registry, @@ -327,4 +442,8 @@ absl::Status RegisterStringsFunctions( google::api::expr::runtime::ConvertToRuntimeOptions(options)); } +CheckerLibrary StringsCheckerLibrary() { + return {"strings", &RegisterStringsDecls}; +} + } // namespace cel::extensions diff --git a/extensions/strings.h b/extensions/strings.h index 4db2ab4ab..44f4a997e 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ #include "absl/status/status.h" +#include "checker/type_checker_builder.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" @@ -31,6 +32,8 @@ absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); +CheckerLibrary StringsCheckerLibrary(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index c41157676..cb793e6f6 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -20,7 +20,12 @@ #include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" +#include "checker/standard_library.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/decl.h" #include "common/value.h" +#include "compiler/compiler_factory.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -31,6 +36,7 @@ #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/baseline_tests.h" #include "google/protobuf/arena.h" namespace cel::extensions { @@ -40,6 +46,7 @@ using ::absl_testing::IsOk; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; +using ::testing::Values; TEST(Strings, SplitWithEmptyDelimiterCord) { google::protobuf::Arena arena; @@ -221,5 +228,70 @@ TEST(Strings, UpperAscii) { EXPECT_TRUE(result.GetBool().NativeValue()); } +TEST(StringsCheckerLibrary, SmokeTest) { + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StringsCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("foo", StringType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("foo.replace('he', 'we', 1) == 'wello hello'")); + ASSERT_TRUE(result.IsValid()); + + EXPECT_EQ(test::FormatBaselineAst(*result.GetAst()), + R"(_==_( + foo~string^foo.replace( + "he"~string, + "we"~string, + 1~int + )~string^string_replace_string_string_int, + "wello hello"~string +)~bool^equals)"); +} + +// Basic test for the included declarations. +// Additional coverage for behavior in the spec tests. +class StringsCheckerLibraryTest : public ::testing::TestWithParam { +}; + +TEST_P(StringsCheckerLibraryTest, TypeChecks) { + const std::string& expr = GetParam(); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(builder->AddLibrary(StringsCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(expr)); + EXPECT_TRUE(result.IsValid()) << "Failed to compile: " << expr; +} + +INSTANTIATE_TEST_SUITE_P( + Expressions, StringsCheckerLibraryTest, + Values("['a', 'b', 'c'].join() == 'abc'", + "['a', 'b', 'c'].join('|') == 'a|b|c'", + "'a|b|c'.split('|') == ['a', 'b', 'c']", + "'a|b|c'.split('|', 1) == ['a', 'b|c']", + "'a|b|c'.split('|') == ['a', 'b', 'c']", + "'AbC'.lowerAscii() == 'abc'", + "'tacocat'.replace('cat', 'dog') == 'tacodog'", + "'tacocat'.replace('aco', 'an', 2) == 'tacocat'", + "'tacocat'.charAt(2) == 'c'", "'tacocat'.indexOf('c') == 2", + "'tacocat'.indexOf('c', 3) == 4", "'tacocat'.lastIndexOf('c') == 4", + "'tacocat'.lastIndexOf('c', 5) == -1", + "'tacocat'.substring(1) == 'acocat'", + "'tacocat'.substring(1, 3) == 'aco'", "'aBc'.upperAscii() == 'ABC'", + "'abc %d'.format([2]) == 'abc 2'", + "strings.quote('abc') == \"'abc 2'\"", "'abc'.reverse() == 'cba'")); + } // namespace } // namespace cel::extensions From 2db04afae3c33fb99e93214f3fc88b77466f138a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 11 Nov 2024 17:19:10 -0800 Subject: [PATCH 036/180] Refactor extensions/math_ext to use register helpers for function bindings. PiperOrigin-RevId: 695534639 --- extensions/math_ext.cc | 333 +++++++++++++++++------------------------ 1 file changed, 140 insertions(+), 193 deletions(-) diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 85c89f6ec..68f1d6010 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -159,30 +159,26 @@ absl::StatusOr MaxList(ValueManager& value_manager, template absl::Status RegisterCrossNumericMin(FunctionRegistry& registry) { - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction(Min))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction(Min))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); return absl::OkStatus(); } template absl::Status RegisterCrossNumericMax(FunctionRegistry& registry) { - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction(Max))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction(Max))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); return absl::OkStatus(); } @@ -303,189 +299,140 @@ Value BitShiftRightUint(ValueManager&, uint64_t lhs, int64_t rhs) { absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Min))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Min))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMin, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Min))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Identity, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMin, Min, registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - kMathMin, false), - UnaryFunctionAdapter, ListValue>::WrapFunction( - MinList))); - - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(Identity))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Max))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Max))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - kMathMax, /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - Max))); + CEL_RETURN_IF_ERROR(( + UnaryFunctionAdapter, + ListValue>::RegisterGlobalOverload(kMathMin, MinList, + registry))); + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Identity, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + kMathMax, Max, registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - kMathMax, false), - UnaryFunctionAdapter, ListValue>::WrapFunction( - MaxList))); - - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.ceil", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(CeilDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.floor", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(FloorDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.round", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(RoundDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.trunc", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(TruncDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.isInf", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(IsInfDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.isNaN", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(IsNaNDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.isFinite", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(IsFiniteDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.abs", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(AbsDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.abs", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(AbsInt))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.abs", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(AbsUint))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.sign", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(SignDouble))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.sign", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(SignInt))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.sign", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(SignUint))); - - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitAnd", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitAndInt))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitAnd", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitAndUint))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitOr", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitOrInt))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitOr", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitOrUint))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitXor", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitXorInt))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitXor", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitXorUint))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.bitNot", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(BitNotInt))); - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor( - "math.bitNot", /*receiver_style=*/false), - UnaryFunctionAdapter::WrapFunction(BitNotUint))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitShiftLeft", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitShiftLeftInt))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitShiftLeft", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitShiftLeftUint))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitShiftRight", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitShiftRightInt))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor( - "math.bitShiftRight", /*receiver_style=*/false), - BinaryFunctionAdapter::WrapFunction( - BitShiftRightUint))); + CEL_RETURN_IF_ERROR(( + UnaryFunctionAdapter, + ListValue>::RegisterGlobalOverload(kMathMax, MaxList, + registry))); + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.ceil", CeilDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.floor", FloorDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.round", RoundDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.trunc", TruncDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isInf", IsInfDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isNaN", IsNaNDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.isFinite", IsFiniteDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.abs", AbsUint, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sign", SignUint, registry))); + + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitAnd", BitAndInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitAnd", + BitAndUint, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitOr", BitOrInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitOr", + BitOrUint, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitXor", BitXorInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitXor", + BitXorUint, + registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.bitNot", BitNotInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.bitNot", BitNotUint, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftLeft", BitShiftLeftInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftLeft", BitShiftLeftUint, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftRight", BitShiftRightInt, registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterGlobalOverload( + "math.bitShiftRight", BitShiftRightUint, registry))); return absl::OkStatus(); } From 01cfac5b0f1ac6366014e015cd61f9f5c62d3762 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 12 Nov 2024 10:07:18 -0800 Subject: [PATCH 037/180] Add type checker declarations for the math extension library. PiperOrigin-RevId: 695777983 --- conformance/BUILD | 5 +- conformance/service.cc | 3 + extensions/BUILD | 33 +++ extensions/math_ext_decls.cc | 291 ++++++++++++++++++++++++++ extensions/math_ext_decls.h | 31 +++ extensions/math_ext_test.cc | 388 +++++++++++++++++++++++------------ 6 files changed, 619 insertions(+), 132 deletions(-) create mode 100644 extensions/math_ext_decls.cc create mode 100644 extensions/math_ext_decls.h diff --git a/conformance/BUILD b/conformance/BUILD index c0a41a3a7..a4d2d082b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -74,6 +74,7 @@ cc_library( "//extensions:bindings_ext", "//extensions:encoders", "//extensions:math_ext", + "//extensions:math_ext_decls", "//extensions:math_ext_macros", "//extensions:proto_ext", "//extensions:strings", @@ -318,14 +319,10 @@ gen_conformance_tests( modern = True, skip_tests = _TESTS_TO_SKIP_MODERN + [ # TODO: Need to add function declarations for these extensions. - "math_ext", "encoders_ext", # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. "block_ext", - # Test has a typo, C++ conformance runner doesn't accept declaring a message type that isn't - # known. - "dynamic/any/var", ], ) diff --git a/conformance/service.cc b/conformance/service.cc index 5fdd8f86b..a02fbf8c5 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -62,6 +62,7 @@ #include "extensions/bindings_ext.h" #include "extensions/encoders.h" #include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" #include "extensions/math_ext_macros.h" #include "extensions/proto_ext.h" #include "extensions/protobuf/ast_converters.h" @@ -642,6 +643,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::OptionalCheckerLibrary())); CEL_RETURN_IF_ERROR( builder.AddLibrary(cel::extensions::StringsCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder.AddLibrary(cel::extensions::MathCheckerLibrary())); } for (const auto& decl : request.type_env()) { diff --git a/extensions/BUILD b/extensions/BUILD index 0ebc0a603..64c64f421 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -83,12 +83,38 @@ cc_library( ], ) +cc_library( + name = "math_ext_decls", + srcs = ["math_ext_decls.cc"], + hdrs = ["math_ext_decls.h"], + deps = [ + ":math_ext_macros", + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//parser:parser_interface", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "math_ext_test", srcs = ["math_ext_test.cc"], deps = [ ":math_ext", + ":math_ext_decls", ":math_ext_macros", + "//base:function_descriptor", + "//checker:standard_library", + "//checker:validation_result", + "//common:decl", + "//compiler:compiler_factory", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", @@ -99,8 +125,15 @@ cc_test( "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//internal:testing", + "//internal:testing_descriptor_pool", "//parser", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/extensions/math_ext_decls.cc b/extensions/math_ext_decls.cc new file mode 100644 index 000000000..cf3b0d273 --- /dev/null +++ b/extensions/math_ext_decls.cc @@ -0,0 +1,291 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/math_ext_decls.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "extensions/math_ext_macros.h" +#include "internal/status_macros.h" +#include "parser/parser_interface.h" + +namespace cel::extensions { +namespace { + +constexpr char kMathExtensionName[] = "cel.lib.ext.math"; + +const Type& ListIntType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), IntType())); + return *kInstance; +} + +const Type& ListDoubleType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), DoubleType())); + return *kInstance; +} + +const Type& ListUintType() { + static absl::NoDestructor kInstance( + ListType(checker_internal::BuiltinsArena(), UintType())); + return *kInstance; +} + +std::string OverloadTypeName(const Type& type) { + switch (type.kind()) { + case cel::TypeKind::kInt: + return "int"; + case TypeKind::kDouble: + return "double"; + case TypeKind::kUint: + return "uint"; + case TypeKind::kList: + return absl::StrCat("list_", + OverloadTypeName(type.AsList()->GetElement())); + default: + return "unsupported"; + } +} + +absl::Status AddMinMaxDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + const Type kListNumerics[] = {ListIntType(), ListDoubleType(), + ListUintType()}; + + constexpr char kMinOverloadPrefix[] = "math_@min_"; + constexpr char kMaxOverloadPrefix[] = "math_@max_"; + + FunctionDecl min_decl; + min_decl.set_name("math.@min"); + + FunctionDecl max_decl; + max_decl.set_name("math.@max"); + + for (const Type& type : kNumerics) { + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), type, type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), type, type))); + + for (const Type& other_type : kNumerics) { + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type), "_", + OverloadTypeName(other_type)), + DynType(), type, other_type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type), "_", + OverloadTypeName(other_type)), + DynType(), type, other_type))); + } + } + + for (const Type& type : kListNumerics) { + CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), + type.AsList()->GetElement(), type))); + + CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( + absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), + type.AsList()->GetElement(), type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(min_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(max_decl)); + + return absl::OkStatus(); +} + +absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + + FunctionDecl sign_decl; + sign_decl.set_name("math.sign"); + + FunctionDecl abs_decl; + abs_decl.set_name("math.abs"); + + for (const Type& type : kNumerics) { + CEL_RETURN_IF_ERROR(sign_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_sign_", OverloadTypeName(type)), type, type))); + CEL_RETURN_IF_ERROR(abs_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_abs_", OverloadTypeName(type)), type, type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(sign_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(abs_decl)); + + return absl::OkStatus(); +} + +absl::Status AddFloatingPointDecls(TypeCheckerBuilder& builder) { + // Rounding + CEL_ASSIGN_OR_RETURN( + auto ceil_decl, + MakeFunctionDecl( + "math.ceil", + MakeOverloadDecl("math_ceil_double", DoubleType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto floor_decl, + MakeFunctionDecl( + "math.floor", + MakeOverloadDecl("math_floor_double", DoubleType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto round_decl, + MakeFunctionDecl( + "math.round", + MakeOverloadDecl("math_round_double", DoubleType(), DoubleType()))); + CEL_ASSIGN_OR_RETURN( + auto trunc_decl, + MakeFunctionDecl( + "math.trunc", + MakeOverloadDecl("math_trunc_double", DoubleType(), DoubleType()))); + + // FP helpers + CEL_ASSIGN_OR_RETURN( + auto is_inf_decl, + MakeFunctionDecl( + "math.isInf", + MakeOverloadDecl("math_isInf_double", BoolType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto is_nan_decl, + MakeFunctionDecl( + "math.isNaN", + MakeOverloadDecl("math_isNaN_double", BoolType(), DoubleType()))); + + CEL_ASSIGN_OR_RETURN( + auto is_finite_decl, + MakeFunctionDecl( + "math.isFinite", + MakeOverloadDecl("math_isFinite_double", BoolType(), DoubleType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(ceil_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(floor_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(round_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(trunc_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_inf_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_nan_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(is_finite_decl)); + + return absl::OkStatus(); +} + +absl::Status AddBitwiseDecls(TypeCheckerBuilder& builder) { + const Type kBitwiseTypes[] = {IntType(), UintType()}; + + FunctionDecl bit_and_decl; + bit_and_decl.set_name("math.bitAnd"); + + FunctionDecl bit_or_decl; + bit_or_decl.set_name("math.bitOr"); + + FunctionDecl bit_xor_decl; + bit_xor_decl.set_name("math.bitXor"); + + FunctionDecl bit_not_decl; + bit_not_decl.set_name("math.bitNot"); + + FunctionDecl bit_lshift_decl; + bit_lshift_decl.set_name("math.bitShiftLeft"); + + FunctionDecl bit_rshift_decl; + bit_rshift_decl.set_name("math.bitShiftRight"); + + for (const Type& type : kBitwiseTypes) { + CEL_RETURN_IF_ERROR(bit_and_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitAnd_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_or_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitOr_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_xor_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitXor_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type, type))); + + CEL_RETURN_IF_ERROR(bit_not_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_bitNot_", OverloadTypeName(type), + "_", OverloadTypeName(type)), + type, type))); + + CEL_RETURN_IF_ERROR(bit_lshift_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_bitShiftLeft_", OverloadTypeName(type), "_int"), + type, type, IntType()))); + + CEL_RETURN_IF_ERROR(bit_rshift_decl.AddOverload(MakeOverloadDecl( + absl::StrCat("math_bitShiftRight_", OverloadTypeName(type), "_int"), + type, type, IntType()))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_and_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_or_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_xor_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_not_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_lshift_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(bit_rshift_decl)); + + return absl::OkStatus(); +} + +absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder) { + CEL_RETURN_IF_ERROR(AddMinMaxDecls(builder)); + CEL_RETURN_IF_ERROR(AddSignednessDecls(builder)); + CEL_RETURN_IF_ERROR(AddFloatingPointDecls(builder)); + CEL_RETURN_IF_ERROR(AddBitwiseDecls(builder)); + + return absl::OkStatus(); +} + +absl::Status AddMathExtensionMacros(ParserBuilder& builder) { + for (const auto& m : math_macros()) { + CEL_RETURN_IF_ERROR(builder.AddMacro(m)); + } + return absl::OkStatus(); +} + +} // namespace + +// Configuration for cel::Compiler to enable the math extension declarations. +CompilerLibrary MathCompilerLibrary() { + return CompilerLibrary(kMathExtensionName, &AddMathExtensionMacros, + &AddMathExtensionDeclarations); +} + +// Configuration for cel::TypeChecker to enable the math extension declarations. +CheckerLibrary MathCheckerLibrary() { + return { + .id = kMathExtensionName, + .configure = &AddMathExtensionDeclarations, + }; +} + +} // namespace cel::extensions diff --git a/extensions/math_ext_decls.h b/extensions/math_ext_decls.h new file mode 100644 index 000000000..31758f77b --- /dev/null +++ b/extensions/math_ext_decls.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ + +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" + +namespace cel::extensions { + +// Configuration for cel::Compiler to enable the math extension declarations. +CompilerLibrary MathCompilerLibrary(); + +// Configuration for cel::TypeChecker to enable the math extension declarations. +CheckerLibrary MathCheckerLibrary(); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index fbd1635ff..6cf3ff313 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -15,11 +15,21 @@ #include "extensions/math_ext.h" #include +#include +#include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/function_descriptor.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "compiler/compiler_factory.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" @@ -29,14 +39,20 @@ #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/testing/matchers.h" +#include "extensions/math_ext_decls.h" #include "extensions/math_ext_macros.h" #include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" #include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { +using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; @@ -86,13 +102,29 @@ struct MacroTestCase { absl::string_view err = ""; }; +std::string FormatIssues(const cel::ValidationResult& result) { + std::string issues; + for (const auto& issue : result.GetIssues()) { + if (!issues.empty()) { + absl::StrAppend(&issues, "\n", + issue.ToDisplayString(*result.GetSource())); + } else { + issues = issue.ToDisplayString(*result.GetSource()); + } + } + return issues; +} + class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) - : CelFunction(CelFunctionDescriptor( - name, true, - {CelValue::Type::kBool, CelValue::Type::kInt64, - CelValue::Type::kInt64})) {} + : CelFunction(MakeDescriptor(name)) {} + + static FunctionDescriptor MakeDescriptor(absl::string_view name) { + return FunctionDescriptor(name, true, + {CelValue::Type::kBool, CelValue::Type::kInt64, + CelValue::Type::kInt64}); + } absl::Status Evaluate(absl::Span args, CelValue* result, Arena* arena) const override { @@ -276,7 +308,7 @@ TEST(MathExtTest, MinMaxList) { } using MathExtMacroParamsTest = testing::TestWithParam; -TEST_P(MathExtMacroParamsTest, MacroTests) { +TEST_P(MathExtMacroParamsTest, ParserTests) { const MacroTestCase& test_case = GetParam(); auto result = ParseWithMacros(test_case.expr, cel::extensions::math_macros(), ""); @@ -291,6 +323,7 @@ TEST_P(MathExtMacroParamsTest, MacroTests) { Expr expr = parsed_expr.expr(); SourceInfo source_info = parsed_expr.source_info(); InterpreterOptions options; + options.enable_qualified_identifier_rewrites = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateGreatestFunction())); @@ -309,131 +342,230 @@ TEST_P(MathExtMacroParamsTest, MacroTests) { EXPECT_EQ(value.BoolOrDie(), true); } +TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { + const MacroTestCase& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); + + // Add test functions that check macro (non-)expansion. + ASSERT_OK_AND_ASSIGN( + auto least_decl, + MakeFunctionDecl("least", MakeMemberOverloadDecl("bool_least_int_int", + /*result*/ BoolType(), + /*receiver*/ BoolType(), + IntType(), IntType()))); + ASSERT_OK_AND_ASSIGN(auto greatest_decl, + MakeFunctionDecl("greatest", MakeMemberOverloadDecl( + "bool_greatest_int_int", + /*result*/ BoolType(), + /*receiver*/ BoolType(), + IntType(), IntType()))); + + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(least_decl), + IsOk()); + ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(greatest_decl), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); + + auto result = compiler->Compile(test_case.expr, ""); + + if (!test_case.err.empty()) { + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.err))); + return; + } + + ASSERT_THAT(result, IsOk()); + ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT( + RegisterMathExtensionFunctions(runtime_builder.function_registry(), opts), + IsOk()); + + ASSERT_THAT( + runtime_builder.function_registry().Register( + TestFunction::MakeDescriptor(kGreatest), CreateGreatestFunction()), + IsOk()); + ASSERT_THAT( + runtime_builder.function_registry().Register( + TestFunction::MakeDescriptor(kLeast), CreateGreatestFunction()), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result->ReleaseAst())); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.IsBool()); + EXPECT_EQ(value.GetBool(), true); +} + INSTANTIATE_TEST_SUITE_P( MathExtMacrosParamsTest, MathExtMacroParamsTest, - testing::ValuesIn({ - // Tests for math.least - {"math.least(-0.5) == -0.5"}, - {"math.least(-1) == -1"}, - {"math.least(1u) == 1u"}, - {"math.least(42.0, -0.5) == -0.5"}, - {"math.least(-1, 0) == -1"}, - {"math.least(-1, -1) == -1"}, - {"math.least(1u, 42u) == 1u"}, - {"math.least(42.0, -0.5, -0.25) == -0.5"}, - {"math.least(-1, 0, 1) == -1"}, - {"math.least(-1, -1, -1) == -1"}, - {"math.least(1u, 42u, 0u) == 0u"}, - // math.least two arg overloads across type. - {"math.least(1, 1.0) == 1"}, - {"math.least(1, -2.0) == -2.0"}, - {"math.least(2, 1u) == 1u"}, - {"math.least(1.5, 2) == 1.5"}, - {"math.least(1.5, -2) == -2"}, - {"math.least(2.5, 1u) == 1u"}, - {"math.least(1u, 2) == 1u"}, - {"math.least(1u, -2) == -2"}, - {"math.least(2u, 2.5) == 2u"}, - // math.least with dynamic values across type. - {"math.least(1u, dyn(42)) == 1"}, - {"math.least(1u, dyn(42), dyn(0.0)) == 0u"}, - // math.least with a list literal. - {"math.least([1u, 42u, 0u]) == 0u"}, - // math.least errors - { - "math.least()", - "math.least() requires at least one argument.", - }, - { - "math.least('hello')", - "math.least() invalid single argument value.", - }, - { - "math.least({})", - "math.least() invalid single argument value", - }, - { - "math.least([])", - "math.least() invalid single argument value", - }, - { - "math.least([1, true])", - "math.least() invalid single argument value", - }, - { - "math.least(1, true)", - "math.least() simple literal arguments must be numeric", - }, - { - "math.least(1, 2, true)", - "math.least() simple literal arguments must be numeric", - }, - - // Tests for math.greatest - {"math.greatest(-0.5) == -0.5"}, - {"math.greatest(-1) == -1"}, - {"math.greatest(1u) == 1u"}, - {"math.greatest(42.0, -0.5) == 42.0"}, - {"math.greatest(-1, 0) == 0"}, - {"math.greatest(-1, -1) == -1"}, - {"math.greatest(1u, 42u) == 42u"}, - {"math.greatest(42.0, -0.5, -0.25) == 42.0"}, - {"math.greatest(-1, 0, 1) == 1"}, - {"math.greatest(-1, -1, -1) == -1"}, - {"math.greatest(1u, 42u, 0u) == 42u"}, - // math.least two arg overloads across type. - {"math.greatest(1, 1.0) == 1"}, - {"math.greatest(1, -2.0) == 1"}, - {"math.greatest(2, 1u) == 2"}, - {"math.greatest(1.5, 2) == 2"}, - {"math.greatest(1.5, -2) == 1.5"}, - {"math.greatest(2.5, 1u) == 2.5"}, - {"math.greatest(1u, 2) == 2"}, - {"math.greatest(1u, -2) == 1u"}, - {"math.greatest(2u, 2.5) == 2.5"}, - // math.greatest with dynamic values across type. - {"math.greatest(1u, dyn(42)) == 42.0"}, - {"math.greatest(1u, dyn(0.0), 0u) == 1"}, - // math.greatest with a list literal - {"math.greatest([1u, dyn(0.0), 0u]) == 1"}, - // math.greatest errors - { - "math.greatest()", - "math.greatest() requires at least one argument.", - }, - { - "math.greatest('hello')", - "math.greatest() invalid single argument value.", - }, - { - "math.greatest({})", - "math.greatest() invalid single argument value", - }, - { - "math.greatest([])", - "math.greatest() invalid single argument value", - }, - { - "math.greatest([1, true])", - "math.greatest() invalid single argument value", - }, - { - "math.greatest(1, true)", - "math.greatest() simple literal arguments must be numeric", - }, - { - "math.greatest(1, 2, true)", - "math.greatest() simple literal arguments must be numeric", - }, - // Call signatures which trigger macro expansion, but which do not - // get expanded. The function just returns true. - { - "false.greatest(1,2)", - }, - { - "true.least(1,2)", - }, - })); + testing::ValuesIn( + {// Tests for math.least + {"math.least(-0.5) == -0.5"}, + {"math.least(-1) == -1"}, + {"math.least(1u) == 1u"}, + {"math.least(42.0, -0.5) == -0.5"}, + {"math.least(-1, 0) == -1"}, + {"math.least(-1, -1) == -1"}, + {"math.least(1u, 42u) == 1u"}, + {"math.least(42.0, -0.5, -0.25) == -0.5"}, + {"math.least(-1, 0, 1) == -1"}, + {"math.least(-1, -1, -1) == -1"}, + {"math.least(1u, 42u, 0u) == 0u"}, + // math.least two arg overloads across type. + {"math.least(1, 1.0) == 1"}, + {"math.least(1, -2.0) == -2.0"}, + {"math.least(2, 1u) == 1u"}, + {"math.least(1.5, 2) == 1.5"}, + {"math.least(1.5, -2) == -2"}, + {"math.least(2.5, 1u) == 1u"}, + {"math.least(1u, 2) == 1u"}, + {"math.least(1u, -2) == -2"}, + {"math.least(2u, 2.5) == 2u"}, + // math.least with dynamic values across type. + {"math.least(1u, dyn(42)) == 1"}, + {"math.least(1u, dyn(42), dyn(0.0)) == 0u"}, + // math.least with a list literal. + {"math.least([1u, 42u, 0u]) == 0u"}, + // math.least errors + { + "math.least()", + "math.least() requires at least one argument.", + }, + { + "math.least('hello')", + "math.least() invalid single argument value.", + }, + { + "math.least({})", + "math.least() invalid single argument value", + }, + { + "math.least([])", + "math.least() invalid single argument value", + }, + { + "math.least([1, true])", + "math.least() invalid single argument value", + }, + { + "math.least(1, true)", + "math.least() simple literal arguments must be numeric", + }, + { + "math.least(1, 2, true)", + "math.least() simple literal arguments must be numeric", + }, + + // Tests for math.greatest + {"math.greatest(-0.5) == -0.5"}, + {"math.greatest(-1) == -1"}, + {"math.greatest(1u) == 1u"}, + {"math.greatest(42.0, -0.5) == 42.0"}, + {"math.greatest(-1, 0) == 0"}, + {"math.greatest(-1, -1) == -1"}, + {"math.greatest(1u, 42u) == 42u"}, + {"math.greatest(42.0, -0.5, -0.25) == 42.0"}, + {"math.greatest(-1, 0, 1) == 1"}, + {"math.greatest(-1, -1, -1) == -1"}, + {"math.greatest(1u, 42u, 0u) == 42u"}, + // math.least two arg overloads across type. + {"math.greatest(1, 1.0) == 1"}, + {"math.greatest(1, -2.0) == 1"}, + {"math.greatest(2, 1u) == 2"}, + {"math.greatest(1.5, 2) == 2"}, + {"math.greatest(1.5, -2) == 1.5"}, + {"math.greatest(2.5, 1u) == 2.5"}, + {"math.greatest(1u, 2) == 2"}, + {"math.greatest(1u, -2) == 1u"}, + {"math.greatest(2u, 2.5) == 2.5"}, + // math.greatest with dynamic values across type. + {"math.greatest(1u, dyn(42)) == 42.0"}, + {"math.greatest(1u, dyn(0.0), 0u) == 1"}, + // math.greatest with a list literal + {"math.greatest([1u, dyn(0.0), 0u]) == 1"}, + // math.greatest errors + { + "math.greatest()", + "math.greatest() requires at least one argument.", + }, + { + "math.greatest('hello')", + "math.greatest() invalid single argument value.", + }, + { + "math.greatest({})", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([1, true])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest(1, true)", + "math.greatest() simple literal arguments must be numeric", + }, + { + "math.greatest(1, 2, true)", + "math.greatest() simple literal arguments must be numeric", + }, + // Call signatures which trigger macro expansion, but which do not + // get expanded. The function just returns true. + { + "false.greatest(1,2)", + }, + { + "true.least(1,2)", + }, + // Basic coverage for function definitions. Behavior is tested in the + // conformance tests. + {"math.sign(-12) == -1"}, + {"math.sign(0u) == 0u"}, + {"math.sign(42.01) == 1.0"}, + {"math.abs(-12) == 12"}, + {"math.abs(0u) == 0u"}, + {"math.abs(42.01) == 42.01"}, + {"math.ceil(42.01) == 43.0"}, + {"math.floor(42.01) == 42.0"}, + {"math.round(42.5) == 43.0"}, + {"math.trunc(42.0) == 42.0"}, + {"math.isInf(42.0 / 0.0) == true"}, + {"math.isNaN(double('nan')) == true"}, + {"math.isFinite(42.1) == true"}, + {"math.bitAnd(3, 1) == 1"}, + {"math.bitAnd(3u, 1u) == 1u"}, + {"math.bitOr(2, 1) == 3"}, + {"math.bitOr(2u, 1u) == 3u"}, + {"math.bitXor(3, 1) == 2"}, + {"math.bitXor(3u, 1u) == 2u"}, + {"math.bitNot(2) == -3"}, + {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, + {"math.bitShiftLeft(1, 1) == 2"}, + {"math.bitShiftLeft(1u, 1) == 2u"}, + {"math.bitShiftRight(4, 1) == 2"}, + {"math.bitShiftRight(4u, 1) == 2u"}})); } // namespace } // namespace cel::extensions From 2c885095f25a41f6ed2a2f9ab4c61a1c22c5af4f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 12 Nov 2024 11:03:54 -0800 Subject: [PATCH 038/180] Add checker declarations for encoder extensions. PiperOrigin-RevId: 695800058 --- conformance/BUILD | 2 - conformance/service.cc | 2 + extensions/BUILD | 23 ++++++++++ extensions/encoders.cc | 25 ++++++++++ extensions/encoders.h | 4 ++ extensions/encoders_test.cc | 91 +++++++++++++++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 extensions/encoders_test.cc diff --git a/conformance/BUILD b/conformance/BUILD index a4d2d082b..915c02367 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -318,8 +318,6 @@ gen_conformance_tests( data = _ALL_TESTS, modern = True, skip_tests = _TESTS_TO_SKIP_MODERN + [ - # TODO: Need to add function declarations for these extensions. - "encoders_ext", # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. "block_ext", diff --git a/conformance/service.cc b/conformance/service.cc index a02fbf8c5..b89046e1c 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -645,6 +645,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { builder.AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder.AddLibrary(cel::extensions::MathCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder.AddLibrary(cel::extensions::EncodersCheckerLibrary())); } for (const auto& decl : request.type_env()) { diff --git a/extensions/BUILD b/extensions/BUILD index 64c64f421..b09ad4905 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -5,6 +5,9 @@ cc_library( srcs = ["encoders.cc"], hdrs = ["encoders.h"], deps = [ + "//checker:type_checker_builder", + "//common:decl", + "//common:type", "//common:value", "//eval/public:cel_function_registry", "//eval/public:cel_options", @@ -19,6 +22,26 @@ cc_library( ], ) +cc_test( + name = "encoders_test", + srcs = ["encoders_test.cc"], + deps = [ + ":encoders", + "//checker:standard_library", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "proto_ext", srcs = ["proto_ext.cc"], diff --git a/extensions/encoders.cc b/extensions/encoders.cc index 751e0283c..4941d82af 100644 --- a/extensions/encoders.cc +++ b/extensions/encoders.cc @@ -21,6 +21,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" #include "common/value.h" #include "common/value_manager.h" #include "eval/public/cel_function_registry.h" @@ -52,6 +55,24 @@ absl::StatusOr Base64Encode(ValueManager& value_manager, return value_manager.CreateStringValue(std::move(out)); } +absl::Status RegisterEncodersDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + auto base64_decode_decl, + MakeFunctionDecl( + "base64.decode", + MakeOverloadDecl("base64_decode_string", BytesType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + auto base64_encode_decl, + MakeFunctionDecl( + "base64.encode", + MakeOverloadDecl("base64_encode_bytes", StringType(), BytesType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(base64_decode_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(base64_encode_decl)); + return absl::OkStatus(); +} + } // namespace absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, @@ -78,4 +99,8 @@ absl::Status RegisterEncodersFunctions( google::api::expr::runtime::ConvertToRuntimeOptions(options)); } +CheckerLibrary EncodersCheckerLibrary() { + return {"cel.lib.ext.encoders", &RegisterEncodersDecls}; +} + } // namespace cel::extensions diff --git a/extensions/encoders.h b/extensions/encoders.h index 1e7207943..12dc40ff9 100644 --- a/extensions/encoders.h +++ b/extensions/encoders.h @@ -17,6 +17,7 @@ #include "absl/base/nullability.h" #include "absl/status/status.h" +#include "checker/type_checker_builder.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" @@ -32,6 +33,9 @@ absl::Status RegisterEncodersFunctions( absl::Nonnull registry, const google::api::expr::runtime::InterpreterOptions& options); +// Declarations for the encoders extension library. +CheckerLibrary EncodersCheckerLibrary(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ diff --git a/extensions/encoders_test.cc b/extensions/encoders_test.cc new file mode 100644 index 000000000..c95588e29 --- /dev/null +++ b/extensions/encoders_test.cc @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/encoders.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; + +struct TestCase { + std::string expr; +}; + +class EncodersTest : public ::testing::TestWithParam {}; + +TEST_P(EncodersTest, ParseCheckEval) { + const TestCase& test_case = GetParam(); + + // Configure the compiler. + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + compiler_builder->AddLibrary(extensions::EncodersCheckerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + std::move(*compiler_builder).Build()); + + // Configure the runtime. + cel::RuntimeOptions runtime_options; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), + runtime_options)); + ASSERT_THAT(RegisterEncodersFunctions(runtime_builder.function_registry(), + runtime_options), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + + // Compile, plan, evaluate. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result.ReleaseAst())); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.IsBool()); + ASSERT_TRUE(value.GetBool()); +} + +INSTANTIATE_TEST_SUITE_P( + EncodersTest, EncodersTest, + testing::Values(TestCase{"base64.encode(b'hello') == 'aGVsbG8='"}, + TestCase{"base64.decode('aGVsbG8=') == b'hello'"})); + +} // namespace +} // namespace cel::extensions From acc93753181ba692b176857915a7e32762feee85 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 14 Nov 2024 14:24:41 -0800 Subject: [PATCH 039/180] Internal testing change PiperOrigin-RevId: 696650344 --- conformance/BUILD | 2 +- internal/BUILD | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 915c02367..81cf1455d 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -120,6 +120,7 @@ cc_library( srcs = ["run.cc"], deps = [ ":service", + "//google/api/expr:eval_cc_proto", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", @@ -129,7 +130,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", - "@com_google_cel_spec//proto/cel/expr:eval_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:simple_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", diff --git a/internal/BUILD b/internal/BUILD index 0c7b1e6e0..26e92f0f4 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -534,9 +534,9 @@ cel_proto_transitive_descriptor_set( deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", + "//google/api/expr:eval_protos", + "//google/api/expr:explain_protos", "@com_google_cel_spec//proto/cel/expr:checked_proto", - "@com_google_cel_spec//proto/cel/expr:eval_proto", - "@com_google_cel_spec//proto/cel/expr:explain_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", "@com_google_cel_spec//proto/cel/expr:value_proto", From ea65f242ec8197eb380eb6422c4c698cdf25c3f6 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 14 Nov 2024 17:07:35 -0800 Subject: [PATCH 040/180] BUILD dependency bundling PiperOrigin-RevId: 696697160 --- conformance/BUILD | 5 ++--- internal/BUILD | 2 -- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 81cf1455d..43735aafc 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -120,7 +120,6 @@ cc_library( srcs = ["run.cc"], deps = [ ":service", - "//google/api/expr:eval_cc_proto", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", @@ -180,7 +179,7 @@ _TESTS_TO_SKIP_MODERN = [ "basic/functions/unbound", "basic/functions/unbound_is_runtime_error", - # TODO: Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails + # TODO: Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "namespace/qualified/self_eval_qualified_lookup", "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", @@ -234,7 +233,7 @@ _TESTS_TO_SKIP_LEGACY = [ "basic/functions/unbound", "basic/functions/unbound_is_runtime_error", - # TODO: Parse-only qualified variable lookup "x.y" wtih binding "x.y" or "y" within container "x" fails + # TODO: Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "namespace/qualified/self_eval_qualified_lookup", "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", diff --git a/internal/BUILD b/internal/BUILD index 26e92f0f4..71bc99bc6 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -534,8 +534,6 @@ cel_proto_transitive_descriptor_set( deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", - "//google/api/expr:eval_protos", - "//google/api/expr:explain_protos", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", From c50e392fd309898df08b21656de45197f97ad609 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 15 Nov 2024 11:50:10 -0800 Subject: [PATCH 041/180] Reject legacy runtime type values with empty typenames when converting to modern values. PiperOrigin-RevId: 696956052 --- common/legacy_value.cc | 7 +++- .../cel_expression_builder_flat_impl_test.cc | 40 ++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/common/legacy_value.cc b/common/legacy_value.cc index b1aa72bcb..f0ee7fd99 100644 --- a/common/legacy_value.cc +++ b/common/legacy_value.cc @@ -1095,8 +1095,11 @@ absl::Status ModernValue(google::protobuf::Arena* arena, result = UnknownValue{*legacy_value.UnknownSetOrDie()}; return absl::OkStatus(); case CelValue::Type::kCelType: { - result = TypeValue{common_internal::LegacyRuntimeType( - legacy_value.CelTypeOrDie().value())}; + auto type_name = legacy_value.CelTypeOrDie().value(); + if (type_name.empty()) { + return absl::InvalidArgumentError("empty type name in CelValue"); + } + result = TypeValue{common_internal::LegacyRuntimeType(type_name)}; return absl::OkStatus(); } case CelValue::Type::kError: diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index 73365e4e6..9ae484a3f 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -41,7 +41,6 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/testing/matchers.h" #include "extensions/bindings_ext.h" #include "internal/status_macros.h" @@ -52,7 +51,6 @@ #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -357,6 +355,44 @@ TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { StatusIs(_, HasSubstr("No matching overloads")))); } +TEST(CelExpressionBuilderFlatImplTest, EmptyLegacyTypeViewUnsupported) { + // Creating type values directly (instead of using the builtin functions and + // identifiers from the type registry) is not recommended for CEL users. The + // name is expected to be non-empty. + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); + cel::RuntimeOptions options; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateCelTypeView("")); + google::protobuf::Arena arena; + ASSERT_THAT(plan->Evaluate(activation, &arena), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CelExpressionBuilderFlatImplTest, LegacyTypeViewSupported) { + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); + cel::RuntimeOptions options; + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + activation.InsertValue("x", CelValue::CreateCelTypeView("MyType")); + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + ASSERT_TRUE(result.IsCelType()); + EXPECT_EQ(result.CelTypeOrDie().value(), "MyType"); +} + TEST(CelExpressionBuilderFlatImplTest, CheckedExpr) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); CheckedExpr checked_expr; From 8fb742f967721d214a3ad887e9ee9826b80a3f50 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 15 Nov 2024 13:47:34 -0800 Subject: [PATCH 042/180] Fix misleading error on string -> uint conversion. When string -> uint fails it states "doesn't convert to a string" although the request is to convert from a string. Instead choose "cannot convert string to uint", which is the same kind of statement used for other string conversion failures. PiperOrigin-RevId: 696989968 --- runtime/standard/type_conversion_functions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 7db2aa4a2..1f3d40108 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -231,7 +231,7 @@ absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, uint64_t result; if (!absl::SimpleAtoi(s.ToString(), &result)) { return value_factory.CreateErrorValue( - absl::InvalidArgumentError("doesn't convert to a string")); + absl::InvalidArgumentError("cannot convert string to uint")); } return value_factory.CreateUintValue(result); }, From 200ae3350ac83e9de3c1dce94f0daf2eb1056a5b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 18 Nov 2024 12:20:49 -0800 Subject: [PATCH 043/180] Update helper class for managing AST traversal. - Rename to AstTraversal - allow client to explicitly step through the visitation instead of signaling a stop request. PiperOrigin-RevId: 697717643 --- common/BUILD | 4 +-- common/ast_traverse.cc | 63 +++++++++++++++++++------------------ common/ast_traverse.h | 53 ++++++++++++++----------------- common/ast_traverse_test.cc | 54 +++++++++---------------------- 4 files changed, 72 insertions(+), 102 deletions(-) diff --git a/common/BUILD b/common/BUILD index bc2fd69d3..cdb0ca899 100644 --- a/common/BUILD +++ b/common/BUILD @@ -149,8 +149,8 @@ cc_library( ":ast_visitor", ":constant", ":expr", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/status", "@com_google_absl//absl/types:variant", ], ) @@ -164,8 +164,6 @@ cc_test( ":constant", ":expr", "//internal:testing", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:status_matchers", ], ) diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc index 07de5f1e8..a6ba0d1ba 100644 --- a/common/ast_traverse.cc +++ b/common/ast_traverse.cc @@ -18,7 +18,6 @@ #include #include "absl/log/absl_log.h" -#include "absl/status/status.h" #include "absl/types/variant.h" #include "common/ast_visitor.h" #include "common/constant.h" @@ -26,12 +25,6 @@ namespace cel { -namespace common_internal { -struct AstTraverseContext { - bool should_halt = false; -}; -} // namespace common_internal - namespace { struct ArgRecord { @@ -326,30 +319,44 @@ void PushDependencies(const StackRecord& record, std::stack& stack, } // namespace -AstTraverseManager::AstTraverseManager(TraversalOptions options) - : options_(options) {} +namespace common_internal { +struct AstTraversalState { + std::stack stack; +}; +} // namespace common_internal + +AstTraversal AstTraversal::Create(const cel::Expr& ast, + const TraversalOptions& options) { + AstTraversal instance(options); + instance.state_ = std::make_unique(); + instance.state_->stack.push(StackRecord(&ast)); + return instance; +} + +AstTraversal::AstTraversal(TraversalOptions options) : options_(options) {} -AstTraverseManager::AstTraverseManager() = default; -AstTraverseManager::~AstTraverseManager() = default; +AstTraversal::~AstTraversal() = default; -absl::Status AstTraverseManager::AstTraverse(const Expr& expr, - AstVisitor& visitor) { - if (context_ != nullptr) { - return absl::FailedPreconditionError( - "AstTraverseManager is already in use"); +bool AstTraversal::Step(AstVisitor& visitor) { + if (IsDone()) { + return false; + } + auto& stack = state_->stack; + StackRecord& record = stack.top(); + if (!record.visited) { + PreVisit(record, &visitor); + PushDependencies(record, stack, options_); + record.visited = true; + } else { + PostVisit(record, &visitor); + stack.pop(); } - context_ = std::make_unique(); - TraversalOptions options = options_; - options.manager_context = context_.get(); - ::cel::AstTraverse(expr, visitor, options); - context_ = nullptr; - return absl::OkStatus(); + + return !stack.empty(); } -void AstTraverseManager::RequestHalt() { - if (context_ != nullptr) { - context_->should_halt = true; - } +bool AstTraversal::IsDone() { + return state_ == nullptr || state_->stack.empty(); } void AstTraverse(const Expr& expr, AstVisitor& visitor, @@ -358,10 +365,6 @@ void AstTraverse(const Expr& expr, AstVisitor& visitor, stack.push(StackRecord(&expr)); while (!stack.empty()) { - if (options.manager_context != nullptr && - options.manager_context->should_halt) { - return; - } StackRecord& record = stack.top(); if (!record.visited) { PreVisit(record, &visitor); diff --git a/common/ast_traverse.h b/common/ast_traverse.h index 6201002f0..47d8ccc80 100644 --- a/common/ast_traverse.h +++ b/common/ast_traverse.h @@ -17,62 +17,57 @@ #include -#include "absl/status/status.h" +#include "absl/base/attributes.h" #include "common/ast_visitor.h" #include "common/expr.h" namespace cel { namespace common_internal { -struct AstTraverseContext; +struct AstTraversalState; } struct TraversalOptions { // Enable use of the comprehension specific callbacks. - bool use_comprehension_callbacks; - // Opaque context used by the traverse manager. - const common_internal::AstTraverseContext* manager_context; - - TraversalOptions() - : use_comprehension_callbacks(false), manager_context(nullptr) {} + bool use_comprehension_callbacks = false; }; // Helper class for managing the traversal of the AST. -// Allows for passing a signal to halt the traversal. +// Allows caller to step through the traversal. // // Usage: // -// AstTraverseManager manager(/*options=*/{}); +// AstTraversal traversal = AstTraversal::Create(expr); // -// MyVisitor visitor(&manager); -// CEL_RETURN_IF_ERROR(manager.AstTraverse(expr, visitor)); +// MyVisitor visitor(); +// while(!traversal.IsDone()) { +// traversal.Step(visitor); +// } // // This class is thread-hostile and should only be used in synchronous code. -class AstTraverseManager { +class AstTraversal { public: - explicit AstTraverseManager(TraversalOptions options); - AstTraverseManager(); + static AstTraversal Create(const cel::Expr& ast ABSL_ATTRIBUTE_LIFETIME_BOUND, + const TraversalOptions& options = {}); - ~AstTraverseManager(); + ~AstTraversal(); - AstTraverseManager(const AstTraverseManager&) = delete; - AstTraverseManager& operator=(const AstTraverseManager&) = delete; - AstTraverseManager(AstTraverseManager&&) = delete; - AstTraverseManager& operator=(AstTraverseManager&&) = delete; + AstTraversal(const AstTraversal&) = delete; + AstTraversal& operator=(const AstTraversal&) = delete; + AstTraversal(AstTraversal&&) = default; + AstTraversal& operator=(AstTraversal&&) = default; - // Managed traversal of the AST. Allows for interrupting the traversal. - // Re-entrant traversal is not supported and will result in a - // FailedPrecondition error. - absl::Status AstTraverse(const Expr& expr, AstVisitor& visitor); + // Advances the traversal. Returns true if there is more work to do. This is a + // no-op if the traversal is done and IsDone() is true. + bool Step(AstVisitor& visitor); - // Signals a request for the traversal to halt. The traversal routine will - // check for this signal at the start of each Expr node visitation. - // This has no effect if no traversal is in progress. - void RequestHalt(); + // Returns true if there is more work to do. + bool IsDone(); private: + explicit AstTraversal(TraversalOptions options); TraversalOptions options_; - std::unique_ptr context_; + std::unique_ptr state_; }; // Traverses the AST representation in an expr proto. diff --git a/common/ast_traverse_test.cc b/common/ast_traverse_test.cc index 26c620be6..16ee40ce0 100644 --- a/common/ast_traverse_test.cc +++ b/common/ast_traverse_test.cc @@ -14,8 +14,6 @@ #include "common/ast_traverse.h" -#include "absl/status/status.h" -#include "absl/status/status_matchers.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" @@ -25,8 +23,6 @@ namespace cel::ast_internal { namespace { -using ::absl_testing::IsOk; -using ::absl_testing::StatusIs; using ::testing::_; using ::testing::Ref; @@ -434,7 +430,7 @@ TEST(AstCrawlerTest, CheckExprHandlers) { AstTraverse(expr, handler); } -TEST(AstTraverseManager, Interrupt) { +TEST(AstTraversal, Interrupt) { MockAstVisitor handler; Expr expr; @@ -444,37 +440,21 @@ TEST(AstTraverseManager, Interrupt) { testing::InSequence seq; - AstTraverseManager manager; + auto traversal = AstTraversal::Create(expr); - EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))) - .Times(1) - .WillOnce([&manager](const Expr& expr, const IdentExpr& ident_expr) { - manager.RequestHalt(); - }); - EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(0); - - EXPECT_THAT(manager.AstTraverse(expr, handler), IsOk()); -} - -TEST(AstTraverseManager, NoInterrupt) { - MockAstVisitor handler; - - Expr expr; - auto& select_expr = expr.mutable_select_expr(); - auto& operand = select_expr.mutable_operand(); - auto& ident_expr = operand.mutable_ident_expr(); - - testing::InSequence seq; - - AstTraverseManager manager; + EXPECT_CALL(handler, PreVisitExpr(_)).Times(2); EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); - EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); + EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(0); + + EXPECT_TRUE(traversal.Step(handler)); + EXPECT_TRUE(traversal.Step(handler)); + EXPECT_TRUE(traversal.Step(handler)); - EXPECT_THAT(manager.AstTraverse(expr, handler), IsOk()); + EXPECT_FALSE(traversal.IsDone()); } -TEST(AstCrawlerTest, ReentantTraversalUnsupported) { +TEST(AstTraversal, NoInterrupt) { MockAstVisitor handler; Expr expr; @@ -482,21 +462,15 @@ TEST(AstCrawlerTest, ReentantTraversalUnsupported) { auto& operand = select_expr.mutable_operand(); auto& ident_expr = operand.mutable_ident_expr(); - AstTraverseManager manager; - testing::InSequence seq; - EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))) - .Times(1) - .WillOnce( - [&manager, &handler](const Expr& expr, const IdentExpr& ident_expr) { - EXPECT_THAT(manager.AstTraverse(expr, handler), - StatusIs(absl::StatusCode::kFailedPrecondition)); - }); + auto traversal = AstTraversal::Create(expr); + EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); - EXPECT_THAT(manager.AstTraverse(expr, handler), IsOk()); + while (traversal.Step(handler)) continue; + EXPECT_TRUE(traversal.IsDone()); } } // namespace From 968703924e246d96965b8d11527a781bfd079596 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 18 Nov 2024 12:52:37 -0800 Subject: [PATCH 044/180] Add missing return statements after errors in C++ planner. Fixes a couple of cases where malformed expressions might lead to index out of bounds errors. PiperOrigin-RevId: 697727147 --- eval/compiler/flat_expr_builder.cc | 5 + runtime/BUILD | 2 + .../standard_runtime_builder_factory_test.cc | 114 ++++++++++++++++++ 3 files changed, 121 insertions(+) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index af34b13a7..84e23f914 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -933,6 +933,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); + return; } const cel::ast_internal::Expr* condition_expr = @@ -981,6 +982,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); + return; } const cel::ast_internal::Expr* left_expr = &expr->call_expr().args()[0]; const cel::ast_internal::Expr* right_expr = &expr->call_expr().args()[1]; @@ -1028,6 +1030,7 @@ class FlatExprVisitor : public cel::AstVisitor { expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); + return; } const cel::ast_internal::Expr* left_expr = &expr->call_expr().target(); const cel::ast_internal::Expr* right_expr = &expr->call_expr().args()[0]; @@ -1188,6 +1191,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); + return; } SetRecursiveStep(CreateDirectContainerAccessStep( std::move(args[0]), std::move(args[1]), @@ -1966,6 +1970,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, auto jump_after_first = CreateJumpStep({}, expr->id()); if (!jump_after_first.ok()) { visitor_->SetProgressStatusError(jump_after_first.status()); + return; } jump_after_first_ = diff --git a/runtime/BUILD b/runtime/BUILD index 91b8382b2..77c63be58 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -262,6 +262,7 @@ cc_test( ":runtime_issue", ":runtime_options", ":standard_runtime_builder_factory", + "//base:builtins", "//common:memory", "//common:source", "//common:value", @@ -277,6 +278,7 @@ cc_test( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index c9ce72bb9..6abe75a84 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -24,8 +24,10 @@ #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/builtins.h" #include "common/memory.h" #include "common/source.h" #include "common/value.h" @@ -51,6 +53,7 @@ namespace cel { namespace { +using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::test::BoolValueIs; @@ -544,5 +547,116 @@ TEST(StandardRuntimeTest, RuntimeIssueSupport) { } } +enum class EvalStrategy { kIterative, kRecursive }; + +class StandardRuntimeEvalStrategyTest + : public ::testing::TestWithParam {}; + +// Check that calls to specialized builtins are validated. +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinBoolOp) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kOr); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_const_expr()->set_bool_value(true); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinTernaryOp) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function( + cel::builtin::kTernary); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + expr.mutable_expr() + ->mutable_call_expr() + ->add_args() + ->mutable_const_expr() + ->set_bool_value(true); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinIndex) { + EvalStrategy eval_strategy = GetParam(); + RuntimeOptions options; + if (eval_strategy == EvalStrategy::kRecursive) { + options.max_recursion_depth = -1; + } else { + options.max_recursion_depth = 0; + } + + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), options)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ParsedExpr expr; + expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kIndex); + auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); + arg->mutable_list_expr() + ->add_elements() + ->mutable_const_expr() + ->set_int64_value(1); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +INSTANTIATE_TEST_SUITE_P(StandardRuntimeEvalStrategyTest, + StandardRuntimeEvalStrategyTest, + testing::Values(EvalStrategy::kIterative, + EvalStrategy::kRecursive)); + } // namespace } // namespace cel From 19fcb020913e7baf387b9d4241c233b8b8020679 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 18 Nov 2024 15:41:33 -0800 Subject: [PATCH 045/180] TypeChecker updates to filter bad line information. Add additional checks in line offset computation to avoid integer overflows for ASTs with bad position maps. PiperOrigin-RevId: 697781525 --- checker/internal/type_checker_impl.cc | 22 +++++--- checker/internal/type_checker_impl_test.cc | 58 ++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 027c3a87b..bae1c43d7 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -75,21 +75,31 @@ SourceLocation ComputeSourceLocation(const AstImpl& ast, int64_t expr_id) { return SourceLocation{}; } int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. int32_t line_idx = -1; + int32_t offset = 0; for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { - int32_t offset = source_info.line_offsets()[i]; - if (absolute_position < offset) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { line_idx = i; break; } + offset = next_offset; } - if (line_idx <= 0 || line_idx >= source_info.line_offsets().size()) { - return SourceLocation{1, absolute_position}; + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; } - auto offset = source_info.line_offsets()[line_idx - 1]; - int32_t rel_position = absolute_position - offset; return SourceLocation{line_idx + 1, rel_position}; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index d64e22cc3..a2b6fdade 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -68,6 +68,7 @@ using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Property; +using ::testing::SizeIs; using AstType = ast_internal::Type; using Severity = TypeCheckIssue::Severity; @@ -1281,6 +1282,63 @@ TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { "expected type 'map' but found 'map'"))); } +TEST(TypeCheckerImplTest, BadSourcePosition) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + auto& ast_impl = AstImpl::CastFromPublicAst(*ast); + ast_impl.source_info().mutable_positions()[1] = -42; + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo")); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ( + result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in container '')"); +} + +TEST(TypeCheckerImplTest, BadLineOffsets) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto source, NewSource("\nfoo")); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); + auto& ast_impl = AstImpl::CastFromPublicAst(*ast); + ast_impl.source_info().mutable_line_offsets()[1] = 1; + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in " + "container '')"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); + auto& ast_impl = AstImpl::CastFromPublicAst(*ast); + ast_impl.source_info().mutable_line_offsets().clear(); + ast_impl.source_info().mutable_line_offsets().push_back(-1); + ast_impl.source_info().mutable_line_offsets().push_back(2); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + ASSERT_THAT(result.GetIssues(), SizeIs(1)); + + EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), + "ERROR: :-1:-1: undeclared reference to 'foo' (in " + "container '')"); + } +} + TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container("google.protobuf"); From ef155f8fd4f4dc46e7ea8d0c87f2774eccbf48c8 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 19 Nov 2024 11:31:31 -0800 Subject: [PATCH 046/180] Add option to set input expression size limit in type checker. If exceeded, type checking fails early instead of fully visiting the input AST. PiperOrigin-RevId: 698088256 --- checker/checker_options.h | 6 ++++++ checker/internal/type_checker_impl.cc | 21 ++++++++++++++++++--- checker/internal/type_checker_impl_test.cc | 18 ++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/checker/checker_options.h b/checker/checker_options.h index 839446180..91fdad3e0 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -42,6 +42,12 @@ struct CheckerOptions { // Enabled by default, but can be disabled to preserve the original type name // as parsed. bool update_struct_type_names = true; + + // Maximum number (inclusive) of expression nodes to check for an input + // expression. + // + // If exceeded, the checker should return a status with code InvalidArgument. + int max_expression_node_count = 100000; }; } // namespace cel diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index bae1c43d7..c1a8ab4aa 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -458,9 +458,9 @@ class ResolveVisitor : public AstVisitorBase { // These are handled separately to disambiguate between namespaces and field // accesses absl::flat_hash_set deferred_select_operations_; - absl::Status status_; std::vector> comprehension_vars_; std::vector comprehension_scopes_; + absl::Status status_; // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; @@ -1252,8 +1252,23 @@ absl::StatusOr TypeCheckerImpl::Check( TraversalOptions opts; opts.use_comprehension_callbacks = true; - AstTraverse(ast_impl.root_expr(), visitor, opts); - CEL_RETURN_IF_ERROR(visitor.status()); + + auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts); + for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { + bool has_next = traversal.Step(visitor); + if (!visitor.status().ok()) { + return visitor.status(); + } + if (!has_next) { + break; + } + } + + if (!traversal.IsDone()) { + return absl::InvalidArgumentError( + absl::StrCat("Max expression node count exceeded: ", + options_.max_expression_node_count)); + } if (env_.expected_type().has_value()) { visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index a2b6fdade..c72e57a12 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -57,6 +57,7 @@ namespace checker_internal { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; using ::cel::expr::conformance::proto3::TestAllTypes; @@ -65,6 +66,7 @@ using ::testing::_; using ::testing::Contains; using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Property; @@ -1013,6 +1015,22 @@ TEST(TypeCheckerImplTest, NullLiteral) { EXPECT_TRUE(ast_impl.type_map()[1].has_null()); } +TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + CheckerOptions options; + options.max_expression_node_count = 2; + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("{}.foo.bar")); + EXPECT_THAT(impl.Check(std::move(ast)), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expression node count exceeded: 2"))); +} + TEST(TypeCheckerImplTest, ComprehensionUnsupportedRange) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; From 77a8b0fa047e8b2041567ca827051de0884d770e Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 19 Nov 2024 13:16:54 -0800 Subject: [PATCH 047/180] Add option for maximum number of ERROR level issues. If the limit is passed, the checker should stop validating and just return the current set of issues. PiperOrigin-RevId: 698122196 --- checker/checker_options.h | 7 ++++ checker/internal/type_checker_impl.cc | 47 +++++++++++++++------- checker/internal/type_checker_impl_test.cc | 29 +++++++++++++ 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/checker/checker_options.h b/checker/checker_options.h index 91fdad3e0..5101281a6 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -48,6 +48,13 @@ struct CheckerOptions { // // If exceeded, the checker should return a status with code InvalidArgument. int max_expression_node_count = 100000; + + // Maximum number (inclusive) of error-level issues to tolerate for an input + // ast. + // + // If exceeded, the checker will stop processing the ast and return + // the current set of issues. + int max_error_issues = 20; }; } // namespace cel diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index c1a8ab4aa..f6b96916b 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -315,6 +315,8 @@ class ResolveVisitor : public AstVisitorBase { const absl::Status& status() const { return status_; } + int error_count() const { return error_count_; } + void AssertExpectedType(const Expr& expr, const Type& expected_type) { Type observed = GetTypeOrDyn(&expr); if (!inference_context_->IsAssignable(observed, expected_type)) { @@ -364,8 +366,15 @@ class ResolveVisitor : public AstVisitorBase { void ResolveSelectOperation(const Expr& expr, absl::string_view field, const Expr& operand); + void ReportIssue(TypeCheckIssue issue) { + if (issue.severity() == Severity::kError) { + error_count_++; + } + issues_->push_back(std::move(issue)); + } + void ReportMissingReference(const Expr& expr, absl::string_view name) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", container_, "')"))); @@ -373,7 +382,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); @@ -381,7 +390,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr_id), absl::StrCat("expected type '", inference_context_->FinalizeType(expected).DebugString(), @@ -412,7 +421,7 @@ class ResolveVisitor : public AstVisitorBase { } if (!inference_context_->IsAssignable(value_type, field_type) && !IsPbNullFieldAssignable(value_type, field_type)) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", @@ -461,6 +470,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector> comprehension_vars_; std::vector comprehension_scopes_; absl::Status status_; + int error_count_ = 0; // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; @@ -546,7 +556,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, types_[&expr] = TimestampType(); break; default: - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); @@ -597,7 +607,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. - issues_->push_back(TypeCheckIssue( + ReportIssue(TypeCheckIssue( Severity::kWarning, ComputeSourceLocation(*ast_, key->id()), absl::StrCat( "unsupported map key type: ", @@ -702,7 +712,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); @@ -849,7 +859,7 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( case TypeKind::kDyn: break; default: - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, comprehension.iter_range().id()), absl::StrCat( "expression of type '", @@ -923,7 +933,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", @@ -1085,7 +1095,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, break; } - issues_->push_back(TypeCheckIssue::CreateError( + ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, id), absl::StrCat("expression of type '", inference_context_->FinalizeType(operand_type).DebugString(), @@ -1252,25 +1262,34 @@ absl::StatusOr TypeCheckerImpl::Check( TraversalOptions opts; opts.use_comprehension_callbacks = true; - + bool error_limit_reached = false; auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts); + for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { bool has_next = traversal.Step(visitor); if (!visitor.status().ok()) { return visitor.status(); } + if (visitor.error_count() > options_.max_error_issues) { + error_limit_reached = true; + break; + } if (!has_next) { break; } } - if (!traversal.IsDone()) { + if (!traversal.IsDone() && !error_limit_reached) { return absl::InvalidArgumentError( - absl::StrCat("Max expression node count exceeded: ", + absl::StrCat("Maximum expression node count exceeded: ", options_.max_expression_node_count)); } - if (env_.expected_type().has_value()) { + if (error_limit_reached) { + issues.push_back(TypeCheckIssue::CreateError( + {}, absl::StrCat("maximum number of ERROR issues exceeded: ", + options_.max_error_issues))); + } else if (env_.expected_type().has_value()) { visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index c72e57a12..9c95fbf22 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -340,6 +340,35 @@ TEST(TypeCheckerImplTest, ReportMissingIdentDecl) { "undeclared reference to 'y'"))); } +TEST(TypeCheckerImplTest, ErrorLimitInclusive) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + CheckerOptions options; + options.max_error_issues = 1; + + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + ElementsAre(IsIssueWithSubstring(Severity::kError, + "undeclared reference to 'y'"))); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("x + y + z")); + ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + ElementsAre( + IsIssueWithSubstring(Severity::kError, "undeclared reference to 'x'"), + IsIssueWithSubstring(Severity::kError, "undeclared reference to 'y'"), + IsIssueWithSubstring(Severity::kError, + "maximum number of ERROR issues exceeded: 1"))); +} + MATCHER_P3(IsIssueWithLocation, line, column, message, "") { const TypeCheckIssue& issue = arg; if (issue.location().line == line && issue.location().column == column && From 8d527656cef2908a96d4feeafa7a3730d8531655 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 20 Nov 2024 10:41:31 -0800 Subject: [PATCH 048/180] Refactor factory functions for TypeCheckerBuilders. - move the factories to their own cc_library - update return value to be std::unique_ptr PiperOrigin-RevId: 698442862 --- checker/BUILD | 29 ++++- checker/optional_test.cc | 27 ++--- checker/standard_library_test.cc | 36 ++++--- checker/type_checker_builder.cc | 29 ----- checker/type_checker_builder.h | 31 +----- checker/type_checker_builder_factory.cc | 58 ++++++++++ checker/type_checker_builder_factory.h | 57 ++++++++++ ...c => type_checker_builder_factory_test.cc} | 100 +++++++++--------- compiler/BUILD | 1 + compiler/compiler_factory.cc | 11 +- conformance/BUILD | 2 +- conformance/service.cc | 22 ++-- 12 files changed, 244 insertions(+), 159 deletions(-) create mode 100644 checker/type_checker_builder_factory.cc create mode 100644 checker/type_checker_builder_factory.h rename checker/{type_checker_builder_test.cc => type_checker_builder_factory_test.cc} (72%) diff --git a/checker/BUILD b/checker/BUILD index df9049a12..b0926ee26 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -91,16 +91,13 @@ cc_library( "//checker/internal:type_checker_impl", "//common:decl", "//common:type", - "//internal:noop_delete", "//internal:status_macros", - "//internal:well_known_types", "//parser:macro", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -108,11 +105,30 @@ cc_library( ], ) +cc_library( + name = "type_checker_builder_factory", + srcs = ["type_checker_builder_factory.cc"], + hdrs = ["type_checker_builder_factory.h"], + deps = [ + ":checker_options", + ":type_checker_builder", + "//internal:noop_delete", + "//internal:status_macros", + "//internal:well_known_types", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( - name = "type_checker_builder_test", - srcs = ["type_checker_builder_test.cc"], + name = "type_checker_builder_factory_test", + srcs = ["type_checker_builder_factory_test.cc"], deps = [ ":type_checker_builder", + ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", "//common:decl", @@ -145,9 +161,11 @@ cc_test( name = "standard_library_test", srcs = ["standard_library_test.cc"], deps = [ + ":checker_options", ":standard_library", ":type_checker", ":type_checker_builder", + ":type_checker_builder_factory", ":validation_result", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", @@ -190,6 +208,7 @@ cc_test( ":type_check_issue", ":type_checker", ":type_checker_builder", + ":type_checker_builder_factory", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", "//checker/internal:test_ast_helpers", diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 126225668..abc7f3950 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -29,6 +29,7 @@ #include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -77,13 +78,13 @@ MATCHER_P(IsOptionalType, inner_type, "") { TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); - builder.set_container("cel.expr.conformance.proto3"); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, - std::move(builder).Build()); + std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("TestAllTypes{}.?single_int64")); @@ -113,13 +114,13 @@ class OptionalTest : public testing::TestWithParam {}; TEST_P(OptionalTest, Runner) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); const TestCase& test_case = GetParam(); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, - std::move(builder).Build()); + std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); @@ -284,13 +285,13 @@ TEST_P(OptionalStrictNullAssignmentTest, Runner) { CheckerOptions options; options.enable_legacy_null_assignment = false; ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); const TestCase& test_case = GetParam(); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - ASSERT_THAT(builder.AddLibrary(OptionalCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, - std::move(builder).Build()); + std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); diff --git a/checker/standard_library_test.cc b/checker/standard_library_test.cc index ca51b8aaa..8a0ab1ac7 100644 --- a/checker/standard_library_test.cc +++ b/checker/standard_library_test.cc @@ -22,9 +22,11 @@ #include "absl/status/status_matchers.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" +#include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/constant.h" @@ -50,27 +52,27 @@ using AstType = cel::ast_internal::Type; TEST(StandardLibraryTest, StandardLibraryAddsDecls) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - EXPECT_THAT(std::move(builder).Build(), IsOk()); + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + EXPECT_THAT(std::move(*builder).Build(), IsOk()); } TEST(StandardLibraryTest, StandardLibraryErrorsIfAddedTwice) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - EXPECT_THAT(builder.AddLibrary(StandardCheckerLibrary()), + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); // Note: this is atypical -- parameterized variables aren't well supported // outside of built-in syntax. @@ -83,13 +85,13 @@ TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) { Type list_type = ListType(&arena, TypeParamType("V")); Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V")); - ASSERT_THAT(builder.AddVariable(MakeVariableDecl("list_var", list_type)), + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("list_var", list_type)), IsOk()); - ASSERT_THAT(builder.AddVariable(MakeVariableDecl("map_var", map_type)), + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("map_var", map_type)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, - std::move(builder).Build()); + std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN( auto ast, checker_internal::MakeTestParsedAst( @@ -108,10 +110,10 @@ class StandardLibraryDefinitionsTest : public ::testing::Test { public: void SetUp() override { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); - ASSERT_OK_AND_ASSIGN(stdlib_type_checker_, std::move(builder).Build()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_OK_AND_ASSIGN(stdlib_type_checker_, std::move(*builder).Build()); } protected: @@ -212,12 +214,12 @@ class StdLibDefinitionsTest // Type-parameterized functions are not yet checkable. TEST_P(StdLibDefinitionsTest, Runner) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), GetParam().options)); - ASSERT_THAT(builder.AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, - std::move(builder).Build()); + std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, checker_internal::MakeTestParsedAst(GetParam().expr)); diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index f7c2e9064..e17b28691 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -20,25 +20,19 @@ #include #include "absl/base/no_destructor.h" -#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" -#include "internal/noop_delete.h" #include "internal/status_macros.h" -#include "internal/well_known_types.h" #include "parser/macro.h" -#include "google/protobuf/descriptor.h" namespace cel { namespace { @@ -83,29 +77,6 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } // namespace -absl::StatusOr CreateTypeCheckerBuilder( - absl::Nonnull descriptor_pool, - const CheckerOptions& options) { - ABSL_DCHECK(descriptor_pool != nullptr); - return CreateTypeCheckerBuilder( - std::shared_ptr( - descriptor_pool, - internal::NoopDeleteFor()), - options); -} - -absl::StatusOr CreateTypeCheckerBuilder( - absl::Nonnull> - descriptor_pool, - const CheckerOptions& options) { - ABSL_DCHECK(descriptor_pool != nullptr); - // Verify the standard descriptors, we do not need to keep - // `well_known_types::Reflection` at the moment here. - CEL_RETURN_IF_ERROR( - well_known_types::Reflection().Initialize(descriptor_pool.get())); - return TypeCheckerBuilder(std::move(descriptor_pool), options); -} - absl::StatusOr> TypeCheckerBuilder::Build() && { auto checker = std::make_unique( std::move(env_), options_); diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f4b3386a7..e2994f560 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -38,34 +38,6 @@ namespace cel { class TypeCheckerBuilder; -// Creates a new `TypeCheckerBuilder`. -// -// When passing a raw pointer to a descriptor pool, the descriptor pool must -// outlive the type checker builder and the type checker builder it creates. -// -// The descriptor pool must include the minimally necessary -// descriptors required by CEL. Those are the following: -// - google.protobuf.NullValue -// - google.protobuf.BoolValue -// - google.protobuf.Int32Value -// - google.protobuf.Int64Value -// - google.protobuf.UInt32Value -// - google.protobuf.UInt64Value -// - google.protobuf.FloatValue -// - google.protobuf.DoubleValue -// - google.protobuf.BytesValue -// - google.protobuf.StringValue -// - google.protobuf.Any -// - google.protobuf.Duration -// - google.protobuf.Timestamp -absl::StatusOr CreateTypeCheckerBuilder( - absl::Nonnull descriptor_pool, - const CheckerOptions& options = {}); -absl::StatusOr CreateTypeCheckerBuilder( - absl::Nonnull> - descriptor_pool, - const CheckerOptions& options = {}); - // Functional implementation to apply the library features to a // TypeCheckerBuilder. using TypeCheckerBuilderConfigurer = @@ -109,7 +81,8 @@ class TypeCheckerBuilder { const CheckerOptions& options() const { return options_; } private: - friend absl::StatusOr CreateTypeCheckerBuilder( + friend absl::StatusOr> + CreateTypeCheckerBuilder( absl::Nonnull> descriptor_pool, const CheckerOptions& options); diff --git a/checker/type_checker_builder_factory.cc b/checker/type_checker_builder_factory.cc new file mode 100644 index 000000000..78bd89fb1 --- /dev/null +++ b/checker/type_checker_builder_factory.cc @@ -0,0 +1,58 @@ + +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker_builder_factory.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "checker/checker_options.h" +#include "checker/type_checker_builder.h" +#include "internal/noop_delete.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr> CreateTypeCheckerBuilder( + absl::Nonnull descriptor_pool, + const CheckerOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + return CreateTypeCheckerBuilder( + std::shared_ptr( + descriptor_pool, + internal::NoopDeleteFor()), + options); +} + +absl::StatusOr> CreateTypeCheckerBuilder( + absl::Nonnull> + descriptor_pool, + const CheckerOptions& options) { + ABSL_DCHECK(descriptor_pool != nullptr); + // Verify the standard descriptors, we do not need to keep + // `well_known_types::Reflection` at the moment here. + CEL_RETURN_IF_ERROR( + well_known_types::Reflection().Initialize(descriptor_pool.get())); + auto* builder = new TypeCheckerBuilder(std::move(descriptor_pool), options); + return absl::WrapUnique(builder); +} + +} // namespace cel diff --git a/checker/type_checker_builder_factory.h b/checker/type_checker_builder_factory.h new file mode 100644 index 000000000..93c603394 --- /dev/null +++ b/checker/type_checker_builder_factory.h @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "checker/checker_options.h" +#include "checker/type_checker_builder.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Creates a new `TypeCheckerBuilder`. +// +// When passing a raw pointer to a descriptor pool, the descriptor pool must +// outlive the type checker builder and the type checker builder it creates. +// +// The descriptor pool must include the minimally necessary +// descriptors required by CEL. Those are the following: +// - google.protobuf.NullValue +// - google.protobuf.BoolValue +// - google.protobuf.Int32Value +// - google.protobuf.Int64Value +// - google.protobuf.UInt32Value +// - google.protobuf.UInt64Value +// - google.protobuf.FloatValue +// - google.protobuf.DoubleValue +// - google.protobuf.BytesValue +// - google.protobuf.StringValue +// - google.protobuf.Any +// - google.protobuf.Duration +// - google.protobuf.Timestamp +absl::StatusOr> CreateTypeCheckerBuilder( + absl::Nonnull descriptor_pool, + const CheckerOptions& options = {}); +absl::StatusOr> CreateTypeCheckerBuilder( + absl::Nonnull> + descriptor_pool, + const CheckerOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ diff --git a/checker/type_checker_builder_test.cc b/checker/type_checker_builder_factory_test.cc similarity index 72% rename from checker/type_checker_builder_test.cc rename to checker/type_checker_builder_factory_test.cc index 82e255e78..fde2106a0 100644 --- a/checker/type_checker_builder_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" +#include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "checker/internal/test_ast_helpers.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/type.h" @@ -36,12 +38,12 @@ using ::testing::HasSubstr; TEST(TypeCheckerBuilderTest, AddVariable) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); - ASSERT_OK_AND_ASSIGN(auto checker, std::move(builder).Build()); + ASSERT_OK_AND_ASSIGN(auto checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); @@ -49,17 +51,17 @@ TEST(TypeCheckerBuilderTest, AddVariable) { TEST(TypeCheckerBuilderTest, AddVariableRedeclaredError) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); - ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", IntType())), IsOk()); - EXPECT_THAT(builder.AddVariable(MakeVariableDecl("x", IntType())), + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + EXPECT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(TypeCheckerBuilderTest, AddFunction) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -67,8 +69,8 @@ TEST(TypeCheckerBuilderTest, AddFunction) { MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); - ASSERT_THAT(builder.AddFunction(fn_decl), IsOk()); - ASSERT_OK_AND_ASSIGN(auto checker, std::move(builder).Build()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_OK_AND_ASSIGN(auto checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); @@ -76,7 +78,7 @@ TEST(TypeCheckerBuilderTest, AddFunction) { TEST(TypeCheckerBuilderTest, AddFunctionRedeclaredError) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -84,14 +86,14 @@ TEST(TypeCheckerBuilderTest, AddFunctionRedeclaredError) { MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); - ASSERT_THAT(builder.AddFunction(fn_decl), IsOk()); - EXPECT_THAT(builder.AddFunction(fn_decl), + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(TypeCheckerBuilderTest, AddLibrary) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -99,13 +101,13 @@ TEST(TypeCheckerBuilderTest, AddLibrary) { MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); - ASSERT_THAT(builder.AddLibrary({"", - [&](TypeCheckerBuilder& b) { - return builder.AddFunction(fn_decl); - }}), + ASSERT_THAT(builder->AddLibrary({"", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), IsOk()); - ASSERT_OK_AND_ASSIGN(auto checker, std::move(builder).Build()); + ASSERT_OK_AND_ASSIGN(auto checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); @@ -113,7 +115,7 @@ TEST(TypeCheckerBuilderTest, AddLibrary) { TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -121,21 +123,21 @@ TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); - ASSERT_THAT(builder.AddLibrary({"testlib", - [&](TypeCheckerBuilder& b) { - return builder.AddFunction(fn_decl); - }}), + ASSERT_THAT(builder->AddLibrary({"testlib", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), IsOk()); - EXPECT_THAT(builder.AddLibrary({"testlib", - [&](TypeCheckerBuilder& b) { - return builder.AddFunction(fn_decl); - }}), + EXPECT_THAT(builder->AddLibrary({"testlib", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("testlib"))); } TEST(TypeCheckerBuilderTest, AddLibraryForwardsErrors) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -143,21 +145,21 @@ TEST(TypeCheckerBuilderTest, AddLibraryForwardsErrors) { MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); - ASSERT_THAT(builder.AddLibrary({"", - [&](TypeCheckerBuilder& b) { - return builder.AddFunction(fn_decl); - }}), + ASSERT_THAT(builder->AddLibrary({"", + [&](TypeCheckerBuilder& b) { + return builder->AddFunction(fn_decl); + }}), IsOk()); - EXPECT_THAT(builder.AddLibrary({"", - [](TypeCheckerBuilder& b) { - return absl::InternalError("test error"); - }}), + EXPECT_THAT(builder->AddLibrary({"", + [](TypeCheckerBuilder& b) { + return absl::InternalError("test error"); + }}), StatusIs(absl::StatusCode::kInternal, HasSubstr("test error"))); } TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -165,42 +167,42 @@ TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { "ovl_3", ListType(), ListType(), DynType(), DynType()))); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'map' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("filter"); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'filter' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("exists"); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'exists' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("exists_one"); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'exists_one' with 3 argument(s) " "overlaps with predefined macro")); fn_decl.set_name("all"); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'all' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("optMap"); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'optMap' with 3 argument(s) overlaps " "with predefined macro")); @@ -208,7 +210,7 @@ TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { fn_decl.set_name("optFlatMap"); EXPECT_THAT( - builder.AddFunction(fn_decl), + builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'optFlatMap' with 3 argument(s) overlaps " "with predefined macro")); @@ -217,7 +219,7 @@ TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { fn_decl, MakeFunctionDecl( "has", MakeOverloadDecl("ovl_1", BoolType(), DynType()))); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'has' with 1 argument(s) overlaps " "with predefined macro")); @@ -228,7 +230,7 @@ TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { DynType(), DynType(), DynType()))); - EXPECT_THAT(builder.AddFunction(fn_decl), + EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'map' with 4 argument(s) overlaps " "with predefined macro")); @@ -236,7 +238,7 @@ TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { ASSERT_OK_AND_ASSIGN( - TypeCheckerBuilder builder, + std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( @@ -244,7 +246,7 @@ TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { MakeFunctionDecl("has", MakeMemberOverloadDecl("ovl", BoolType(), DynType(), StringType()))); - EXPECT_THAT(builder.AddFunction(fn_decl), IsOk()); + EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); } } // namespace diff --git a/compiler/BUILD b/compiler/BUILD index 7cfe940e5..22894ee78 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -38,6 +38,7 @@ cc_library( ":compiler", "//checker:type_checker", "//checker:type_checker_builder", + "//checker:type_checker_builder_factory", "//checker:validation_result", "//common:source", "//internal:noop_delete", diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 4a1f3209f..96a6c5b2e 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -26,6 +26,7 @@ #include "absl/strings/string_view.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" #include "common/source.h" #include "compiler/compiler.h" @@ -64,7 +65,7 @@ class CompilerImpl : public Compiler { class CompilerBuilderImpl : public CompilerBuilder { public: - CompilerBuilderImpl(TypeCheckerBuilder type_checker_builder, + CompilerBuilderImpl(std::unique_ptr type_checker_builder, std::unique_ptr parser_builder) : type_checker_builder_(std::move(type_checker_builder)), parser_builder_(std::move(parser_builder)) {} @@ -80,7 +81,7 @@ class CompilerBuilderImpl : public CompilerBuilder { } if (library.configure_checker) { - CEL_RETURN_IF_ERROR(type_checker_builder_.AddLibrary({ + CEL_RETURN_IF_ERROR(type_checker_builder_->AddLibrary({ .id = std::move(library.id), .configure = std::move(library.configure_checker), })); @@ -93,7 +94,7 @@ class CompilerBuilderImpl : public CompilerBuilder { ParserBuilder& GetParserBuilder() override { return *parser_builder_; } TypeCheckerBuilder& GetCheckerBuilder() override { - return type_checker_builder_; + return *type_checker_builder_; } absl::StatusOr> Build() && override { @@ -102,13 +103,13 @@ class CompilerBuilderImpl : public CompilerBuilder { } CEL_ASSIGN_OR_RETURN(auto parser, std::move(*parser_builder_).Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, - std::move(type_checker_builder_).Build()); + std::move(*type_checker_builder_).Build()); return std::make_unique(std::move(type_checker), std::move(parser)); } private: - TypeCheckerBuilder type_checker_builder_; + std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; absl::flat_hash_set library_ids_; diff --git a/conformance/BUILD b/conformance/BUILD index 43735aafc..9bd9fb722 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -57,6 +57,7 @@ cc_library( "//checker:optional", "//checker:standard_library", "//checker:type_checker_builder", + "//checker:type_checker_builder_factory", "//common:ast", "//common:decl", "//common:expr", @@ -81,7 +82,6 @@ cc_library( "//extensions/protobuf:ast_converters", "//extensions/protobuf:enum_adapter", "//extensions/protobuf:memory_manager", - "//extensions/protobuf:value", "//internal:status_macros", "//parser", "//parser:macro", diff --git a/conformance/service.cc b/conformance/service.cc index b89046e1c..aab66e58b 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -44,6 +44,7 @@ #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" +#include "checker/type_checker_builder_factory.h" #include "common/ast.h" #include "common/decl.h" #include "common/expr.h" @@ -68,7 +69,6 @@ #include "extensions/protobuf/ast_converters.h" #include "extensions/protobuf/enum_adapter.h" #include "extensions/protobuf/memory_manager.h" -#include "extensions/protobuf/type_reflector.h" #include "extensions/strings.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -634,19 +634,19 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { CEL_ASSIGN_OR_RETURN(source, cel::NewSource(location)); } - CEL_ASSIGN_OR_RETURN(cel::TypeCheckerBuilder builder, + CEL_ASSIGN_OR_RETURN(std::unique_ptr builder, cel::CreateTypeCheckerBuilder( google::protobuf::DescriptorPool::generated_pool())); if (!request.no_std_env()) { - CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::StandardCheckerLibrary())); - CEL_RETURN_IF_ERROR(builder.AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); CEL_RETURN_IF_ERROR( - builder.AddLibrary(cel::extensions::StringsCheckerLibrary())); + builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( - builder.AddLibrary(cel::extensions::MathCheckerLibrary())); + builder->AddLibrary(cel::extensions::MathCheckerLibrary())); CEL_RETURN_IF_ERROR( - builder.AddLibrary(cel::extensions::EncodersCheckerLibrary())); + builder->AddLibrary(cel::extensions::EncodersCheckerLibrary())); } for (const auto& decl : request.type_env()) { @@ -673,19 +673,19 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(fn_decl.AddOverload(std::move(overload))); } - CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); + CEL_RETURN_IF_ERROR(builder->AddFunction(std::move(fn_decl))); } else if (decl.has_ident()) { VariableDecl var_decl; var_decl.set_name(name); CEL_ASSIGN_OR_RETURN(auto var_type, FromConformanceType(arena, decl.ident().type())); var_decl.set_type(var_type); - CEL_RETURN_IF_ERROR(builder.AddVariable(var_decl)); + CEL_RETURN_IF_ERROR(builder->AddVariable(var_decl)); } } - builder.set_container(request.container()); + builder->set_container(request.container()); - CEL_ASSIGN_OR_RETURN(auto checker, std::move(builder).Build()); + CEL_ASSIGN_OR_RETURN(auto checker, std::move(*builder).Build()); CEL_ASSIGN_OR_RETURN(auto validation_result, checker->Check(std::move(ast))); From ceb592c88880f6960215d9201ad2cbb7f98d0ba9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 20 Nov 2024 10:58:54 -0800 Subject: [PATCH 049/180] Make TypeCheckerBuilder class abstract. PiperOrigin-RevId: 698449337 --- checker/BUILD | 5 +- checker/internal/BUILD | 14 +++- .../type_checker_builder_impl.cc} | 24 +++--- checker/internal/type_checker_builder_impl.h | 83 +++++++++++++++++++ checker/type_checker_builder.h | 72 ++++++++-------- checker/type_checker_builder_factory.cc | 6 +- 6 files changed, 146 insertions(+), 58 deletions(-) rename checker/{type_checker_builder.cc => internal/type_checker_builder_impl.cc} (86%) create mode 100644 checker/internal/type_checker_builder_impl.h diff --git a/checker/BUILD b/checker/BUILD index b0926ee26..7626284a6 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -82,13 +82,10 @@ cc_library( cc_library( name = "type_checker_builder", - srcs = ["type_checker_builder.cc"], hdrs = ["type_checker_builder.h"], deps = [ ":checker_options", ":type_checker", - "//checker/internal:type_check_env", - "//checker/internal:type_checker_impl", "//common:decl", "//common:type", "//internal:status_macros", @@ -112,12 +109,12 @@ cc_library( deps = [ ":checker_options", ":type_checker_builder", + "//checker/internal:type_checker_impl", "//internal:noop_delete", "//internal:status_macros", "//internal:well_known_types", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 68ea74f4f..aec0701bb 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -107,8 +107,14 @@ cc_test( cc_library( name = "type_checker_impl", - srcs = ["type_checker_impl.cc"], - hdrs = ["type_checker_impl.h"], + srcs = [ + "type_checker_builder_impl.cc", + "type_checker_impl.cc", + ], + hdrs = [ + "type_checker_builder_impl.h", + "type_checker_impl.h", + ], deps = [ ":namespace_generator", ":type_check_env", @@ -118,6 +124,7 @@ cc_library( "//checker:checker_options", "//checker:type_check_issue", "//checker:type_checker", + "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:ast_rewrite", @@ -131,9 +138,12 @@ cc_library( "//common:type", "//common:type_kind", "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/checker/type_checker_builder.cc b/checker/internal/type_checker_builder_impl.cc similarity index 86% rename from checker/type_checker_builder.cc rename to checker/internal/type_checker_builder_impl.cc index e17b28691..6f0345290 100644 --- a/checker/type_checker_builder.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/type_checker_builder.h" +#include "checker/internal/type_checker_builder_impl.h" #include #include @@ -28,13 +28,14 @@ #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" +#include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "parser/macro.h" -namespace cel { +namespace cel::checker_internal { namespace { const absl::flat_hash_map>& GetStdMacros() { @@ -77,13 +78,14 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } // namespace -absl::StatusOr> TypeCheckerBuilder::Build() && { +absl::StatusOr> +TypeCheckerBuilderImpl::Build() && { auto checker = std::make_unique( std::move(env_), options_); return checker; } -absl::Status TypeCheckerBuilder::AddLibrary(CheckerLibrary library) { +absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { if (!library.id.empty() && !library_ids_.insert(library.id).second) { return absl::AlreadyExistsError( absl::StrCat("library '", library.id, "' already exists")); @@ -97,7 +99,7 @@ absl::Status TypeCheckerBuilder::AddLibrary(CheckerLibrary library) { return status; } -absl::Status TypeCheckerBuilder::AddVariable(const VariableDecl& decl) { +absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) { bool inserted = env_.InsertVariableIfAbsent(decl); if (!inserted) { return absl::AlreadyExistsError( @@ -106,7 +108,7 @@ absl::Status TypeCheckerBuilder::AddVariable(const VariableDecl& decl) { return absl::OkStatus(); } -absl::Status TypeCheckerBuilder::AddFunction(const FunctionDecl& decl) { +absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl)); bool inserted = env_.InsertFunctionIfAbsent(decl); if (!inserted) { @@ -116,7 +118,7 @@ absl::Status TypeCheckerBuilder::AddFunction(const FunctionDecl& decl) { return absl::OkStatus(); } -absl::Status TypeCheckerBuilder::MergeFunction(const FunctionDecl& decl) { +absl::Status TypeCheckerBuilderImpl::MergeFunction(const FunctionDecl& decl) { const FunctionDecl* existing = env_.LookupFunction(decl.name()); if (existing == nullptr) { return AddFunction(decl); @@ -140,17 +142,17 @@ absl::Status TypeCheckerBuilder::MergeFunction(const FunctionDecl& decl) { return absl::OkStatus(); } -void TypeCheckerBuilder::AddTypeProvider( +void TypeCheckerBuilderImpl::AddTypeProvider( std::unique_ptr provider) { env_.AddTypeProvider(std::move(provider)); } -void TypeCheckerBuilder::set_container(absl::string_view container) { +void TypeCheckerBuilderImpl::set_container(absl::string_view container) { env_.set_container(std::string(container)); } -void TypeCheckerBuilder::SetExpectedType(const Type& type) { +void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { env_.set_expected_type(type); } -} // namespace cel +} // namespace cel::checker_internal diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h new file mode 100644 index 000000000..377b9ec20 --- /dev/null +++ b/checker/internal/type_checker_builder_impl.h @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/checker_options.h" +#include "checker/internal/type_check_env.h" +#include "checker/type_checker.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +class TypeCheckerBuilderImpl; + +// Builder for TypeChecker instances. +class TypeCheckerBuilderImpl : public TypeCheckerBuilder { + public: + TypeCheckerBuilderImpl( + absl::Nonnull> + descriptor_pool, + const CheckerOptions& options) + : options_(options), env_(std::move(descriptor_pool)) {} + + // Move only. + TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; + TypeCheckerBuilderImpl(TypeCheckerBuilderImpl&&) = default; + TypeCheckerBuilderImpl& operator=(const TypeCheckerBuilderImpl&) = delete; + TypeCheckerBuilderImpl& operator=(TypeCheckerBuilderImpl&&) = default; + + absl::StatusOr> Build() && override; + + absl::Status AddLibrary(CheckerLibrary library) override; + + absl::Status AddVariable(const VariableDecl& decl) override; + absl::Status AddFunction(const FunctionDecl& decl) override; + + void SetExpectedType(const Type& type) override; + + absl::Status MergeFunction(const FunctionDecl& decl) override; + + void AddTypeProvider(std::unique_ptr provider) override; + + void set_container(absl::string_view container) override; + + const CheckerOptions& options() const override { return options_; } + + private: + CheckerOptions options_; + std::vector libraries_; + absl::flat_hash_set library_ids_; + + checker_internal::TypeCheckEnv env_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index e2994f560..c93c8d1db 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -17,26 +17,21 @@ #include #include -#include -#include -#include "absl/base/nullability.h" -#include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/checker_options.h" -#include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" namespace cel { class TypeCheckerBuilder; +class TypeCheckerBuilderImpl; // Functional implementation to apply the library features to a // TypeCheckerBuilder. @@ -50,54 +45,55 @@ struct CheckerLibrary { TypeCheckerBuilderConfigurer configure; }; -// Builder for TypeChecker instances. +// Interface for TypeCheckerBuilders. class TypeCheckerBuilder { public: - TypeCheckerBuilder(const TypeCheckerBuilder&) = delete; - TypeCheckerBuilder(TypeCheckerBuilder&&) = default; - TypeCheckerBuilder& operator=(const TypeCheckerBuilder&) = delete; - TypeCheckerBuilder& operator=(TypeCheckerBuilder&&) = default; + virtual ~TypeCheckerBuilder() = default; - absl::StatusOr> Build() &&; + // Adds a library to the TypeChecker being built. + virtual absl::Status AddLibrary(CheckerLibrary library) = 0; - absl::Status AddLibrary(CheckerLibrary library); + // Adds a variable declaration that may be referenced in expressions checked + // with the resulting type checker. + virtual absl::Status AddVariable(const VariableDecl& decl) = 0; - absl::Status AddVariable(const VariableDecl& decl); - absl::Status AddFunction(const FunctionDecl& decl); + // Adds a function declaration that may be referenced in expressions checked + // with the resulting TypeChecker. + virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; - void SetExpectedType(const Type& type); + // Sets the expected type for checked expressions. + // + // Validation will fail with an ERROR level issue if the deduced type of the + // expression is not assignable to this type. + virtual void SetExpectedType(const Type& type) = 0; // Adds function declaration overloads to the TypeChecker being built. // // Attempts to merge with any existing overloads for a function decl with the // same name. If the overloads are not compatible, an error is returned and // no change is made. - absl::Status MergeFunction(const FunctionDecl& decl); - - void AddTypeProvider(std::unique_ptr provider); - - void set_container(absl::string_view container); - - const CheckerOptions& options() const { return options_; } + virtual absl::Status MergeFunction(const FunctionDecl& decl) = 0; - private: - friend absl::StatusOr> - CreateTypeCheckerBuilder( - absl::Nonnull> - descriptor_pool, - const CheckerOptions& options); + // Adds a type provider to the TypeChecker being built. + // + // Type providers are used to describe custom types with typed field + // traversal. This is not needed for built-in types or protobuf messages + // described by the associated descriptor pool. + virtual void AddTypeProvider(std::unique_ptr provider) = 0; - TypeCheckerBuilder( - absl::Nonnull> - descriptor_pool, - const CheckerOptions& options) - : options_(options), env_(std::move(descriptor_pool)) {} + // Set the container for the TypeChecker being built. + // + // This is used for resolving references in the expressions being built. + virtual void set_container(absl::string_view container) = 0; - CheckerOptions options_; - std::vector libraries_; - absl::flat_hash_set library_ids_; + // The current options for the TypeChecker being built. + virtual const CheckerOptions& options() const = 0; - checker_internal::TypeCheckEnv env_; + // Builds the TypeChecker. + // + // This operation is destructive: the builder instance should not be used + // after this method is called. + virtual absl::StatusOr> Build() && = 0; }; } // namespace cel diff --git a/checker/type_checker_builder_factory.cc b/checker/type_checker_builder_factory.cc index 78bd89fb1..d06a7e5f7 100644 --- a/checker/type_checker_builder_factory.cc +++ b/checker/type_checker_builder_factory.cc @@ -20,9 +20,9 @@ #include "absl/base/nullability.h" #include "absl/log/absl_check.h" -#include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "checker/checker_options.h" +#include "checker/internal/type_checker_builder_impl.h" #include "checker/type_checker_builder.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" @@ -51,8 +51,8 @@ absl::StatusOr> CreateTypeCheckerBuilder( // `well_known_types::Reflection` at the moment here. CEL_RETURN_IF_ERROR( well_known_types::Reflection().Initialize(descriptor_pool.get())); - auto* builder = new TypeCheckerBuilder(std::move(descriptor_pool), options); - return absl::WrapUnique(builder); + return std::make_unique( + std::move(descriptor_pool), options); } } // namespace cel From d50f3d3ebc2b1cdda23583c697af265d22748ea4 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 20 Nov 2024 15:16:41 -0800 Subject: [PATCH 050/180] Update type checker to fail (return a status) if it fails to deduce the type of subexpression. PiperOrigin-RevId: 698538634 --- checker/internal/type_checker_impl.cc | 81 +++++++++++++--------- checker/internal/type_checker_impl_test.cc | 26 +++++++ 2 files changed, 76 insertions(+), 31 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index f6b96916b..f5c8c481b 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -318,7 +318,7 @@ class ResolveVisitor : public AstVisitorBase { int error_count() const { return error_count_; } void AssertExpectedType(const Expr& expr, const Type& expected_type) { - Type observed = GetTypeOrDyn(&expr); + Type observed = GetDeducedType(&expr); if (!inference_context_->IsAssignable(observed, expected_type)) { ReportTypeMismatch(expr.id(), expected_type, observed); } @@ -405,7 +405,7 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view resolved_name) { for (const auto& field : create_struct.fields()) { const Expr* value = &field.value(); - Type value_type = GetTypeOrDyn(value); + Type value_type = GetDeducedType(value); // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( @@ -441,12 +441,22 @@ class ResolveVisitor : public AstVisitorBase { void HandleOptSelect(const Expr& expr); - // TODO: This should switch to a failing check once all core - // features are supported. For now, we allow dyn for implementing the - // typechecker behaviors in isolation. - Type GetTypeOrDyn(const Expr* expr) { + // Get the assigned type of the given subexpression. Should only be called if + // the given subexpression is expected to have already been checked. + // + // If unknown, returns DynType as a placeholder and reports an error. + // Whether or not the subexpression is valid for the checker configuration, + // the type checker should have assigned a type (possibly ErrorType). If there + // is no assigned type, the type checker failed to handle the subexpression + // and should not attempt to continue type checking. + Type GetDeducedType(const Expr* expr) { auto iter = types_.find(expr); - return iter != types_.end() ? iter->second : DynType(); + if (iter != types_.end()) { + return iter->second; + } + status_.Update(absl::InvalidArgumentError( + absl::StrCat("Could not deduce type for expression id: ", expr->id()))); + return DynType(); } absl::string_view container_; @@ -560,6 +570,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); + types_[&expr] = ErrorType(); break; } } @@ -599,7 +610,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& entry : map.entries()) { const Expr* key = &entry.key(); - Type key_type = GetTypeOrDyn(key); + Type key_type = GetDeducedType(key); if (!IsSupportedKeyType(key_type)) { // The Go type checker implementation can allow any type as a map key, but // per the spec this should be limited to the types listed in @@ -626,7 +637,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { assignability_context.Reset(); for (const auto& entry : map.entries()) { const Expr* value = &entry.value(); - Type value_type = GetTypeOrDyn(value); + Type value_type = GetDeducedType(value); if (entry.optional()) { if (value_type.IsOptional()) { value_type = value_type.GetOptional().GetParameter(); @@ -657,7 +668,7 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& element : list.elements()) { const Expr* value = &element.expr(); - Type value_type = GetTypeOrDyn(value); + Type value_type = GetDeducedType(value); if (element.optional()) { if (value_type.IsOptional()) { value_type = value_type.GetOptional().GetParameter(); @@ -707,6 +718,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_name.empty()) { ReportMissingReference(expr, create_struct.name()); + types_[&expr] = ErrorType(); return; } @@ -716,6 +728,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, ComputeSourceLocation(*ast_, expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); + types_[&expr] = ErrorType(); return; } @@ -758,13 +771,14 @@ void ResolveVisitor::PostVisitCall(const Expr& expr, const CallExpr& call) { const FunctionDecl* decl = ResolveFunctionCallShape( expr, call.function(), arg_count, call.has_target()); - if (decl != nullptr) { - ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(), - /* is_namespaced= */ false); + if (decl == nullptr) { + ReportMissingReference(expr, call.function()); + types_[&expr] = ErrorType(); return; } - ReportMissingReference(expr, call.function()); + ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(), + /* is_namespaced= */ false); } void ResolveVisitor::PreVisitComprehension( @@ -786,7 +800,7 @@ void ResolveVisitor::PreVisitComprehension( void ResolveVisitor::PostVisitComprehension( const Expr& expr, const ComprehensionExpr& comprehension) { comprehension_scopes_.pop_back(); - types_[&expr] = GetTypeOrDyn(&comprehension.result()); + types_[&expr] = GetDeducedType(&comprehension.result()); } void ResolveVisitor::PreVisitComprehensionSubexpression( @@ -839,11 +853,12 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( // the corresponding variables can be referenced. switch (comprehension_arg) { case ComprehensionArg::ACCU_INIT: - scope.accu_scope->InsertVariableIfAbsent(MakeVariableDecl( - comprehension.accu_var(), GetTypeOrDyn(&comprehension.accu_init()))); + scope.accu_scope->InsertVariableIfAbsent( + MakeVariableDecl(comprehension.accu_var(), + GetDeducedType(&comprehension.accu_init()))); break; case ComprehensionArg::ITER_RANGE: { - Type range_type = GetTypeOrDyn(&comprehension.iter_range()); + Type range_type = GetDeducedType(&comprehension.iter_range()); Type iter_type = DynType(); // iter_var for non comprehensions v2. Type iter_type1 = DynType(); // iter_var for comprehensions v2. Type iter_type2 = DynType(); // iter_var2 for comprehensions v2. @@ -879,9 +894,6 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( } break; } - case ComprehensionArg::RESULT: - types_[&expr] = types_[&expr]; - break; default: break; } @@ -923,10 +935,10 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, std::vector arg_types; arg_types.reserve(arg_count); if (is_receiver) { - arg_types.push_back(GetTypeOrDyn(&expr.call_expr().target())); + arg_types.push_back(GetDeducedType(&expr.call_expr().target())); } for (int i = 0; i < expr.call_expr().args().size(); ++i) { - arg_types.push_back(GetTypeOrDyn(&expr.call_expr().args()[i])); + arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } absl::optional resolution = @@ -942,6 +954,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, out->append(type.DebugString()); }), ")'"))); + types_[&expr] = ErrorType(); return; } @@ -1000,6 +1013,7 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, if (decl == nullptr) { ReportMissingReference(expr, name); + types_[&expr] = ErrorType(); return; } @@ -1029,6 +1043,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier( if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); + types_[&expr] = ErrorType(); return; } @@ -1106,7 +1121,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, void ResolveVisitor::ResolveSelectOperation(const Expr& expr, absl::string_view field, const Expr& operand) { - const Type& operand_type = GetTypeOrDyn(&operand); + const Type& operand_type = GetDeducedType(&operand); absl::optional result_type; int64_t id = expr.id(); @@ -1122,12 +1137,15 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr, result_type = CheckFieldType(id, operand_type, field); } - if (result_type.has_value()) { - if (expr.select_expr().test_only()) { - types_[&expr] = BoolType(); - } else { - types_[&expr] = *result_type; - } + if (!result_type.has_value()) { + types_[&expr] = ErrorType(); + return; + } + + if (expr.select_expr().test_only()) { + types_[&expr] = BoolType(); + } else { + types_[&expr] = *result_type; } } @@ -1147,7 +1165,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { return; } - Type operand_type = GetTypeOrDyn(operand); + Type operand_type = GetDeducedType(operand); if (operand_type.IsOptional()) { operand_type = operand_type.GetOptional().GetParameter(); } @@ -1155,6 +1173,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { absl::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { + types_[&expr] = ErrorType(); return; } const FunctionDecl* select_decl = env_->LookupFunction(kOptionalSelect); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 9c95fbf22..2c429eeaa 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -1348,6 +1348,32 @@ TEST(TypeCheckerImplTest, BadSourcePosition) { "ERROR: :-1:-1: undeclared reference to 'foo' (in container '')"); } +// Check that the TypeChecker will fail if no type is deduced for a +// subexpression. This is meant to be a guard against failing to account for new +// types of expressions in the type checker logic. +TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertVariableIfAbsent(MakeVariableDecl("a", BoolType())); + env.InsertVariableIfAbsent(MakeVariableDecl("b", BoolType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("a || b")); + auto& ast_impl = AstImpl::CastFromPublicAst(*ast); + + // Assume that an unspecified expr kind is not deducible. + Expr unspecified_expr; + unspecified_expr.set_id(3); + ast_impl.root_expr().mutable_call_expr().mutable_args()[1] = + std::move(unspecified_expr); + + ASSERT_THAT(impl.Check(std::move(ast)), + StatusIs(absl::StatusCode::kInvalidArgument, + "Could not deduce type for expression id: 3")); +} + TEST(TypeCheckerImplTest, BadLineOffsets) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); From 241c9dd69c77f621918066d9ff2dba7fda61c3f6 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 Nov 2024 12:21:36 -0800 Subject: [PATCH 051/180] Improve error messages for Value type casting errors. PiperOrigin-RevId: 698874247 --- eval/public/structs/proto_message_type_adapter.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 8e703ae3a..06f129226 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -561,7 +561,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( const CelMap* cel_map; CEL_RETURN_IF_ERROR(ValidateSetFieldOp( value.GetValue(&cel_map) && cel_map != nullptr, - field->name(), "value is not CelMap")); + field->name(), + absl::StrCat("value is not CelMap - value is ", + CelValue::TypeName(value.type())))); auto entry_descriptor = field->message_type(); @@ -598,7 +600,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( const CelList* cel_list; CEL_RETURN_IF_ERROR(ValidateSetFieldOp( value.GetValue(&cel_list) && cel_list != nullptr, - field->name(), "expected CelList value")); + field->name(), + absl::StrCat("expected CelList value - value is", + CelValue::TypeName(value.type())))); for (int i = 0; i < cel_list->size(); i++) { CEL_RETURN_IF_ERROR(internal::AddValueToRepeatedField( From c39a717d0413832d5f5940c5ebd1d4cbb538edee Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 22 Nov 2024 10:47:12 -0800 Subject: [PATCH 052/180] Checker support for declaring context messages. Add support for declaring a context message type to the C++ type checker. The checker considers the top level fields of the type as variables in the type check environment. PiperOrigin-RevId: 699221362 --- checker/BUILD | 7 - checker/internal/BUILD | 24 +- checker/internal/type_checker_builder_impl.cc | 63 ++++++ checker/internal/type_checker_builder_impl.h | 5 + .../type_checker_builder_impl_test.cc | 212 ++++++++++++++++++ checker/type_checker_builder.h | 12 + checker/type_checker_builder_factory_test.cc | 21 ++ 7 files changed, 336 insertions(+), 8 deletions(-) create mode 100644 checker/internal/type_checker_builder_impl_test.cc diff --git a/checker/BUILD b/checker/BUILD index 7626284a6..a8ebbb653 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -88,17 +88,10 @@ cc_library( ":type_checker", "//common:decl", "//common:type", - "//internal:status_macros", - "//parser:macro", - "@com_google_absl//absl/base:no_destructor", - "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", ], ) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index aec0701bb..336106073 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -139,11 +139,11 @@ cc_library( "//common:type_kind", "//internal:status_macros", "//parser:macro", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -189,6 +189,28 @@ cc_test( ], ) +cc_test( + name = "type_checker_builder_impl_test", + srcs = ["type_checker_builder_impl_test.cc"], + deps = [ + ":test_ast_helpers", + ":type_checker_impl", + "//base/ast_internal:ast_impl", + "//base/ast_internal:expr", + "//checker:type_checker", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "type_inference_context", srcs = ["type_inference_context.cc"], diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 6f0345290..4897205a4 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -19,12 +19,15 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" @@ -34,6 +37,7 @@ #include "common/type_introspector.h" #include "internal/status_macros.h" #include "parser/macro.h" +#include "google/protobuf/descriptor.h" namespace cel::checker_internal { namespace { @@ -78,8 +82,34 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } // namespace +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationVariables( + absl::Nonnull descriptor) { + for (int i = 0; i < descriptor->field_count(); i++) { + const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i); + MessageTypeField cel_field(proto_field); + cel_field.name(); + Type field_type = cel_field.GetType(); + if (field_type.IsEnum()) { + field_type = IntType(); + } + if (!env_.InsertVariableIfAbsent( + MakeVariableDecl(std::string(cel_field.name()), field_type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", cel_field.name(), + "' already exists (from context declaration: '", + descriptor->full_name(), "')")); + } + } + + return absl::OkStatus(); +} + absl::StatusOr> TypeCheckerBuilderImpl::Build() && { + for (const auto* type : context_types_) { + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(type)); + } + auto checker = std::make_unique( std::move(env_), options_); return checker; @@ -108,6 +138,39 @@ absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) { return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( + absl::string_view type) { + CEL_ASSIGN_OR_RETURN(absl::optional resolved_type, + env_.LookupTypeName(type)); + + if (!resolved_type.has_value()) { + return absl::NotFoundError( + absl::StrCat("context declaration '", type, "' not found")); + } + + if (!resolved_type->IsStruct()) { + return absl::InvalidArgumentError( + absl::StrCat("context declaration '", type, "' is not a struct")); + } + + if (!resolved_type->AsStruct()->IsMessage()) { + return absl::InvalidArgumentError( + absl::StrCat("context declaration '", type, + "' is not protobuf message backed struct")); + } + + const google::protobuf::Descriptor* descriptor = + &(**(resolved_type->AsStruct()->AsMessage())); + + if (absl::c_linear_search(context_types_, descriptor)) { + return absl::AlreadyExistsError( + absl::StrCat("context declaration '", type, "' already exists")); + } + + context_types_.push_back(descriptor); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl)); bool inserted = env_.InsertFunctionIfAbsent(decl); diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 377b9ec20..c9028f90b 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -58,6 +58,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddLibrary(CheckerLibrary library) override; absl::Status AddVariable(const VariableDecl& decl) override; + absl::Status AddContextDeclaration(absl::string_view type) override; absl::Status AddFunction(const FunctionDecl& decl) override; void SetExpectedType(const Type& type) override; @@ -71,9 +72,13 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options() const override { return options_; } private: + absl::Status AddContextDeclarationVariables( + absl::Nonnull descriptor); + CheckerOptions options_; std::vector libraries_; absl::flat_hash_set library_ids_; + std::vector> context_types_; checker_internal::TypeCheckEnv env_; }; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc new file mode 100644 index 000000000..5091a8843 --- /dev/null +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -0,0 +1,212 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/type_checker_builder_impl.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/ast_internal/ast_impl.h" +#include "base/ast_internal/expr.h" +#include "checker/internal/test_ast_helpers.h" +#include "checker/type_checker.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::ast_internal::AstImpl; + +using AstType = cel::ast_internal::Type; + +struct ContextDeclsTestCase { + std::string expr; + AstType expected_type; +}; + +class ContextDeclsFieldsDefinedTest + : public testing::TestWithParam {}; + +TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + std::move(builder).Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(GetParam().expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); + + EXPECT_EQ(ast_impl.GetReturnType(), GetParam().expected_type); +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", + AstType(ast_internal::PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", + AstType(ast_internal::PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", + AstType(ast_internal::PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", + AstType(ast_internal::PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", + AstType(ast_internal::WellKnownType::kAny)}, + ContextDeclsTestCase{"single_duration", + AstType(ast_internal::WellKnownType::kDuration)}, + ContextDeclsTestCase{"single_bool_wrapper", + AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + AstType(ast_internal::ListType( + std::make_unique(ast_internal::DynamicType())))}, + ContextDeclsTestCase{ + "standalone_message", + AstType(ast_internal::MessageType( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + AstType(ast_internal::PrimitiveType::kInt64)}, + ContextDeclsTestCase{ + "repeated_bytes", + AstType(ast_internal::ListType(std::make_unique( + ast_internal::PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + AstType(ast_internal::ListType(std::make_unique< + AstType>(ast_internal::MessageType( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kInt64), + std::make_unique( + ast_internal::WellKnownType::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType())))})); + +TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + +TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclaration("com.example.UnknownType"), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + +TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclaration("google.protobuf.Timestamp"), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + +TEST(ContextDeclsTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return absl::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclaration("com.example.MyStruct"), + StatusIs(absl::StatusCode::kInvalidArgument, + "context declaration 'com.example.MyStruct' is not " + "protobuf message backed struct")); +} + +TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto2.TestAllTypes"), + IsOk()); + + EXPECT_THAT( + std::move(builder).Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' already exists (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT( + std::move(builder).Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' already exists (from context " + "declaration: 'cel.expr.conformance.proto3.TestAllTypes')")); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index c93c8d1db..21f3c35a5 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -57,6 +57,18 @@ class TypeCheckerBuilder { // with the resulting type checker. virtual absl::Status AddVariable(const VariableDecl& decl) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct is declared + // as an individual variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index fde2106a0..79a5ec0e4 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -113,6 +113,27 @@ TEST(TypeCheckerBuilderTest, AddLibrary) { EXPECT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, AddContextDeclaration) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclaration( + "cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, From 30f440412ed23557af00da70a7e54590a093b06e Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 2 Dec 2024 11:20:54 -0800 Subject: [PATCH 053/180] internal codelab updates PiperOrigin-RevId: 702029607 --- codelab/solutions/exercise4.cc | 81 ++++++++++++++++------------------ 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/codelab/solutions/exercise4.cc b/codelab/solutions/exercise4.cc index c07bc3413..924393b1c 100644 --- a/codelab/solutions/exercise4.cc +++ b/codelab/solutions/exercise4.cc @@ -15,12 +15,14 @@ #include #include +#include #include "cel/expr/checked.pb.h" -#include "google/protobuf/text_format.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "codelab/cel_compiler.h" #include "eval/public/activation.h" #include "eval/public/activation_bind_helper.h" @@ -31,6 +33,8 @@ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" namespace google::api::expr::codelab { namespace { @@ -44,7 +48,6 @@ using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::FunctionAdapter; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; - using ::google::rpc::context::AttributeContext; absl::StatusOr ContainsExtensionFunction( @@ -59,46 +62,35 @@ absl::StatusOr ContainsExtensionFunction( return false; } -class Compiler { - public: - explicit Compiler(std::unique_ptr compiler) - : compiler_(std::move(compiler)) {} - - absl::Status SetupCheckerEnvironment() { - // Codelab part 1: - // Add a declaration for the map.contains(string, string) function. - Decl decl; - if (!google::protobuf::TextFormat::ParseFromString( - R"pb( - name: "contains" - function { - overloads { - overload_id: "map_contains_string_string" - result_type { primitive: BOOL } - is_instance_function: true - params { - map_type { - key_type { primitive: STRING } - value_type { dyn {} } - } - } - params { primitive: STRING } - params { primitive: STRING } - } - })pb", - &decl)) { - return absl::InternalError("Failed to setup type check environment."); - } - return compiler_->AddDeclaration(std::move(decl)); - } - - absl::StatusOr Compile(absl::string_view expr) { - return compiler_->Compile(expr); +absl::StatusOr> MakeConfiguredCompiler() { + std::vector declarations; + // Codelab part 1: + // Add a declaration for the map.contains(string, string) function. + bool success = google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "contains" + function { + overloads { + overload_id: "map_contains_string_string" + result_type { primitive: BOOL } + is_instance_function: true + params { + map_type { + key_type { primitive: STRING } + value_type { dyn {} } + } + } + params { primitive: STRING } + params { primitive: STRING } + } + })pb", + &declarations.emplace_back()); + if (!success) { + return absl::InternalError( + "Failed to parse Decl textproto in type check environment setup."); } - - private: - std::unique_ptr compiler_; -}; + return CreateCodelabCompiler(declarations); +} class Evaluator { public: @@ -147,9 +139,10 @@ class Evaluator { absl::StatusOr EvaluateWithExtensionFunction( absl::string_view expr, const AttributeContext& context) { // Prepare a checked expression. - Compiler compiler(GetDefaultCompiler()); - CEL_RETURN_IF_ERROR(compiler.SetupCheckerEnvironment()); - CEL_ASSIGN_OR_RETURN(auto checked_expr, compiler.Compile(expr)); + // Prepare a checked expression. + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, + MakeConfiguredCompiler()); + CEL_ASSIGN_OR_RETURN(auto checked_expr, compiler->Compile(expr)); // Prepare an evaluation environment. Evaluator evaluator; From 26c810aae6b6fd14ec3eb1abb0f7627e6e1b65ea Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 5 Dec 2024 02:04:10 -0800 Subject: [PATCH 054/180] No public description PiperOrigin-RevId: 703031112 --- eval/eval/create_list_step.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index e1895ad82..3636ab8b8 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -99,7 +99,7 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { .NativeValue(); } } else { - CEL_RETURN_IF_ERROR(builder->Add(std::move(arg))); + CEL_RETURN_IF_ERROR(builder->Add(arg)); } } From 16a71c5e5c8fc13b8228be883a7e0b25deaff8d6 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 5 Dec 2024 10:18:24 -0800 Subject: [PATCH 055/180] Add support for type checker type deduction conformance tests. PiperOrigin-RevId: 703160475 --- bazel/deps.bzl | 4 ++-- conformance/BUILD | 9 +++---- conformance/run.cc | 60 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 51eb3e9d6..1d428f207 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -142,10 +142,10 @@ def cel_spec_deps(): url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", ) - CEL_SPEC_GIT_SHA = "373994d7e20e582fce56767b01ac5039524cddab" # Oct 23, 2024 + CEL_SPEC_GIT_SHA = "afa18f9bd5a83f5960ca06c1f9faea406ab34ccc" # Dec 2, 2024 http_archive( name = "com_google_cel_spec", - sha256 = "b498a768140fc0ed0314eef8b2519a48287661d09ca15b17c8ca34088af6aac3", + sha256 = "19b4084ba33cc8da7a640d999e46731efbec585ad2995951dc61a7af24f059cb", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) diff --git a/conformance/BUILD b/conformance/BUILD index 9bd9fb722..b128e7424 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -167,6 +167,7 @@ _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", "@com_google_cel_spec//tests/simple:testdata/wrappers.textproto", "@com_google_cel_spec//tests/simple:testdata/block_ext.textproto", + "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] _TESTS_TO_SKIP_MODERN = [ @@ -301,14 +302,14 @@ gen_conformance_tests( name = "conformance_parse_only", data = _ALL_TESTS, modern = True, - skip_tests = _TESTS_TO_SKIP_MODERN, + skip_tests = _TESTS_TO_SKIP_MODERN + ["type_deductions"], ) gen_conformance_tests( name = "conformance_legacy_parse_only", data = _ALL_TESTS, modern = False, - skip_tests = _TESTS_TO_SKIP_LEGACY, + skip_tests = _TESTS_TO_SKIP_LEGACY + ["type_deductions"], ) gen_conformance_tests( @@ -330,7 +331,7 @@ gen_conformance_tests( dashboard = True, data = _ALL_TESTS, modern = True, - skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD, + skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD + ["type_deductions"], tags = [ "guitar", "notap", @@ -355,7 +356,7 @@ gen_conformance_tests( dashboard = True, data = _ALL_TESTS, modern = False, - skip_tests = _TESTS_TO_SKIP_LEGACY_DASHBOARD, + skip_tests = _TESTS_TO_SKIP_LEGACY_DASHBOARD + ["type_deductions"], tags = [ "guitar", "notap", diff --git a/conformance/run.cc b/conformance/run.cc index c76569f8c..d61e24fb9 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -19,6 +19,7 @@ // conformance tests; as well as integrating better with C++ testing // infrastructure. +#include #include #include #include @@ -32,6 +33,7 @@ #include "cel/expr/eval.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" // IWYU pragma: keep #include "google/api/expr/v1alpha1/eval.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" // IWYU pragma: keep #include "google/api/expr/v1alpha1/value.pb.h" #include "cel/expr/value.pb.h" #include "google/rpc/code.pb.h" @@ -126,6 +128,43 @@ MATCHER_P(MatchesConformanceValue, expected, "") { return false; } +MATCHER_P(ResultTypeMatches, expected, "") { + static auto* kDifferencer = []() { + auto* differencer = new MessageDifferencer(); + differencer->set_message_field_comparison(MessageDifferencer::EQUIVALENT); + return differencer; + }(); + + const cel::expr::Type& want = expected; + const google::api::expr::v1alpha1::CheckedExpr& checked_expr = arg; + + int64_t root_id = checked_expr.expr().id(); + auto it = checked_expr.type_map().find(root_id); + + if (it == checked_expr.type_map().end()) { + (*result_listener) << "type map does not contain root id: " << root_id; + return false; + } + + auto got_versioned = it->second; + std::string serialized; + cel::expr::Type got; + if (!got_versioned.SerializeToString(&serialized) || + !got.ParseFromString(serialized)) { + (*result_listener) << "type cannot be converted from versioned type: " + << DescribeMessage(got_versioned); + return false; + } + + if (kDifferencer->Compare(got, want)) { + return true; + } + (*result_listener) << "got: " << DescribeMessage(got); + (*result_listener) << "\n"; + (*result_listener) << "wanted: " << DescribeMessage(want); + return false; +} + bool ShouldSkipTest(absl::Span tests_to_skip, absl::string_view name) { for (absl::string_view test_to_skip : tests_to_skip) { @@ -202,6 +241,14 @@ class ConformanceTest : public testing::Test { check_response.release_checked_expr()); } + if (test_.check_only()) { + ASSERT_TRUE(test_.has_typed_result()) + << "test must specify a typed result if check_only is set"; + EXPECT_THAT(eval_request.checked_expr(), + ResultTypeMatches(test_.typed_result().deduced_type())); + return; + } + EvalResponse eval_response; if (auto status = service_->Eval(eval_request, eval_response); !status.ok()) { @@ -219,6 +266,19 @@ class ConformanceTest : public testing::Test { EXPECT_THAT(test_value, MatchesConformanceValue(test_.value())); break; } + case SimpleTest::kTypedResult: { + ASSERT_TRUE(eval_request.has_checked_expr()) + << "expression was not type checked"; + absl::Cord serialized; + ABSL_CHECK(eval_response.result().SerializePartialToCord(&serialized)); + cel::expr::ExprValue test_value; + ABSL_CHECK(test_value.ParsePartialFromCord(serialized)); + EXPECT_THAT(test_value, + MatchesConformanceValue(test_.typed_result().result())); + EXPECT_THAT(eval_request.checked_expr(), + ResultTypeMatches(test_.typed_result().deduced_type())); + break; + } case SimpleTest::kEvalError: EXPECT_TRUE(eval_response.result().has_error()) << eval_response.result(); From fe224769756e035c11e0cd213d699d175d702e43 Mon Sep 17 00:00:00 2001 From: Justin King Date: Mon, 9 Dec 2024 11:15:13 -0800 Subject: [PATCH 056/180] Initial minimal implementation of Comprehensions V2 PiperOrigin-RevId: 704353242 --- common/expr_factory.h | 22 + common/values/error_value.h | 9 + conformance/BUILD | 2 + conformance/service.cc | 8 + eval/compiler/flat_expr_builder.cc | 132 ++++- eval/eval/BUILD | 2 + eval/eval/comprehension_step.cc | 480 ++++++++++++++++- eval/eval/comprehension_step.h | 35 +- eval/eval/comprehension_step_test.cc | 16 +- eval/eval/create_map_step.cc | 35 ++ eval/eval/create_map_step.h | 14 + extensions/BUILD | 85 +++ extensions/comprehensions_v2_functions.cc | 85 +++ extensions/comprehensions_v2_functions.h | 35 ++ .../comprehensions_v2_functions_test.cc | 222 ++++++++ extensions/comprehensions_v2_macros.cc | 488 ++++++++++++++++++ extensions/comprehensions_v2_macros.h | 30 ++ extensions/comprehensions_v2_macros_test.cc | 230 +++++++++ parser/macro_expr_factory.h | 21 + 19 files changed, 1907 insertions(+), 44 deletions(-) create mode 100644 extensions/comprehensions_v2_functions.cc create mode 100644 extensions/comprehensions_v2_functions.h create mode 100644 extensions/comprehensions_v2_functions_test.cc create mode 100644 extensions/comprehensions_v2_macros.cc create mode 100644 extensions/comprehensions_v2_macros.h create mode 100644 extensions/comprehensions_v2_macros_test.cc diff --git a/common/expr_factory.h b/common/expr_factory.h index fd483bc5e..dd8e6ed25 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -317,10 +317,32 @@ class ExprFactory { AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, LoopStep loop_step, Result result) { + return NewComprehension(id, std::move(iter_var), "", std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); + } + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewComprehension(ExprId id, IterVar iter_var, IterVar2 iter_var2, + IterRange iter_range, AccuVar accu_var, + AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { Expr expr; expr.set_id(id); auto& comprehension_expr = expr.mutable_comprehension_expr(); comprehension_expr.set_iter_var(std::move(iter_var)); + comprehension_expr.set_iter_var2(std::move(iter_var2)); comprehension_expr.set_iter_range(std::move(iter_range)); comprehension_expr.set_accu_var(std::move(accu_var)); comprehension_expr.set_accu_init(std::move(accu_init)); diff --git a/common/values/error_value.h b/common/values/error_value.h index 577675776..02380d575 100644 --- a/common/values/error_value.h +++ b/common/values/error_value.h @@ -156,6 +156,15 @@ bool IsNoSuchField(const ErrorValue& value); bool IsNoSuchKey(const ErrorValue& value); +class ErrorValueReturn final { + public: + ErrorValueReturn() = default; + + ErrorValue operator()(absl::Status status) const { + return ErrorValue(std::move(status)); + } +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ diff --git a/conformance/BUILD b/conformance/BUILD index b128e7424..698a27702 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -73,6 +73,8 @@ cc_library( "//eval/public:cel_value", "//eval/public:transform_utility", "//extensions:bindings_ext", + "//extensions:comprehensions_v2_functions", + "//extensions:comprehensions_v2_macros", "//extensions:encoders", "//extensions:math_ext", "//extensions:math_ext_decls", diff --git a/conformance/service.cc b/conformance/service.cc index aab66e58b..3250834a0 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -61,6 +61,8 @@ #include "eval/public/cel_value.h" #include "eval/public/transform_utility.h" #include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2_functions.h" +#include "extensions/comprehensions_v2_macros.h" #include "extensions/encoders.h" #include "extensions/math_ext.h" #include "extensions/math_ext_decls.h" @@ -256,6 +258,8 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, options.enable_optional_syntax = enable_optional_syntax; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); + CEL_RETURN_IF_ERROR( + cel::extensions::RegisterComprehensionsV2Macros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); @@ -334,6 +338,8 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( @@ -507,6 +513,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { type_registry, cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); + CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( + builder.function_registry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder.function_registry(), options)); diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 84e23f914..c687db003 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -327,12 +327,48 @@ const cel::ast_internal::Expr* GetOptimizableListAppendOperand( return &GetOptimizableListAppendCall(comprehension)->args()[1]; } +// Returns whether this comprehension appears to be a macro implementation for +// map transformations. It is not exhaustive, so it is unsafe to use with custom +// comprehensions outside of the standard macros or hand crafted ASTs. +bool IsOptimizableMapInsert( + const cel::ast_internal::Comprehension* comprehension) { + if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) { + return false; + } + absl::string_view accu_var = comprehension->accu_var(); + if (accu_var.empty() || !comprehension->has_result() || + !comprehension->result().has_ident_expr() || + comprehension->result().ident_expr().name() != accu_var) { + return false; + } + if (!comprehension->accu_init().has_map_expr()) { + return false; + } + if (!comprehension->loop_step().has_call_expr()) { + return false; + } + const auto* call_expr = &comprehension->loop_step().call_expr(); + + if (call_expr->function() == cel::builtin::kTernary && + call_expr->args().size() == 3) { + if (!call_expr->args()[1].has_call_expr()) { + return false; + } + call_expr = &(call_expr->args()[1].call_expr()); + } + return call_expr->function() == "cel.@mapInsert" && + call_expr->args().size() == 3 && + call_expr->args()[0].has_ident_expr() && + call_expr->args()[0].ident_expr().name() == accu_var; +} + bool IsBind(const cel::ast_internal::Comprehension* comprehension) { static constexpr absl::string_view kUnusedIterVar = "#unused"; return comprehension->loop_condition().const_expr().has_bool_value() && comprehension->loop_condition().const_expr().bool_value() == false && comprehension->iter_var() == kUnusedIterVar && + comprehension->iter_var2().empty() && comprehension->iter_range().has_list_expr() && comprehension->iter_range().list_expr().elements().empty(); } @@ -346,7 +382,7 @@ class ComprehensionVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, bool is_trivial, size_t iter_slot, - size_t accu_slot) + size_t iter2_slot, size_t accu_slot) : visitor_(visitor), next_step_(nullptr), cond_step_(nullptr), @@ -354,6 +390,7 @@ class ComprehensionVisitor { is_trivial_(is_trivial), accu_init_extracted_(false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void PreVisit(const cel::ast_internal::Expr* expr); @@ -387,6 +424,7 @@ class ComprehensionVisitor { bool is_trivial_; bool accu_init_extracted_; size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; }; @@ -603,6 +641,10 @@ class FlatExprVisitor : public cel::AstVisitor { } return {static_cast(record.iter_slot), -1}; } + if (record.iter_var2_in_scope && + record.comprehension->iter_var2() == path) { + return {static_cast(record.iter2_slot), -1}; + } if (record.accu_var_in_scope && record.comprehension->accu_var() == path) { int slot = record.accu_slot; @@ -1091,7 +1133,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::ast_internal::Expr* expr, const cel::ast_internal::Comprehension* comprehension, size_t iter_slot, - size_t accu_slot) { + size_t iter2_slot, size_t accu_slot) { if (options_.max_recursion_depth == 0) { return; } @@ -1144,7 +1186,8 @@ class FlatExprVisitor : public cel::AstVisitor { } auto step = CreateDirectComprehensionStep( - iter_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, + iter_slot, iter2_slot, accu_slot, + range_plan->ExtractRecursiveProgram().step, accu_plan->ExtractRecursiveProgram().step, loop_plan->ExtractRecursiveProgram().step, condition_plan->ExtractRecursiveProgram().step, @@ -1256,6 +1299,7 @@ class FlatExprVisitor : public cel::AstVisitor { } const auto& accu_var = comprehension.accu_var(); const auto& iter_var = comprehension.iter_var(); + const auto& iter_var2 = comprehension.iter_var2(); ValidateOrError(!accu_var.empty(), "Invalid comprehension: 'accu_var' must not be empty"); ValidateOrError(!iter_var.empty(), @@ -1263,6 +1307,12 @@ class FlatExprVisitor : public cel::AstVisitor { ValidateOrError( accu_var != iter_var, "Invalid comprehension: 'accu_var' must not be the same as 'iter_var'"); + ValidateOrError(accu_var != iter_var2, + "Invalid comprehension: 'accu_var' must not be the same as " + "'iter_var2'"); + ValidateOrError(iter_var2 != iter_var, + "Invalid comprehension: 'iter_var2' must not be the same " + "as 'iter_var'"); ValidateOrError(comprehension.has_accu_init(), "Invalid comprehension: 'accu_init' must be set"); ValidateOrError(comprehension.has_loop_condition(), @@ -1272,16 +1322,21 @@ class FlatExprVisitor : public cel::AstVisitor { ValidateOrError(comprehension.has_result(), "Invalid comprehension: 'result' must be set"); - size_t iter_slot, accu_slot, slot_count; + size_t iter_slot, iter2_slot, accu_slot, slot_count; bool is_bind = IsBind(&comprehension); if (is_bind) { - accu_slot = iter_slot = index_manager_.ReserveSlots(1); + accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1); slot_count = 1; - } else { - iter_slot = index_manager_.ReserveSlots(2); + } else if (comprehension.iter_var2().empty()) { + iter_slot = iter2_slot = index_manager_.ReserveSlots(2); accu_slot = iter_slot + 1; slot_count = 2; + } else { + iter_slot = index_manager_.ReserveSlots(3); + iter2_slot = iter_slot + 1; + accu_slot = iter2_slot + 1; + slot_count = 3; } if (block_.has_value()) { @@ -1307,16 +1362,20 @@ class FlatExprVisitor : public cel::AstVisitor { } comprehension_stack_.push_back( - {&expr, &comprehension, iter_slot, accu_slot, slot_count, + {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, /*subexpression=*/-1, + /*.is_optimizable_list_append=*/ IsOptimizableListAppend(&comprehension, options_.enable_comprehension_list_append), - is_bind, + /*.is_optimizable_map_insert=*/IsOptimizableMapInsert(&comprehension), + /*.is_optimizable_bind=*/is_bind, /*.iter_var_in_scope=*/false, + /*.iter_var2_in_scope=*/false, /*.accu_var_in_scope=*/false, /*.in_accu_init=*/false, - std::make_unique( - this, options_.short_circuiting, is_bind, iter_slot, accu_slot)}); + std::make_unique(this, options_.short_circuiting, + is_bind, iter_slot, iter2_slot, + accu_slot)}); comprehension_stack_.back().visitor->PreVisit(&expr); } @@ -1359,30 +1418,35 @@ class FlatExprVisitor : public cel::AstVisitor { case cel::ITER_RANGE: { record.in_accu_init = false; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::ACCU_INIT: { record.in_accu_init = true; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::LOOP_CONDITION: { record.in_accu_init = false; record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::LOOP_STEP: { record.in_accu_init = false; record.iter_var_in_scope = true; + record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::RESULT: { record.in_accu_init = false; record.iter_var_in_scope = false; + record.iter_var2_in_scope = false; record.accu_var_in_scope = true; break; } @@ -1486,6 +1550,21 @@ class FlatExprVisitor : public cel::AstVisitor { return; } + if (!comprehension_stack_.empty()) { + const ComprehensionStackRecord& comprehension = + comprehension_stack_.back(); + if (comprehension.is_optimizable_map_insert) { + if (&(comprehension.comprehension->accu_init()) == &expr) { + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); + return; + } + AddStep(CreateMutableMapStep(expr.id())); + return; + } + } + } + auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { @@ -1690,13 +1769,16 @@ class FlatExprVisitor : public cel::AstVisitor { const cel::ast_internal::Expr* expr; const cel::ast_internal::Comprehension* comprehension; size_t iter_slot; + size_t iter2_slot; size_t accu_slot; size_t slot_count; // -1 indicates this shouldn't be used. int subexpression; bool is_optimizable_list_append; + bool is_optimizable_map_insert; bool is_optimizable_bind; bool iter_var_in_scope; + bool iter_var2_in_scope; bool accu_var_in_scope; bool in_accu_init; std::unique_ptr visitor; @@ -2032,20 +2114,26 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( case cel::ITER_RANGE: { // post process iter_range to list its keys if it's a map // and initialize the loop index. - visitor_->AddStep(CreateComprehensionInitStep(expr->id())); + // If the slots are the same, this is comprehensions v1 otherwise this is + // comprehensions v2. + if (iter_slot_ == iter2_slot_) { + visitor_->AddStep(CreateComprehensionInitStep(expr->id())); + } else { + visitor_->AddStep(CreateComprehensionInitStep2(expr->id())); + } break; } case cel::ACCU_INIT: { next_step_pos_ = visitor_->GetCurrentIndex(); - next_step_ = - new ComprehensionNextStep(iter_slot_, accu_slot_, expr->id()); + next_step_ = new ComprehensionNextStep(iter_slot_, iter2_slot_, + accu_slot_, expr->id()); visitor_->AddStep(std::unique_ptr(next_step_)); break; } case cel::LOOP_CONDITION: { cond_step_pos_ = visitor_->GetCurrentIndex(); - cond_step_ = new ComprehensionCondStep(iter_slot_, accu_slot_, - short_circuiting_, expr->id()); + cond_step_ = new ComprehensionCondStep( + iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id()); visitor_->AddStep(std::unique_ptr(cond_step_)); break; } @@ -2070,7 +2158,13 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } case cel::RESULT: { - visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id())); + if (iter_slot_ == iter2_slot_) { + visitor_->AddStep( + CreateComprehensionFinishStep(accu_slot_, expr->id())); + } else { + visitor_->AddStep( + CreateComprehensionFinishStep2(accu_slot_, expr->id())); + } CEL_ASSIGN_OR_RETURN( int jump_from_next, @@ -2118,8 +2212,8 @@ void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) { accu_slot_); return; } - visitor_->MaybeMakeComprehensionRecursive(expr, &expr->comprehension_expr(), - iter_slot_, accu_slot_); + visitor_->MaybeMakeComprehensionRecursive( + expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_); } // Flattens the expression table into the end of the mainline expression vector diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 9dbf19433..e46f66084 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -478,10 +478,12 @@ cc_library( "//eval/internal:errors", "//eval/public:cel_attribute", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 75e723e17..c34121b09 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -5,10 +5,12 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/kind.h" @@ -27,6 +29,7 @@ namespace google::api::expr::runtime { namespace { +using ::cel::AttributeQualifier; using ::cel::BoolValue; using ::cel::Cast; using ::cel::InstanceOf; @@ -35,8 +38,25 @@ using ::cel::ListValue; using ::cel::MapValue; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::ValueKind; using ::cel::runtime_internal::CreateNoMatchingOverloadError; +AttributeQualifier AttributeQualifierFromValue(const Value& v) { + switch (v->kind()) { + case ValueKind::kString: + return AttributeQualifier::OfString(v.GetString().ToString()); + case ValueKind::kInt64: + return AttributeQualifier::OfInt(v.GetInt().NativeValue()); + case ValueKind::kUint64: + return AttributeQualifier::OfUint(v.GetUint().NativeValue()); + case ValueKind::kBool: + return AttributeQualifier::OfBool(v.GetBool().NativeValue()); + default: + // Non-matching qualifier. + return AttributeQualifier(); + } +} + class ComprehensionFinish : public ExpressionStepBase { public: ComprehensionFinish(size_t accu_slot, int64_t expr_id); @@ -65,6 +85,30 @@ absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class ComprehensionFinish2 final : public ExpressionStepBase { + public: + ComprehensionFinish2(size_t accu_slot, int64_t expr_id) + : ExpressionStepBase(expr_id), accu_slot_(accu_slot) {} + + // Stack changes of ComprehensionFinish. + // + // Stack size before: 4. + // Stack size after: 1. + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(4)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + Value result = frame->value_stack().Peek(); + frame->value_stack().Pop(4); + frame->value_stack().Push(std::move(result)); + frame->comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); + } + + private: + size_t accu_slot_; +}; + class ComprehensionInitStep : public ExpressionStepBase { public: explicit ComprehensionInitStep(int64_t expr_id) @@ -124,10 +168,49 @@ absl::Status ComprehensionInitStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class ComprehensionInitStep2 final : public ExpressionStepBase { + public: + explicit ComprehensionInitStep2(int64_t expr_id) + : ExpressionStepBase(expr_id, false) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + if (!frame->value_stack().HasEnough(1)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + const auto& range = frame->value_stack().Peek(); + switch (range.kind()) { + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN( + Value keys, ProjectKeysImpl(*frame, range.GetMap(), + frame->value_stack().PeekAttribute())); + frame->value_stack().Push(std::move(keys)); + } break; + case ValueKind::kList: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().Push(range); + break; + default: + frame->value_stack().PopAndPush(frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + break; + } + + // Initialize current index. + // Error handling for wrong range type is deferred until the 'Next' step + // to simplify the number of jumps. + frame->value_stack().Push(frame->value_factory().CreateIntValue(-1)); + return absl::OkStatus(); + } +}; + class ComprehensionDirectStep : public DirectExpressionStep { public: explicit ComprehensionDirectStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -136,6 +219,7 @@ class ComprehensionDirectStep : public DirectExpressionStep { int64_t expr_id) : DirectExpressionStep(expr_id), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot), range_(std::move(range)), accu_init_(std::move(accu_init)), @@ -143,11 +227,22 @@ class ComprehensionDirectStep : public DirectExpressionStep { condition_(std::move(condition_step)), result_step_(std::move(result_step)), shortcircuiting_(shortcircuiting) {} + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, - AttributeTrail& trail) const override; + AttributeTrail& trail) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame, result, trail) + : Evaluate2(frame, result, trail); + } private: + absl::Status Evaluate1(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + + absl::Status Evaluate2(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; std::unique_ptr range_; std::unique_ptr accu_init_; @@ -158,9 +253,9 @@ class ComprehensionDirectStep : public DirectExpressionStep { bool shortcircuiting_; }; -absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, - Value& result, - AttributeTrail& trail) const { +absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { cel::Value range; AttributeTrail range_attr; CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); @@ -257,6 +352,180 @@ absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, return absl::OkStatus(); } +absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + cel::Value iter2_range; + AttributeTrail range_attr; + CEL_RETURN_IF_ERROR(range_->Evaluate(frame, iter2_range, range_attr)); + + absl::optional iter2_range_map; + cel::Value iter_range; + if (iter2_range.IsMap()) { + iter2_range_map = iter2_range.GetMap(); + CEL_ASSIGN_OR_RETURN(iter_range, + ProjectKeysImpl(frame, *iter2_range_map, range_attr)); + } else { + iter_range = iter2_range; + } + + switch (iter_range.kind()) { + case cel::ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case cel::ValueKind::kUnknown: + result = iter_range; + return absl::OkStatus(); + case cel::ValueKind::kList: + break; + default: + result = cel::ErrorValue(CreateNoMatchingOverloadError("")); + return absl::OkStatus(); + } + + const auto& iter_range_list = iter_range.GetList(); + + Value accu_init; + AttributeTrail accu_init_attr; + CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); + + frame.comprehension_slots().Set(accu_slot_, std::move(accu_init), + accu_init_attr); + ComprehensionSlots::Slot* accu_slot = + frame.comprehension_slots().Get(accu_slot_); + ABSL_DCHECK(accu_slot != nullptr); + + frame.comprehension_slots().Set(iter_slot_); + ComprehensionSlots::Slot* iter_slot = + frame.comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + + frame.comprehension_slots().Set(iter2_slot_); + ComprehensionSlots::Slot* iter2_slot = + frame.comprehension_slots().Get(iter2_slot_); + ABSL_DCHECK(iter2_slot != nullptr); + + Value condition; + AttributeTrail condition_attr; + bool should_skip_result = false; + if (iter2_range_map) { + CEL_RETURN_IF_ERROR(iter2_range_map->ForEach( + frame.value_manager(), + [&](const Value& k, const Value& v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + // Evaluate loop condition first. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + iter_slot->value = k; + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(k)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + iter2_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter2_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(v)); + if (frame.attribute_utility().CheckForUnknownExact( + iter2_slot->attribute)) { + iter2_slot->value = frame.attribute_utility().CreateUnknownSet( + iter2_slot->attribute.attribute()); + } + } + + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + } else { + CEL_RETURN_IF_ERROR(iter_range_list.ForEach( + frame.value_manager(), + [&](size_t index, const Value& v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + // Evaluate loop condition first. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + iter_slot->value = IntValue(index); + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(CelAttributeQualifier::OfInt(index)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + iter2_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter2_slot->attribute = + range_attr.Step(AttributeQualifierFromValue(v)); + if (frame.attribute_utility().CheckForUnknownExact( + iter2_slot->attribute)) { + iter2_slot->value = frame.attribute_utility().CreateUnknownSet( + iter2_slot->attribute.attribute()); + } + } + + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + } + + frame.comprehension_slots().ClearSlot(iter_slot_); + frame.comprehension_slots().ClearSlot(iter2_slot_); + // Error state is already set to the return value, just clean up. + if (should_skip_result) { + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); +} + } // namespace // Stack variables during comprehension evaluation: @@ -276,10 +545,12 @@ absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, // 8. result (dep) 2 -> 3 // 9. ComprehensionFinish 3 -> 1 -ComprehensionNextStep::ComprehensionNextStep(size_t iter_slot, size_t accu_slot, - int64_t expr_id) +ComprehensionNextStep::ComprehensionNextStep(size_t iter_slot, + size_t iter2_slot, + size_t accu_slot, int64_t expr_id) : ExpressionStepBase(expr_id, false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void ComprehensionNextStep::set_jump_offset(int offset) { @@ -307,7 +578,7 @@ void ComprehensionNextStep::set_error_jump_offset(int offset) { // // Stack on error: // 0. error -absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionNextStep::Evaluate1(ExecutionFrame* frame) const { enum { POS_ITER_RANGE, POS_CURRENT_INDEX, @@ -386,11 +657,148 @@ absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } -ComprehensionCondStep::ComprehensionCondStep(size_t iter_slot, size_t accu_slot, +absl::Status ComprehensionNextStep::Evaluate2(ExecutionFrame* frame) const { + enum { + POS_ITER2_RANGE, // Map or same as POS_ITER_RANGE. + POS_ITER_RANGE, + POS_CURRENT_INDEX, + POS_LOOP_STEP_ACCU, + }; + constexpr int kStackSize = 4; + if (!frame->value_stack().HasEnough(kStackSize)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + absl::Span state = frame->value_stack().GetSpan(kStackSize); + + const cel::Value& iter2_range = state[POS_ITER2_RANGE]; + absl::optional iter2_range_map; + switch (iter2_range.kind()) { + case ValueKind::kMap: + iter2_range_map = iter2_range.GetMap(); + break; + case ValueKind::kList: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + // Leave it on the stack. + frame->value_stack().PopAndPush(kStackSize, std::move(iter2_range)); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + kStackSize, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + return frame->JumpTo(error_jump_offset_); + } + + // Get range from the stack. + const cel::Value& iter_range = state[POS_ITER_RANGE]; + switch (iter_range.kind()) { + case ValueKind::kList: + break; + case ValueKind::kError: + ABSL_FALLTHROUGH_INTENDED; + case ValueKind::kUnknown: + frame->value_stack().PopAndPush(kStackSize, std::move(iter_range)); + return frame->JumpTo(error_jump_offset_); + default: + frame->value_stack().PopAndPush( + kStackSize, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + return frame->JumpTo(error_jump_offset_); + } + ListValue iter_range_list = iter_range.GetList(); + + // Get the current index off the stack. + const cel::Value& current_index_value = state[POS_CURRENT_INDEX]; + if (!current_index_value.IsInt()) { + return absl::InternalError(absl::StrCat( + "ComprehensionNextStep: want int, got ", + cel::KindToString(ValueKindToKind(current_index_value.kind())))); + } + CEL_RETURN_IF_ERROR(frame->IncrementIterations()); + + int64_t next_index = current_index_value.GetInt().NativeValue() + 1; + + frame->comprehension_slots().Set(accu_slot_, state[POS_LOOP_STEP_ACCU]); + + CEL_ASSIGN_OR_RETURN(auto iter_range_list_size, iter_range_list.Size()); + + if (next_index >= static_cast(iter_range_list_size)) { + // Make sure the iter var is out of scope. + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + // pop loop step + frame->value_stack().Pop(1); + // jump to result production step + return frame->JumpTo(jump_offset_); + } + + AttributeTrail iter_range_trail; + if (frame->enable_unknowns()) { + iter_range_trail = + frame->value_stack().GetAttributeSpan(kStackSize)[POS_ITER_RANGE].Step( + cel::AttributeQualifier::OfInt(next_index)); + } + + Value current_iter_var; + if (frame->enable_unknowns() && + frame->attribute_utility().CheckForUnknown(iter_range_trail, + /*use_partial=*/false)) { + current_iter_var = frame->attribute_utility().CreateUnknownSet( + iter_range_trail.attribute()); + } else { + CEL_ASSIGN_OR_RETURN(current_iter_var, + iter_range_list.Get(frame->value_factory(), + static_cast(next_index))); + } + + AttributeTrail iter2_range_trail; + Value current_iter_var2; + if (iter2_range_map) { + AttributeTrail iter2_range_trail; + if (frame->enable_unknowns()) { + iter2_range_trail = + frame->value_stack() + .GetAttributeSpan(kStackSize)[POS_ITER2_RANGE] + .Step(AttributeQualifierFromValue(current_iter_var)); + } + if (frame->enable_unknowns() && + frame->attribute_utility().CheckForUnknown(iter2_range_trail, + /*use_partial=*/false)) { + current_iter_var2 = frame->attribute_utility().CreateUnknownSet( + iter2_range_trail.attribute()); + } else { + CEL_ASSIGN_OR_RETURN( + current_iter_var2, + iter2_range_map->Get(frame->value_manager(), current_iter_var)); + } + } else { + iter2_range_trail = iter_range_trail; + current_iter_var2 = current_iter_var; + current_iter_var = IntValue(next_index); + } + + // pop loop step + // pop old current_index + // push new current_index + frame->value_stack().PopAndPush( + 2, frame->value_factory().CreateIntValue(next_index)); + frame->comprehension_slots().Set(iter_slot_, std::move(current_iter_var), + std::move(iter_range_trail)); + frame->comprehension_slots().Set(iter2_slot_, std::move(current_iter_var2), + std::move(iter2_range_trail)); + return absl::OkStatus(); +} + +ComprehensionCondStep::ComprehensionCondStep(size_t iter_slot, + size_t iter2_slot, + size_t accu_slot, bool shortcircuiting, int64_t expr_id) : ExpressionStepBase(expr_id, false), iter_slot_(iter_slot), + iter2_slot_(iter2_slot), accu_slot_(accu_slot), shortcircuiting_(shortcircuiting) {} @@ -412,7 +820,7 @@ void ComprehensionCondStep::set_error_jump_offset(int offset) { // Stack size before: 3. // Stack size after: 2. // Stack size on error: 1. -absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionCondStep::Evaluate1(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(3)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } @@ -440,8 +848,47 @@ absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +// Check the break condition for the comprehension. +// +// If the condition is false jump to the `result` subexpression. +// If not a bool, clear stack and jump past the result expression. +// Otherwise, continue to the accumulate step. +// Stack changes by ComprehensionCondStep. +// +// Stack size before: 4. +// Stack size after: 3. +// Stack size on error: 1. +absl::Status ComprehensionCondStep::Evaluate2(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(4)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + auto& loop_condition_value = frame->value_stack().Peek(); + if (!loop_condition_value->Is()) { + if (loop_condition_value->Is() || + loop_condition_value->Is()) { + frame->value_stack().PopAndPush(4, std::move(loop_condition_value)); + } else { + frame->value_stack().PopAndPush( + 4, frame->value_factory().CreateErrorValue( + CreateNoMatchingOverloadError(""))); + } + // The error jump skips the ComprehensionFinish clean-up step, so we + // need to update the iteration variable stack here. + frame->comprehension_slots().ClearSlot(iter_slot_); + frame->comprehension_slots().ClearSlot(iter2_slot_); + frame->comprehension_slots().ClearSlot(accu_slot_); + return frame->JumpTo(error_jump_offset_); + } + bool loop_condition = loop_condition_value.GetBool().NativeValue(); + frame->value_stack().Pop(1); // loop_condition + if (!loop_condition && shortcircuiting_) { + return frame->JumpTo(jump_offset_); + } + return absl::OkStatus(); +} + std::unique_ptr CreateDirectComprehensionStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -449,7 +896,7 @@ std::unique_ptr CreateDirectComprehensionStep( std::unique_ptr result_step, bool shortcircuiting, int64_t expr_id) { return std::make_unique( - iter_slot, accu_slot, std::move(range), std::move(accu_init), + iter_slot, iter2_slot, accu_slot, std::move(range), std::move(accu_init), std::move(loop_step), std::move(condition_step), std::move(result_step), shortcircuiting, expr_id); } @@ -463,4 +910,13 @@ std::unique_ptr CreateComprehensionInitStep(int64_t expr_id) { return std::make_unique(expr_id); } +std::unique_ptr CreateComprehensionFinishStep2( + size_t accu_slot, int64_t expr_id) { + return std::make_unique(accu_slot, expr_id); +} + +std::unique_ptr CreateComprehensionInitStep2(int64_t expr_id) { + return std::make_unique(expr_id); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index c0fc78aa0..b0b8397f2 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -14,15 +14,23 @@ namespace google::api::expr::runtime { class ComprehensionNextStep : public ExpressionStepBase { public: - ComprehensionNextStep(size_t iter_slot, size_t accu_slot, int64_t expr_id); + ComprehensionNextStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, + int64_t expr_id); void set_jump_offset(int offset); void set_error_jump_offset(int offset); - absl::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; int jump_offset_; int error_jump_offset_; @@ -30,16 +38,23 @@ class ComprehensionNextStep : public ExpressionStepBase { class ComprehensionCondStep : public ExpressionStepBase { public: - ComprehensionCondStep(size_t iter_slot, size_t accu_slot, + ComprehensionCondStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, bool shortcircuiting, int64_t expr_id); void set_jump_offset(int offset); void set_error_jump_offset(int offset); - absl::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const final { + return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); + } private: + absl::Status Evaluate1(ExecutionFrame* frame) const; + + absl::Status Evaluate2(ExecutionFrame* frame) const; + size_t iter_slot_; + size_t iter2_slot_; size_t accu_slot_; int jump_offset_; int error_jump_offset_; @@ -48,7 +63,7 @@ class ComprehensionCondStep : public ExpressionStepBase { // Creates a step for executing a comprehension. std::unique_ptr CreateDirectComprehensionStep( - size_t iter_slot, size_t accu_slot, + size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, @@ -66,6 +81,16 @@ std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, // context for the comprehension. std::unique_ptr CreateComprehensionInitStep(int64_t expr_id); +// Creates a cleanup step for the comprehension. +// Removes the comprehension context then pushes the 'result' sub expression to +// the top of the stack. +std::unique_ptr CreateComprehensionFinishStep2(size_t accu_slot, + int64_t expr_id); + +// Creates a step that checks that the input is iterable and sets up the loop +// context for the comprehension. +std::unique_ptr CreateComprehensionInitStep2(int64_t expr_id); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 776b0e238..47cbba516 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -306,7 +306,7 @@ TEST_F(DirectComprehensionTest, PropagateRangeNonOkStatus) { .WillByDefault(Return(absl::InternalError("test range error"))); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/std::move(range_step), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -335,7 +335,7 @@ TEST_F(DirectComprehensionTest, PropagateAccuInitNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/std::move(accu_init), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -364,7 +364,7 @@ TEST_F(DirectComprehensionTest, PropagateLoopNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -393,7 +393,7 @@ TEST_F(DirectComprehensionTest, PropagateConditionNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -422,7 +422,7 @@ TEST_F(DirectComprehensionTest, PropagateResultNonOkStatus) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), @@ -455,7 +455,7 @@ TEST_F(DirectComprehensionTest, Shortcircuit) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -488,7 +488,7 @@ TEST_F(DirectComprehensionTest, IterationLimit) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), @@ -521,7 +521,7 @@ TEST_F(DirectComprehensionTest, Exhaustive) { ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( - 0, 1, + 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), diff --git a/eval/eval/create_map_step.cc b/eval/eval/create_map_step.cc index 3d8d86729..f205dd4b0 100644 --- a/eval/eval/create_map_step.cc +++ b/eval/eval/create_map_step.cc @@ -46,6 +46,7 @@ using ::cel::MapValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; using ::cel::common_internal::NewMapValueBuilder; +using ::cel::common_internal::NewMutableMapValue; // `CreateStruct` implementation for map. class CreateStructStepForMap final : public ExpressionStepBase { @@ -231,6 +232,30 @@ absl::Status DirectCreateMapStep::Evaluate( return absl::OkStatus(); } +class MutableMapStep final : public ExpressionStep { + public: + explicit MutableMapStep(int64_t expr_id) : ExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override { + frame->value_stack().Push(cel::ParsedMapValue( + NewMutableMapValue(frame->memory_manager().arena()))); + return absl::OkStatus(); + } +}; + +class DirectMutableMapStep final : public DirectExpressionStep { + public: + explicit DirectMutableMapStep(int64_t expr_id) + : DirectExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + result = cel::ParsedMapValue( + NewMutableMapValue(frame.value_manager().GetMemoryManager().arena())); + return absl::OkStatus(); + } +}; + } // namespace std::unique_ptr CreateDirectCreateMapStep( @@ -248,4 +273,14 @@ absl::StatusOr> CreateCreateStructStepForMap( std::move(optional_indices)); } +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id) { + return std::make_unique(expr_id); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step.h b/eval/eval/create_map_step.h index f9be4be0c..cf5e94644 100644 --- a/eval/eval/create_map_step.h +++ b/eval/eval/create_map_step.h @@ -40,6 +40,20 @@ absl::StatusOr> CreateCreateStructStepForMap( size_t entry_count, absl::flat_hash_set optional_indices, int64_t expr_id); +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +absl::StatusOr> CreateMutableMapStep( + int64_t expr_id); + +// Factory method for CreateMap which constructs a mutable map. +// +// This is intended for the map construction step is generated for a +// map-building comprehension (rather than a user authored expression). +std::unique_ptr CreateDirectMutableMapStep( + int64_t expr_id); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ diff --git a/extensions/BUILD b/extensions/BUILD index b09ad4905..a4f489f41 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -436,3 +436,88 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "comprehensions_v2_functions", + srcs = ["comprehensions_v2_functions.cc"], + hdrs = ["comprehensions_v2_functions.h"], + deps = [ + "//common:value", + "//eval/public:cel_function_registry", + "//eval/public:cel_options", + "//internal:status_macros", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "comprehensions_v2_functions_test", + srcs = ["comprehensions_v2_functions_test.cc"], + deps = [ + ":bindings_ext", + ":comprehensions_v2_functions", + ":comprehensions_v2_macros", + ":strings", + "//common:source", + "//common:value_testing", + "//extensions/protobuf:runtime_adapter", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", + "//parser:options", + "//parser:standard_macros", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:reference_resolver", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "comprehensions_v2_macros", + srcs = ["comprehensions_v2_macros.cc"], + hdrs = ["comprehensions_v2_macros.h"], + deps = [ + "//common:expr", + "//common:operators", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "//parser:options", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "comprehensions_v2_macros_test", + srcs = ["comprehensions_v2_macros_test.cc"], + deps = [ + ":comprehensions_v2_macros", + "//common:source", + "//internal:testing", + "//parser", + "//parser:macro_registry", + "//parser:options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/extensions/comprehensions_v2_functions.cc b/extensions/comprehensions_v2_functions.cc new file mode 100644 index 000000000..4202eef8d --- /dev/null +++ b/extensions/comprehensions_v2_functions.cc @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_functions.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/value.h" +#include "common/values/map_value_builder.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "internal/status_macros.h" +#include "runtime/function_adapter.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +namespace { + +absl::StatusOr MapInsert(ValueManager& value_manager, + const MapValue& map, const Value& key, + const Value& value) { + if (auto mutable_map_value = common_internal::AsMutableMapValue(map); + mutable_map_value) { + // Fast path, runtime has given us a mutable map. We can mutate it directly + // and return it. + CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)) + .With(ErrorValueReturn()); + return map; + } + // Slow path, we have to make a copy. + auto builder = common_internal::NewMapValueBuilder( + value_manager.GetMemoryManager().arena()); + if (auto size = map.Size(); size.ok()) { + builder->Reserve(*size + 1); + } else { + size.IgnoreError(); + } + CEL_RETURN_IF_ERROR( + map.ForEach(value_manager, + [&builder](const Value& key, + const Value& value) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(builder->Put(key, value)); + return true; + })) + .With(ErrorValueReturn()); + CEL_RETURN_IF_ERROR(builder->Put(key, value)).With(ErrorValueReturn()); + return std::move(*builder).Build(); +} + +} // namespace + +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options) { + CEL_RETURN_IF_ERROR(registry.Register( + VariadicFunctionAdapter, MapValue, Value, Value>:: + CreateDescriptor("cel.@mapInsert", /*receiver_style=*/false), + VariadicFunctionAdapter, MapValue, Value, + Value>::WrapFunction(&MapInsert))); + return absl::OkStatus(); +} + +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options) { + return RegisterComprehensionsV2Functions( + registry->InternalGetRegistry(), + google::api::expr::runtime::ConvertToRuntimeOptions(options)); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_functions.h b/extensions/comprehensions_v2_functions.h new file mode 100644 index 000000000..8f99780a2 --- /dev/null +++ b/extensions/comprehensions_v2_functions.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ + +#include "absl/status/status.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_options.h" +#include "runtime/function_registry.h" +#include "runtime/runtime_options.h" + +namespace cel::extensions { + +// Register comprehension v2 functions. +absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, + const RuntimeOptions& options); +absl::Status RegisterComprehensionsV2Functions( + google::api::expr::runtime::CelFunctionRegistry* registry, + const google::api::expr::runtime::InterpreterOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ diff --git a/extensions/comprehensions_v2_functions_test.cc b/extensions/comprehensions_v2_functions_test.cc new file mode 100644 index 000000000..bc310fe2a --- /dev/null +++ b/extensions/comprehensions_v2_functions_test.cc @@ -0,0 +1,222 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_functions.h" + +#include +#include +#include + +#include "cel/expr/syntax.pb.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "common/value_testing.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2_macros.h" +#include "extensions/protobuf/runtime_adapter.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/standard_macros.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::cel::test::BoolValueIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::TestWithParam; + +struct ComprehensionsV2FunctionsTestCase { + std::string expression; +}; + +class ComprehensionsV2FunctionsTest + : public TestWithParam { + public: + void SetUp() override { + RuntimeOptions options; + options.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + ASSERT_THAT( + RegisterComprehensionsV2Functions(builder.function_registry(), options), + IsOk()); + ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); + } + + absl::StatusOr Parse(absl::string_view text) { + CEL_ASSIGN_OR_RETURN(auto source, NewSource(text)); + + ParserOptions options; + options.enable_optional_syntax = true; + + MacroRegistry registry; + CEL_RETURN_IF_ERROR(RegisterStandardMacros(registry, options)); + CEL_RETURN_IF_ERROR(RegisterComprehensionsV2Macros(registry, options)); + CEL_RETURN_IF_ERROR(RegisterBindingsMacros(registry, options)); + + CEL_ASSIGN_OR_RETURN(auto result, + EnrichedParse(*source, registry, options)); + return result.parsed_expr(); + } + + protected: + std::unique_ptr runtime_; +}; + +TEST_P(ComprehensionsV2FunctionsTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto ast, Parse(GetParam().expression)); + ASSERT_OK_AND_ASSIGN(auto program, + ProtobufRuntimeAdapter::CreateProgram(*runtime_, ast)); + google::protobuf::Arena arena; + Activation activation; + EXPECT_THAT(program->Evaluate(&arena, activation), + IsOkAndHolds(BoolValueIs(true))) + << GetParam().expression; +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2FunctionsTest, ComprehensionsV2FunctionsTest, + ::testing::ValuesIn({ + // list.all() + {.expression = "[1, 2, 3, 4].all(i, v, i < 5 && v > 0)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i < v)"}, + {.expression = "[1, 2, 3, 4].all(i, v, i > v) == false"}, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))))cel", + }, + { + .expression = + R"cel(cel.bind(listA, [1, 2, 3, 4, 5, 6], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))) == false)cel", + }, + // list.exists() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + // list.existsOne() + { + .expression = + R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", + }, + { + .expression = + R"cel(cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next == "goodbye").orValue(false))) == false)cel", + }, + // list.transformList() + { + .expression = + R"cel(['Hello', 'world'].transformList(i, v, "[" + string(i) + "]" + v.lowerAscii()) == ["[0]hello", "[1]world"])cel", + }, + { + .expression = + R"cel(['hello', 'world'].transformList(i, v, v.startsWith('greeting'), "[" + string(i) + "]" + v) == [])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, (indexVar * valueVar) + valueVar) == [1, 4, 9])cel", + }, + { + .expression = + R"cel([1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == [1, 9])cel", + }, + // map.transformMap() + { + .expression = + R"cel(['Hello', 'world'].transformMap(i, v, [v.lowerAscii()]) == {0: ['hello'], 1: ['world']})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, (indexVar * valueVar) + valueVar) == {0: 1, 1: 4, 2: 9})cel", + }, + { + .expression = + R"cel([1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == {0: 1, 2: 9})cel", + }, + // map.all() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v == 'world'))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.exists() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + // map.existsOne() + { + .expression = + R"cel({'hello': 'world', 'hello!': 'worlds'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')))cel", + }, + { + .expression = + R"cel({'hello': 'world', 'hello!': 'wow, world'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", + }, + // map.transformList() + { + .expression = + R"cel({'Hello': 'world'}.transformList(k, v, k.lowerAscii() + "=" + v) == ["hello=world"])cel", + }, + { + .expression = + R"cel({'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), k + "=" + v) == [])cel", + }, + { + .expression = + R"cel(cel.bind(m, {'farewell': 'goodbye', 'greeting': 'hello'}.transformList(k, _, k), m == ['farewell', 'greeting'] || m == ['greeting', 'farewell']))cel", + }, + { + .expression = + R"cel(cel.bind(m, {'greeting': 'hello', 'farewell': 'goodbye'}.transformList(_, v, v), m == ['goodbye', 'hello'] || m == ['hello', 'goodbye']))cel", + }, + // map.transformMap() + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, k + ", " + v + "!") == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'})cel", + }, + { + .expression = + R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), k + ", " + v + "!") == {'hello': 'hello, world!'})cel", + }, + })); + +} // namespace +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc new file mode 100644 index 000000000..04793ad39 --- /dev/null +++ b/extensions/comprehensions_v2_macros.cc @@ -0,0 +1,488 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_macros.h" + +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +namespace { + +using ::google::api::expr::common::CelOperator; + +absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("all() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "all() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "all() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "all() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("all() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(true); + auto condition = + factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); + auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeAllMacro2() { + auto status_or_macro = Macro::Receiver(CelOperator::ALL, 3, ExpandAllMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("exists() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "exists() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "exists() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "exists() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("exists() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("exists() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewBoolConst(false); + auto condition = factory.NewCall( + CelOperator::NOT_STRICTLY_FALSE, + factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); + auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), + std::move(args[2])); + auto result = factory.NewAccuIdent(); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsMacro2() { + auto status_or_macro = + Macro::Receiver(CelOperator::EXISTS, 3, ExpandExistsMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("existsOne() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "existsOne() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "existsOne() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "existsOne() second variable must be different " + "from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("existsOne() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("existsOne() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto init = factory.NewIntConst(0); + auto condition = factory.NewBoolConst(true); + auto step = + factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + factory.NewCall(CelOperator::ADD, factory.NewAccuIdent(), + factory.NewIntConst(1)), + factory.NewAccuIdent()); + auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(), + factory.NewIntConst(1)); + return factory.NewComprehension( + args[0].ident_expr().name(), args[1].ident_expr().name(), + std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), std::move(result)); +} + +Macro MakeExistsOneMacro2() { + auto status_or_macro = Macro::Receiver("existsOne", 3, ExpandExistsOneMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandFilterMacro2(MacroExprFactory& factory, Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("filter() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], "filter() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], "filter() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt( + args[0], + "filter() second variable must be different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("filter() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("filter() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto name = args[0].ident_expr().name(); + auto name2 = args[1].ident_expr().name(); + auto init = factory.NewList(); + auto condition = factory.NewBoolConst(true); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[1])))); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension( + name, name2, std::move(target), kAccumulatorVariableName, std::move(init), + std::move(condition), std::move(step), factory.NewAccuIdent()); +} + +Macro MakeFilterMacro2() { + auto status_or_macro = + Macro::Receiver(CelOperator::FILTER, 3, ExpandFilterMacro2); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformList() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[2])))); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList3Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 3, ExpandTransformList3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformList() requires 4 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformList() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformList() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformList() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformList() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformList() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall( + CelOperator::ADD, factory.NewAccuIdent(), + factory.NewList(factory.NewListElement(std::move(args[3])))); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewList(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformList4Macro() { + auto status_or_macro = + Macro::Receiver("transformList", 4, ExpandTransformList4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 3) { + return factory.ReportError("transformMap() requires 3 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transforMap() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[2])); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap3Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 3, ExpandTransformMap3Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (args.size() != 4) { + return factory.ReportError("transformMap() requires 4 arguments"); + } + if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[0], + "transformMap() first variable name must be a simple identifier"); + } + if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + return factory.ReportErrorAt( + args[1], + "transformMap() second variable name must be a simple identifier"); + } + if (args[0].ident_expr().name() == args[1].ident_expr().name()) { + return factory.ReportErrorAt(args[0], + "transformMap() second variable must be " + "different from the first variable"); + } + if (args[0].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[0], absl::StrCat("transformMap() first variable name cannot be ", + kAccumulatorVariableName)); + } + if (args[1].ident_expr().name() == kAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("transformMap() second variable name cannot be ", + kAccumulatorVariableName)); + } + auto iter_var = args[0].ident_expr().name(); + auto iter_var2 = args[1].ident_expr().name(); + auto step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), + std::move(args[0]), std::move(args[3])); + step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), + std::move(step), factory.NewAccuIdent()); + return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), + std::move(target), kAccumulatorVariableName, + factory.NewMap(), factory.NewBoolConst(true), + std::move(step), factory.NewAccuIdent()); +} + +Macro MakeTransformMap4Macro() { + auto status_or_macro = + Macro::Receiver("transformMap", 4, ExpandTransformMap4Macro); + ABSL_CHECK_OK(status_or_macro); // Crash OK + return std::move(*status_or_macro); +} + +const Macro& AllMacro2() { + static const absl::NoDestructor macro(MakeAllMacro2()); + return *macro; +} + +const Macro& ExistsMacro2() { + static const absl::NoDestructor macro(MakeExistsMacro2()); + return *macro; +} + +const Macro& ExistsOneMacro2() { + static const absl::NoDestructor macro(MakeExistsOneMacro2()); + return *macro; +} + +const Macro& FilterMacro2() { + static const absl::NoDestructor macro(MakeFilterMacro2()); + return *macro; +} + +const Macro& TransformList3Macro() { + static const absl::NoDestructor macro(MakeTransformList3Macro()); + return *macro; +} + +const Macro& TransformList4Macro() { + static const absl::NoDestructor macro(MakeTransformList4Macro()); + return *macro; +} + +const Macro& TransformMap3Macro() { + static const absl::NoDestructor macro(MakeTransformMap3Macro()); + return *macro; +} + +const Macro& TransformMap4Macro() { + static const absl::NoDestructor macro(MakeTransformMap4Macro()); + return *macro; +} + +} // namespace + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions&) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(AllMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsOneMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(FilterMacro2())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformList3Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformList4Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformMap3Macro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(TransformMap4Macro())); + return absl::OkStatus(); +} + +} // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.h b/extensions/comprehensions_v2_macros.h new file mode 100644 index 000000000..3b2bfd577 --- /dev/null +++ b/extensions/comprehensions_v2_macros.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro_registry.h" +#include "parser/options.h" + +namespace cel::extensions { + +// Registers the macros defined by the comprehension v2 extension. +absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, + const ParserOptions& options); + +} // namespace cel::extensions + +#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ diff --git a/extensions/comprehensions_v2_macros_test.cc b/extensions/comprehensions_v2_macros_test.cc new file mode 100644 index 000000000..4fa07123f --- /dev/null +++ b/extensions/comprehensions_v2_macros_test.cc @@ -0,0 +1,230 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/comprehensions_v2_macros.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/source.h" +#include "internal/testing.h" +#include "parser/macro_registry.h" +#include "parser/options.h" +#include "parser/parser.h" + +namespace cel::extensions { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::google::api::expr::parser::EnrichedParse; +using ::testing::HasSubstr; + +struct ComprehensionsV2MacrosTestCase { + std::string expression; + std::string error; +}; + +using ComprehensionsV2MacrosTest = + ::testing::TestWithParam; + +TEST_P(ComprehensionsV2MacrosTest, Basic) { + const auto& test_param = GetParam(); + ASSERT_OK_AND_ASSIGN(auto source, NewSource(test_param.expression)); + + MacroRegistry registry; + ASSERT_THAT(RegisterComprehensionsV2Macros(registry, ParserOptions()), + IsOk()); + + EXPECT_THAT(EnrichedParse(*source, registry), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_param.error))); +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionsV2MacrosTest, ComprehensionsV2MacrosTest, + ::testing::ValuesIn({ + { + .expression = "[].all(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].all(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].all(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].all(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].all(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].exists(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].exists(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].exists(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].exists(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].existsOne(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].existsOne(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].existsOne(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].existsOne(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].filter(__result__, v, v == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(i, __result__, i == 0)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].filter(e, e, e == e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].filter(foo.bar, e, true)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].filter(e, foo.bar, true)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].transformList(__result__, v, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "[].transformList(__result__, v, v == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(i, __result__, i == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "[].transformList(e, e, e == e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "[].transformList(foo.bar, e, true, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "[].transformList(e, foo.bar, true, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(__result__, v, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, e)", + .error = "second variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(__result__, v, v == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(k, __result__, k == 0, v)", + .error = "variable name cannot be __result__", + }, + { + .expression = "{}.transformMap(e, e, e == e, e)", + .error = + "second variable must be different from the first variable", + }, + { + .expression = "{}.transformMap(foo.bar, e, true, e)", + .error = "first variable name must be a simple identifier", + }, + { + .expression = "{}.transformMap(e, foo.bar, true, e)", + .error = "second variable name must be a simple identifier", + }, + })); + +} // namespace +} // namespace cel::extensions diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index e84e8be7a..291bccdb0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -255,6 +255,27 @@ class MacroExprFactory : protected ExprFactory { std::move(loop_step), std::move(result)); } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); + } + ABSL_MUST_USE_RESULT virtual Expr ReportError(absl::string_view message) = 0; ABSL_MUST_USE_RESULT virtual Expr ReportErrorAt( From 0dceee5f4c2feec39b7587e74a97739e6d8963ca Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 10 Dec 2024 11:12:03 -0800 Subject: [PATCH 057/180] Fix bug when accessing iterator variables when using recursive comprehensions PiperOrigin-RevId: 704778661 --- eval/eval/BUILD | 1 + eval/eval/comprehension_slots.h | 7 ++- eval/eval/comprehension_step.cc | 106 +++++++++++++++++--------------- 3 files changed, 62 insertions(+), 52 deletions(-) diff --git a/eval/eval/BUILD b/eval/eval/BUILD index e46f66084..32c67e673 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -107,6 +107,7 @@ cc_library( ":attribute_trail", "//common:value", "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", ], diff --git a/eval/eval/comprehension_slots.h b/eval/eval/comprehension_slots.h index bfaa1792b..122b1e3f8 100644 --- a/eval/eval/comprehension_slots.h +++ b/eval/eval/comprehension_slots.h @@ -20,6 +20,7 @@ #include #include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "common/value.h" @@ -58,7 +59,7 @@ class ComprehensionSlots { // Return ptr to slot at index. // If not set, returns nullptr. - Slot* Get(size_t index) { + absl::Nullable Get(size_t index) { ABSL_DCHECK_LT(index, slots_.size()); auto& slot = slots_[index]; if (!slot.has_value()) return nullptr; @@ -75,9 +76,9 @@ class ComprehensionSlots { slots_[index] = absl::nullopt; } - void Set(size_t index) { + absl::Nonnull Set(size_t index) { ABSL_DCHECK_LT(index, slots_.size()); - slots_[index].emplace(); + return &slots_[index].emplace(); } void Set(size_t index, cel::Value value) { diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index c34121b09..6c5dc5d78 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -291,9 +291,8 @@ absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, frame.comprehension_slots().Get(accu_slot_); ABSL_DCHECK(accu_slot != nullptr); - frame.comprehension_slots().Set(iter_slot_); ComprehensionSlots::Slot* iter_slot = - frame.comprehension_slots().Get(iter_slot_); + frame.comprehension_slots().Set(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); Value condition; @@ -303,7 +302,21 @@ absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, frame.value_manager(), [&](size_t index, const Value& v) -> absl::StatusOr { CEL_RETURN_IF_ERROR(frame.IncrementIterations()); - // Evaluate loop condition first. + + // Set the iterator variable(s) first, the loop condition has access to + // them. + iter_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(CelAttributeQualifier::OfInt(index)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + // Evaluate the loop condition. CEL_RETURN_IF_ERROR( condition_->Evaluate(frame, condition, condition_attr)); @@ -323,17 +336,7 @@ absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, return false; } - iter_slot->value = v; - if (frame.unknown_processing_enabled()) { - iter_slot->attribute = - range_attr.Step(CelAttributeQualifier::OfInt(index)); - if (frame.attribute_utility().CheckForUnknownExact( - iter_slot->attribute)) { - iter_slot->value = frame.attribute_utility().CreateUnknownSet( - iter_slot->attribute.attribute()); - } - } - + // Evaluate the loop step. CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, accu_slot->attribute)); @@ -394,14 +397,12 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, frame.comprehension_slots().Get(accu_slot_); ABSL_DCHECK(accu_slot != nullptr); - frame.comprehension_slots().Set(iter_slot_); ComprehensionSlots::Slot* iter_slot = - frame.comprehension_slots().Get(iter_slot_); + frame.comprehension_slots().Set(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); - frame.comprehension_slots().Set(iter2_slot_); ComprehensionSlots::Slot* iter2_slot = - frame.comprehension_slots().Get(iter2_slot_); + frame.comprehension_slots().Set(iter2_slot_); ABSL_DCHECK(iter2_slot != nullptr); Value condition; @@ -412,26 +413,9 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, frame.value_manager(), [&](const Value& k, const Value& v) -> absl::StatusOr { CEL_RETURN_IF_ERROR(frame.IncrementIterations()); - // Evaluate loop condition first. - CEL_RETURN_IF_ERROR( - condition_->Evaluate(frame, condition, condition_attr)); - - if (condition.kind() == cel::ValueKind::kError || - condition.kind() == cel::ValueKind::kUnknown) { - result = std::move(condition); - should_skip_result = true; - return false; - } - if (condition.kind() != cel::ValueKind::kBool) { - result = frame.value_manager().CreateErrorValue( - CreateNoMatchingOverloadError("")); - should_skip_result = true; - return false; - } - if (shortcircuiting_ && !Cast(condition).NativeValue()) { - return false; - } + // Set the iterator variable(s) first, the loop condition has access + // to them. iter_slot->value = k; if (frame.unknown_processing_enabled()) { iter_slot->attribute = @@ -454,17 +438,7 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, } } - CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, - accu_slot->attribute)); - - return true; - })); - } else { - CEL_RETURN_IF_ERROR(iter_range_list.ForEach( - frame.value_manager(), - [&](size_t index, const Value& v) -> absl::StatusOr { - CEL_RETURN_IF_ERROR(frame.IncrementIterations()); - // Evaluate loop condition first. + // Evaluate the loop condition. CEL_RETURN_IF_ERROR( condition_->Evaluate(frame, condition, condition_attr)); @@ -484,6 +458,20 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, return false; } + // Evaluate the loop step. + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + } else { + CEL_RETURN_IF_ERROR(iter_range_list.ForEach( + frame.value_manager(), + [&](size_t index, const Value& v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + + // Set the iterator variable(s) first, the loop condition has access + // to them. iter_slot->value = IntValue(index); if (frame.unknown_processing_enabled()) { iter_slot->attribute = @@ -494,7 +482,6 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, iter_slot->attribute.attribute()); } } - iter2_slot->value = v; if (frame.unknown_processing_enabled()) { iter2_slot->attribute = @@ -506,6 +493,27 @@ absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, } } + // Evaluate the loop condition. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + // Evaluate the loop step. CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, accu_slot->attribute)); From 7768c87fe4835909f56771fd69d61156e0b85314 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 10 Dec 2024 12:28:15 -0800 Subject: [PATCH 058/180] Make check for optimizeable list append more strict. Updated to filter hand-rolled map-like comprehensions. PiperOrigin-RevId: 704806367 --- eval/compiler/flat_expr_builder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index c687db003..e5d4a8c25 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -270,7 +270,8 @@ bool IsOptimizableListAppend( comprehension->result().ident_expr().name() != accu_var) { return false; } - if (!comprehension->accu_init().has_list_expr()) { + if (!comprehension->accu_init().has_list_expr() || + !comprehension->accu_init().list_expr().elements().empty()) { return false; } From 8a5e0745bed02ca4bb85656f183c49db17fe68e9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 17 Dec 2024 14:52:10 -0800 Subject: [PATCH 059/180] Breaking change: Remove `MutableRepeatedFieldRef::Reserve()` (reflection) An upcoming performance improvement in RepeatedPtrField is incompatible with this API. The improvement is projected to accelerate repeated access to the elements of `RepeatedPtrField`, in particular and especially sequential access. PA: https://protobuf.dev/news/2024-12-13/ PiperOrigin-RevId: 707260596 --- common/values/parsed_map_field_value.cc | 1 - common/values/parsed_repeated_field_value.cc | 1 - internal/json.cc | 34 -------------------- internal/well_known_types.cc | 9 ------ internal/well_known_types.h | 10 ------ 5 files changed, 55 deletions(-) diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc index 6a0e3cc5d..a9d04f039 100644 --- a/common/values/parsed_map_field_value.cc +++ b/common/values/parsed_map_field_value.cc @@ -165,7 +165,6 @@ ParsedMapFieldValue ParsedMapFieldValue::Clone(Allocator<> allocator) const { auto cloned_field = cloned->GetReflection()->GetMutableRepeatedFieldRef( cel::to_address(cloned), field_); - cloned_field.Reserve(field.size()); cloned_field.CopyFrom(field); return ParsedMapFieldValue(std::move(cloned), field_); } diff --git a/common/values/parsed_repeated_field_value.cc b/common/values/parsed_repeated_field_value.cc index e66eba49c..a288bba88 100644 --- a/common/values/parsed_repeated_field_value.cc +++ b/common/values/parsed_repeated_field_value.cc @@ -160,7 +160,6 @@ ParsedRepeatedFieldValue ParsedRepeatedFieldValue::Clone( auto cloned_field = cloned->GetReflection()->GetMutableRepeatedFieldRef( cel::to_address(cloned), field_); - cloned_field.Reserve(field.size()); cloned_field.CopyFrom(field); return ParsedRepeatedFieldValue(std::move(cloned), field_); } diff --git a/internal/json.cc b/internal/json.cc index f557a5491..44e1f4f63 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -806,7 +806,6 @@ class MessageToJsonState { if (size == 0) { return absl::OkStatus(); } - ReserveValues(result, size); CEL_ASSIGN_OR_RETURN(const auto to_value, GetRepeatedFieldToValue(field)); for (int index = 0; index < size; ++index) { CEL_RETURN_IF_ERROR((this->*to_value)(reflection, message, field, index, @@ -995,9 +994,6 @@ class MessageToJsonState { virtual absl::Nonnull MutableStructValue( absl::Nonnull message) const = 0; - virtual void ReserveValues(absl::Nonnull message, - int capacity) const = 0; - virtual absl::Nonnull AddValues( absl::Nonnull message) const = 0; @@ -1076,13 +1072,6 @@ class GeneratedMessageToJsonState final : public MessageToJsonState { google::protobuf::DownCastMessage(message)); } - void ReserveValues(absl::Nonnull message, - int capacity) const override { - ListValueReflection::ReserveValues( - google::protobuf::DownCastMessage(message), - capacity); - } - absl::Nonnull AddValues( absl::Nonnull message) const override { return ListValueReflection::AddValues( @@ -1163,12 +1152,6 @@ class DynamicMessageToJsonState final : public MessageToJsonState { google::protobuf::DownCastMessage(message)); } - void ReserveValues(absl::Nonnull message, - int capacity) const override { - reflection_.ListValue().ReserveValues( - google::protobuf::DownCastMessage(message), capacity); - } - absl::Nonnull AddValues( absl::Nonnull message) const override { return reflection_.ListValue().AddValues( @@ -2172,9 +2155,6 @@ class JsonMutator { virtual absl::Nonnull MutableListValue( absl::Nonnull message) const = 0; - virtual void ReserveValues(absl::Nonnull message, - int capacity) const = 0; - virtual absl::Nonnull AddValues( absl::Nonnull message) const = 0; @@ -2223,13 +2203,6 @@ class GeneratedJsonMutator final : public JsonMutator { google::protobuf::DownCastMessage(message)); } - void ReserveValues(absl::Nonnull message, - int capacity) const override { - ListValueReflection::ReserveValues( - google::protobuf::DownCastMessage(message), - capacity); - } - absl::Nonnull AddValues( absl::Nonnull message) const override { return ListValueReflection::AddValues( @@ -2312,12 +2285,6 @@ class DynamicJsonMutator final : public JsonMutator { google::protobuf::DownCastMessage(message)); } - void ReserveValues(absl::Nonnull message, - int capacity) const override { - list_value_reflection_.ReserveValues( - google::protobuf::DownCastMessage(message), capacity); - } - absl::Nonnull AddValues( absl::Nonnull message) const override { return list_value_reflection_.AddValues( @@ -2379,7 +2346,6 @@ class NativeJsonToProtoJsonState { absl::Status ToProtoJsonList(const JsonArray& json, absl::Nonnull proto) { - mutator_->ReserveValues(proto, static_cast(json.size())); for (const auto& element : json) { CEL_RETURN_IF_ERROR(ToProtoJson(element, mutator_->AddValues(proto))); } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index 2e9cae6c6..085269ecd 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1534,15 +1534,6 @@ absl::Nonnull ListValueReflection::AddValues( return message->GetReflection()->AddMessage(message, values_field_); } -void ListValueReflection::ReserveValues(absl::Nonnull message, - int capacity) const { - ABSL_DCHECK(IsInitialized()); - ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); - if (capacity > 0) { - MutableValues(message).Reserve(capacity); - } -} - absl::StatusOr GetListValueReflection( absl::Nonnull descriptor) { ListValueReflection reflection; diff --git a/internal/well_known_types.h b/internal/well_known_types.h index 2cef32a96..4c3e70d20 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -1066,13 +1066,6 @@ class ListValueReflection final { return message->add_values(); } - static void ReserveValues(absl::Nonnull message, - int capacity) { - if (capacity > 0) { - message->mutable_values()->Reserve(capacity); - } - } - absl::Status Initialize(absl::Nonnull pool); absl::Status Initialize(absl::Nonnull descriptor); @@ -1110,9 +1103,6 @@ class ListValueReflection final { absl::Nonnull AddValues( absl::Nonnull message) const; - void ReserveValues(absl::Nonnull message, - int capacity) const; - private: absl::Nullable descriptor_ = nullptr; absl::Nullable values_field_ = nullptr; From 22b8dd4484ed9a153c3a5cfc4c33f456d565dac4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 17 Dec 2024 17:13:40 -0800 Subject: [PATCH 060/180] Breaking change: Remove `MutableRepeatedFieldRef::Reserve()` (reflection) An upcoming performance improvement in RepeatedPtrField is incompatible with this API. The improvement is projected to accelerate repeated access to the elements of `RepeatedPtrField`, in particular and especially sequential access. PA: https://protobuf.dev/news/2024-12-13/ PiperOrigin-RevId: 707309375 --- common/values/parsed_map_field_value.cc | 1 + common/values/parsed_repeated_field_value.cc | 1 + internal/json.cc | 34 ++++++++++++++++++++ internal/well_known_types.cc | 9 ++++++ internal/well_known_types.h | 10 ++++++ 5 files changed, 55 insertions(+) diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc index a9d04f039..6a0e3cc5d 100644 --- a/common/values/parsed_map_field_value.cc +++ b/common/values/parsed_map_field_value.cc @@ -165,6 +165,7 @@ ParsedMapFieldValue ParsedMapFieldValue::Clone(Allocator<> allocator) const { auto cloned_field = cloned->GetReflection()->GetMutableRepeatedFieldRef( cel::to_address(cloned), field_); + cloned_field.Reserve(field.size()); cloned_field.CopyFrom(field); return ParsedMapFieldValue(std::move(cloned), field_); } diff --git a/common/values/parsed_repeated_field_value.cc b/common/values/parsed_repeated_field_value.cc index a288bba88..e66eba49c 100644 --- a/common/values/parsed_repeated_field_value.cc +++ b/common/values/parsed_repeated_field_value.cc @@ -160,6 +160,7 @@ ParsedRepeatedFieldValue ParsedRepeatedFieldValue::Clone( auto cloned_field = cloned->GetReflection()->GetMutableRepeatedFieldRef( cel::to_address(cloned), field_); + cloned_field.Reserve(field.size()); cloned_field.CopyFrom(field); return ParsedRepeatedFieldValue(std::move(cloned), field_); } diff --git a/internal/json.cc b/internal/json.cc index 44e1f4f63..f557a5491 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -806,6 +806,7 @@ class MessageToJsonState { if (size == 0) { return absl::OkStatus(); } + ReserveValues(result, size); CEL_ASSIGN_OR_RETURN(const auto to_value, GetRepeatedFieldToValue(field)); for (int index = 0; index < size; ++index) { CEL_RETURN_IF_ERROR((this->*to_value)(reflection, message, field, index, @@ -994,6 +995,9 @@ class MessageToJsonState { virtual absl::Nonnull MutableStructValue( absl::Nonnull message) const = 0; + virtual void ReserveValues(absl::Nonnull message, + int capacity) const = 0; + virtual absl::Nonnull AddValues( absl::Nonnull message) const = 0; @@ -1072,6 +1076,13 @@ class GeneratedMessageToJsonState final : public MessageToJsonState { google::protobuf::DownCastMessage(message)); } + void ReserveValues(absl::Nonnull message, + int capacity) const override { + ListValueReflection::ReserveValues( + google::protobuf::DownCastMessage(message), + capacity); + } + absl::Nonnull AddValues( absl::Nonnull message) const override { return ListValueReflection::AddValues( @@ -1152,6 +1163,12 @@ class DynamicMessageToJsonState final : public MessageToJsonState { google::protobuf::DownCastMessage(message)); } + void ReserveValues(absl::Nonnull message, + int capacity) const override { + reflection_.ListValue().ReserveValues( + google::protobuf::DownCastMessage(message), capacity); + } + absl::Nonnull AddValues( absl::Nonnull message) const override { return reflection_.ListValue().AddValues( @@ -2155,6 +2172,9 @@ class JsonMutator { virtual absl::Nonnull MutableListValue( absl::Nonnull message) const = 0; + virtual void ReserveValues(absl::Nonnull message, + int capacity) const = 0; + virtual absl::Nonnull AddValues( absl::Nonnull message) const = 0; @@ -2203,6 +2223,13 @@ class GeneratedJsonMutator final : public JsonMutator { google::protobuf::DownCastMessage(message)); } + void ReserveValues(absl::Nonnull message, + int capacity) const override { + ListValueReflection::ReserveValues( + google::protobuf::DownCastMessage(message), + capacity); + } + absl::Nonnull AddValues( absl::Nonnull message) const override { return ListValueReflection::AddValues( @@ -2285,6 +2312,12 @@ class DynamicJsonMutator final : public JsonMutator { google::protobuf::DownCastMessage(message)); } + void ReserveValues(absl::Nonnull message, + int capacity) const override { + list_value_reflection_.ReserveValues( + google::protobuf::DownCastMessage(message), capacity); + } + absl::Nonnull AddValues( absl::Nonnull message) const override { return list_value_reflection_.AddValues( @@ -2346,6 +2379,7 @@ class NativeJsonToProtoJsonState { absl::Status ToProtoJsonList(const JsonArray& json, absl::Nonnull proto) { + mutator_->ReserveValues(proto, static_cast(json.size())); for (const auto& element : json) { CEL_RETURN_IF_ERROR(ToProtoJson(element, mutator_->AddValues(proto))); } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index 085269ecd..2e9cae6c6 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1534,6 +1534,15 @@ absl::Nonnull ListValueReflection::AddValues( return message->GetReflection()->AddMessage(message, values_field_); } +void ListValueReflection::ReserveValues(absl::Nonnull message, + int capacity) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); + if (capacity > 0) { + MutableValues(message).Reserve(capacity); + } +} + absl::StatusOr GetListValueReflection( absl::Nonnull descriptor) { ListValueReflection reflection; diff --git a/internal/well_known_types.h b/internal/well_known_types.h index 4c3e70d20..2cef32a96 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -1066,6 +1066,13 @@ class ListValueReflection final { return message->add_values(); } + static void ReserveValues(absl::Nonnull message, + int capacity) { + if (capacity > 0) { + message->mutable_values()->Reserve(capacity); + } + } + absl::Status Initialize(absl::Nonnull pool); absl::Status Initialize(absl::Nonnull descriptor); @@ -1103,6 +1110,9 @@ class ListValueReflection final { absl::Nonnull AddValues( absl::Nonnull message) const; + void ReserveValues(absl::Nonnull message, + int capacity) const; + private: absl::Nullable descriptor_ = nullptr; absl::Nullable values_field_ = nullptr; From 72871f29e9e5e34be24fd382a4d02cebaa5aac01 Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Fri, 20 Dec 2024 08:58:09 -0800 Subject: [PATCH 061/180] Add operator<() for timestamp, duration and bytes to enable sorting. PiperOrigin-RevId: 708337645 --- common/values/bytes_value.h | 4 ++++ common/values/bytes_value_test.cc | 6 ++++++ common/values/duration_value.h | 4 ++++ common/values/duration_value_test.cc | 8 ++++++++ common/values/timestamp_value.h | 4 ++++ common/values/timestamp_value_test.cc | 9 +++++++++ 6 files changed, 35 insertions(+) diff --git a/common/values/bytes_value.h b/common/values/bytes_value.h index e8439ee69..0dcb81dab 100644 --- a/common/values/bytes_value.h +++ b/common/values/bytes_value.h @@ -159,6 +159,10 @@ class BytesValue final { absl::Cord ToCord() const { return NativeCord(); } + friend bool operator<(const BytesValue& lhs, const BytesValue& rhs) { + return lhs.value_ < rhs.value_; + } + private: friend class common_internal::TrivialValue; friend const common_internal::SharedByteString& diff --git a/common/values/bytes_value_test.cc b/common/values/bytes_value_test.cc index fbd5293ad..310509207 100644 --- a/common/values/bytes_value_test.cc +++ b/common/values/bytes_value_test.cc @@ -113,6 +113,12 @@ TEST_P(BytesValueTest, StringViewInequality) { // NOLINTEND(readability/check) } +TEST_P(BytesValueTest, Comparison) { + EXPECT_LT(BytesValue("bar"), BytesValue("foo")); + EXPECT_FALSE(BytesValue("foo") < BytesValue("foo")); + EXPECT_FALSE(BytesValue("foo") < BytesValue("bar")); +} + INSTANTIATE_TEST_SUITE_P( BytesValueTest, BytesValueTest, ::testing::Combine(::testing::Values(MemoryManagement::kPooling, diff --git a/common/values/duration_value.h b/common/values/duration_value.h index 41cb0c99c..62f16ddc6 100644 --- a/common/values/duration_value.h +++ b/common/values/duration_value.h @@ -84,6 +84,10 @@ class DurationValue final { swap(lhs.value_, rhs.value_); } + friend bool operator<(const DurationValue& lhs, const DurationValue& rhs) { + return lhs.value_ < rhs.value_; + } + private: absl::Duration value_ = absl::ZeroDuration(); }; diff --git a/common/values/duration_value_test.cc b/common/values/duration_value_test.cc index efce76a61..e7c722abe 100644 --- a/common/values/duration_value_test.cc +++ b/common/values/duration_value_test.cc @@ -90,6 +90,14 @@ TEST_P(DurationValueTest, Equality) { DurationValue(absl::Seconds(1))); } +TEST_P(DurationValueTest, Comparison) { + EXPECT_LT(DurationValue(absl::ZeroDuration()), absl::Seconds(1)); + EXPECT_FALSE(DurationValue(absl::Seconds(1)) < + DurationValue(absl::Seconds(1))); + EXPECT_FALSE(DurationValue(absl::Seconds(2)) < + DurationValue(absl::Seconds(1))); +} + INSTANTIATE_TEST_SUITE_P( DurationValueTest, DurationValueTest, ::testing::Combine(::testing::Values(MemoryManagement::kPooling, diff --git a/common/values/timestamp_value.h b/common/values/timestamp_value.h index bd2c7183e..9b45a3279 100644 --- a/common/values/timestamp_value.h +++ b/common/values/timestamp_value.h @@ -84,6 +84,10 @@ class TimestampValue final { swap(lhs.value_, rhs.value_); } + friend bool operator<(const TimestampValue& lhs, const TimestampValue& rhs) { + return lhs.value_ < rhs.value_; + } + private: absl::Time value_ = absl::UnixEpoch(); }; diff --git a/common/values/timestamp_value_test.cc b/common/values/timestamp_value_test.cc index 603060969..d45aaf2a7 100644 --- a/common/values/timestamp_value_test.cc +++ b/common/values/timestamp_value_test.cc @@ -98,6 +98,15 @@ TEST_P(TimestampValueTest, Equality) { TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); } +TEST_P(TimestampValueTest, Comparison) { + EXPECT_LT(TimestampValue(absl::UnixEpoch()), + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); + EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(1)) < + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); + EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(2)) < + TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); +} + INSTANTIATE_TEST_SUITE_P( TimestampValueTest, TimestampValueTest, ::testing::Combine(::testing::Values(MemoryManagement::kPooling, From 8ce99ed8cbf432686dc6ee79382fb5a20e37a8ad Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 20 Dec 2024 11:35:01 -0800 Subject: [PATCH 062/180] Add option for changing the comprehension accumulator variable used by standard macros. PiperOrigin-RevId: 708381185 --- common/expr_factory.h | 11 +- parser/BUILD | 1 + parser/macro.cc | 12 +- parser/macro_expr_factory.h | 6 +- parser/macro_expr_factory_test.cc | 2 +- parser/options.h | 3 + parser/parser.cc | 20 ++- parser/parser_test.cc | 267 ++++++++++++++++++++++++++++++ 8 files changed, 305 insertions(+), 17 deletions(-) diff --git a/common/expr_factory.h b/common/expr_factory.h index dd8e6ed25..c8a9b831f 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -179,9 +179,9 @@ class ExprFactory { return expr; } - Expr NewAccuIdent(ExprId id) { - return NewIdent(id, kAccumulatorVariableName); - } + absl::string_view AccuVarName() { return accu_var_; } + + Expr NewAccuIdent(ExprId id) { return NewIdent(id, AccuVarName()); } template ::value>, @@ -356,7 +356,10 @@ class ExprFactory { friend class MacroExprFactory; friend class ParserMacroExprFactory; - ExprFactory() = default; + ExprFactory() : accu_var_(kAccumulatorVariableName) {} + explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {} + + std::string accu_var_; }; } // namespace cel diff --git a/parser/BUILD b/parser/BUILD index 11cf95f85..aeb65342a 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -38,6 +38,7 @@ cc_library( "//base/ast_internal:expr", "//common:ast", "//common:constant", + "//common:expr", "//common:expr_factory", "//common:operators", "//common:source", diff --git a/parser/macro.cc b/parser/macro.cc index b11dca5db..db2fef502 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -103,7 +103,7 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, std::move(args[1])); auto result = factory.NewAccuIdent(); return factory.NewComprehension(args[0].ident_expr().name(), - std::move(target), kAccumulatorVariableName, + std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } @@ -136,7 +136,7 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, std::move(args[1])); auto result = factory.NewAccuIdent(); return factory.NewComprehension(args[0].ident_expr().name(), - std::move(target), kAccumulatorVariableName, + std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } @@ -172,7 +172,7 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(), factory.NewIntConst(1)); return factory.NewComprehension(args[0].ident_expr().name(), - std::move(target), kAccumulatorVariableName, + std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } @@ -204,7 +204,7 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, CelOperator::ADD, factory.NewAccuIdent(), factory.NewList(factory.NewListElement(std::move(args[1])))); return factory.NewComprehension(args[0].ident_expr().name(), - std::move(target), kAccumulatorVariableName, + std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } @@ -237,7 +237,7 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(args[0].ident_expr().name(), - std::move(target), kAccumulatorVariableName, + std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } @@ -272,7 +272,7 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(std::move(name), std::move(target), - kAccumulatorVariableName, std::move(init), + factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index 291bccdb0..83e322d4e 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -107,6 +107,8 @@ class MacroExprFactory : protected ExprFactory { return NewIdent(NextId(), std::move(name)); } + absl::string_view AccuVarName() { return ExprFactory::AccuVarName(); } + ABSL_MUST_USE_RESULT Expr NewAccuIdent() { return NewAccuIdent(NextId()); } template